Coverage for faststream / kafka / subscriber / usecase.py: 92%
153 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import logging
2from abc import abstractmethod
3from collections.abc import AsyncIterator, Callable, Sequence
4from itertools import chain
5from typing import TYPE_CHECKING, Any, Optional, cast
7import anyio
8from aiokafka import ConsumerRecord, TopicPartition
9from aiokafka.errors import ConsumerStoppedError, KafkaError, UnsupportedCodecError
10from typing_extensions import override
12from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin, TasksMixin
13from faststream._internal.endpoint.subscriber.usecase import SubscriberUsecase
14from faststream._internal.endpoint.utils import process_msg
15from faststream._internal.types import MsgType
16from faststream._internal.utils.path import compile_path
17from faststream.kafka.helpers import make_logging_listener
18from faststream.kafka.message import KafkaAckableMessage, KafkaMessage, KafkaRawMessage
19from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser
20from faststream.kafka.publisher.fake import KafkaFakePublisher
22if TYPE_CHECKING:
23 from aiokafka import AIOKafkaConsumer
25 from faststream._internal.endpoint.publisher import PublisherProto
26 from faststream._internal.endpoint.subscriber import SubscriberSpecification
27 from faststream._internal.endpoint.subscriber.call_item import CallsCollection
28 from faststream.kafka.configs import KafkaBrokerConfig
29 from faststream.message import StreamMessage
31 from .config import KafkaSubscriberConfig
34class LogicSubscriber(TasksMixin, SubscriberUsecase[MsgType]):
35 """A class to handle logic for consuming messages from Kafka."""
37 consumer: Optional["AIOKafkaConsumer"]
39 batch: bool
40 parser: AioKafkaParser
42 _outer_config: "KafkaBrokerConfig"
44 def __init__(
45 self,
46 config: "KafkaSubscriberConfig",
47 specification: "SubscriberSpecification[Any, Any]",
48 calls: "CallsCollection[MsgType]",
49 ) -> None:
50 super().__init__(config, specification, calls)
52 self._topics = config.topics
53 self._partitions = config.partitions
54 self.group_id = config.group_id
56 self._pattern = config.pattern
57 self._listener = config.listener
58 self._connection_args = config.connection_args
60 self.consumer = None
62 @property
63 def pattern(self) -> str | None:
64 if not self._pattern:
65 return self._pattern
66 return f"{self._outer_config.prefix}{self._pattern}"
68 @property
69 def topics(self) -> list[str]:
70 return [f"{self._outer_config.prefix}{t}" for t in self._topics]
72 @property
73 def partitions(self) -> list[TopicPartition]:
74 return [
75 TopicPartition(
76 topic=f"{self._outer_config.prefix}{p.topic}",
77 partition=p.partition,
78 )
79 for p in self._partitions
80 ]
82 @property
83 def builder(self) -> Callable[..., "AIOKafkaConsumer"]:
84 return self._outer_config.builder
86 @property
87 def client_id(self) -> str | None:
88 return self._outer_config.client_id
90 async def start(self) -> None:
91 """Start the consumer."""
92 await super().start()
94 self.consumer = consumer = self.builder(
95 group_id=self.group_id,
96 client_id=self.client_id,
97 **self._connection_args,
98 )
100 self.parser._setup(consumer)
102 if self.topics or self.pattern:
103 consumer.subscribe(
104 topics=self.topics,
105 pattern=self.pattern,
106 listener=make_logging_listener(
107 consumer=consumer,
108 logger=self._outer_config.logger.logger.logger,
109 log_extra=self.get_log_context(None),
110 listener=self._listener,
111 ),
112 )
114 elif self.partitions: 114 ↛ 117line 114 didn't jump to line 117 because the condition on line 114 was always true
115 consumer.assign(partitions=self.partitions)
117 await consumer.start()
119 self._post_start()
121 if self.calls:
122 self.add_task(self._run_consume_loop, (self.consumer,))
124 async def stop(self) -> None:
125 await super().stop()
127 if self.consumer is not None:
128 await self.consumer.stop()
129 self.consumer = None
131 @override
132 async def get_one(
133 self,
134 *,
135 timeout: float = 5.0,
136 ) -> "KafkaMessage | None":
137 assert not self.calls, (
138 "You can't use `get_one` method if subscriber has registered handlers."
139 )
141 assert self.consumer, "You should start subscriber at first."
143 raw_messages = await self.consumer.getmany(
144 timeout_ms=timeout * 1000,
145 max_records=1,
146 )
148 if not raw_messages:
149 return None
151 ((raw_message,),) = raw_messages.values()
153 context = self._outer_config.fd_config.context
155 async_parser, async_decoder = self._get_parser_and_decoder()
157 msg: KafkaMessage | None = await process_msg( # type: ignore[assignment]
158 msg=raw_message,
159 middlewares=(
160 m(raw_message, context=context) for m in self._broker_middlewares
161 ),
162 parser=async_parser,
163 decoder=async_decoder,
164 )
165 return msg
167 @override
168 async def __aiter__(self) -> AsyncIterator["KafkaMessage"]: # type: ignore[override]
169 assert self.consumer, "You should start subscriber at first."
170 assert not self.calls, (
171 "You can't use `get_one` method if subscriber has registered handlers."
172 )
174 context = self._outer_config.fd_config.context
175 async_parser, async_decoder = self._get_parser_and_decoder()
177 async for raw_message in self.consumer: 177 ↛ exitline 177 didn't return from function '__aiter__' because the loop on line 177 didn't complete
178 msg: KafkaMessage = await process_msg( # type: ignore[assignment]
179 msg=raw_message,
180 middlewares=(
181 m(raw_message, context=context) for m in self._broker_middlewares
182 ),
183 parser=async_parser,
184 decoder=async_decoder,
185 )
186 yield msg
188 def _make_response_publisher(
189 self,
190 message: "StreamMessage[Any]",
191 ) -> Sequence["PublisherProto"]:
192 return (
193 KafkaFakePublisher(
194 self._outer_config.producer,
195 topic=message.reply_to,
196 ),
197 )
199 @abstractmethod
200 async def get_msg(self, consumer: "AIOKafkaConsumer") -> MsgType:
201 raise NotImplementedError
203 async def _run_consume_loop(self, consumer: "AIOKafkaConsumer") -> None:
204 assert consumer, "You should start subscriber at first."
206 connected = True
207 while self.running:
208 try:
209 msg = await self.get_msg(consumer)
211 except UnsupportedCodecError as e: # noqa: PERF203
212 self._log(
213 logging.ERROR,
214 "There is no suitable compression library available. Please refer to the Kafka "
215 "documentation for more information - "
216 "https://aiokafka.readthedocs.io/en/stable/#installation",
217 exc_info=e,
218 )
219 await anyio.sleep(15)
221 except KafkaError as e:
222 self._log(logging.ERROR, "Kafka error occurred", exc_info=e)
224 if connected:
225 connected = False
227 await anyio.sleep(5)
229 except ConsumerStoppedError:
230 return
232 else:
233 if not connected: # pragma: no cover
234 connected = True
236 if msg: 236 ↛ 207line 236 didn't jump to line 207 because the condition on line 236 was always true
237 await self.consume_one(msg)
239 async def consume_one(self, msg: MsgType) -> None:
240 await self.consume(msg)
242 @property
243 def topic_names(self) -> list[str]:
244 if self.pattern:
245 topics = [self.pattern]
247 elif self.topics:
248 topics = self.topics
250 else:
251 topics = [f"{p.topic}-{p.partition}" for p in self.partitions]
253 return topics
255 @staticmethod
256 def build_log_context(
257 message: Optional["StreamMessage[Any]"],
258 topic: str,
259 group_id: str | None = None,
260 ) -> dict[str, str]:
261 return {
262 "topic": topic,
263 "group_id": group_id or "",
264 "message_id": getattr(message, "message_id", ""),
265 }
268class DefaultSubscriber(LogicSubscriber["ConsumerRecord"]):
269 def __init__(
270 self,
271 config: "KafkaSubscriberConfig",
272 specification: "SubscriberSpecification[Any, Any]",
273 calls: "CallsCollection[ConsumerRecord]",
274 ) -> None:
275 if config.pattern:
276 reg, pattern = compile_path(
277 config.pattern,
278 replace_symbol=".*",
279 patch_regex=lambda x: x.replace(r"\*", ".*"),
280 )
281 config.pattern = pattern
283 else:
284 reg = None
286 self.parser = AioKafkaParser(
287 msg_class=KafkaMessage if config.ack_first else KafkaAckableMessage,
288 regex=reg,
289 )
290 config.parser = self.parser.parse_message
291 config.decoder = self.parser.decode_message
292 super().__init__(config, specification, calls)
294 async def get_msg(self, consumer: "AIOKafkaConsumer") -> "ConsumerRecord":
295 assert consumer, "You should setup subscriber at first."
296 return await consumer.getone()
298 def get_log_context(
299 self,
300 message: Optional["StreamMessage[ConsumerRecord]"],
301 ) -> dict[str, str]:
302 if message is None:
303 topic = ",".join(self.topic_names)
304 else:
305 topic = message.raw_message.topic
307 return self.build_log_context(
308 message=message,
309 topic=topic,
310 group_id=self.group_id,
311 )
314class BatchSubscriber(LogicSubscriber[tuple["ConsumerRecord", ...]]):
315 def __init__(
316 self,
317 config: "KafkaSubscriberConfig",
318 specification: "SubscriberSpecification[Any, Any]",
319 calls: "CallsCollection[tuple[ConsumerRecord, ...]]",
320 batch_timeout_ms: int,
321 max_records: int | None,
322 ) -> None:
323 if config.pattern:
324 reg, pattern = compile_path(
325 config.pattern,
326 replace_symbol=".*",
327 patch_regex=lambda x: x.replace(r"\*", ".*"),
328 )
329 config.pattern = pattern
331 else:
332 reg = None
334 self.parser = AioKafkaBatchParser(
335 msg_class=KafkaMessage if config.ack_first else KafkaAckableMessage,
336 regex=reg,
337 )
338 config.decoder = self.parser.decode_batch
339 config.parser = self.parser.parse_batch
340 super().__init__(config, specification, calls)
342 self.batch_timeout_ms = batch_timeout_ms
343 self.max_records = max_records
345 async def get_msg(
346 self,
347 consumer: "AIOKafkaConsumer",
348 ) -> tuple["ConsumerRecord", ...]:
349 assert consumer, "You should setup subscriber at first."
351 messages = await consumer.getmany(
352 timeout_ms=self.batch_timeout_ms,
353 max_records=self.max_records,
354 )
356 if not messages: # pragma: no cover
357 await anyio.sleep(self.batch_timeout_ms / 1000)
358 return ()
360 return tuple(chain(*messages.values()))
362 def get_log_context(
363 self,
364 message: Optional["StreamMessage[tuple[ConsumerRecord, ...]]"],
365 ) -> dict[str, str]:
366 if message is None:
367 topic = ",".join(self.topic_names)
368 else:
369 topic = message.raw_message[0].topic
371 return self.build_log_context(
372 message=message,
373 topic=topic,
374 group_id=self.group_id,
375 )
378class ConcurrentDefaultSubscriber(ConcurrentMixin["ConsumerRecord"], DefaultSubscriber):
379 async def start(self) -> None:
380 await super().start()
381 self.start_consume_task()
383 async def consume_one(self, msg: "ConsumerRecord") -> None:
384 await self._put_msg(msg)
387class ConcurrentBetweenPartitionsSubscriber(DefaultSubscriber):
388 consumer_subgroup: list["AIOKafkaConsumer"]
390 def __init__(
391 self,
392 config: "KafkaSubscriberConfig",
393 specification: "SubscriberSpecification[Any, Any]",
394 calls: "CallsCollection[ConsumerRecord]",
395 max_workers: int,
396 ) -> None:
397 super().__init__(config, specification, calls)
399 self.max_workers = max_workers
400 self.consumer_subgroup = []
402 async def start(self) -> None:
403 """Start the consumer subgroup."""
404 await super(LogicSubscriber, self).start()
406 if self.calls: 406 ↛ 419line 406 didn't jump to line 419 because the condition on line 406 was always true
407 self.consumer_subgroup = [
408 self.builder(
409 group_id=self.group_id,
410 client_id=self.client_id,
411 **self._connection_args,
412 )
413 for _ in range(self.max_workers)
414 ]
416 else:
417 # We should create single consumer to support
418 # `get_one()` and `__aiter__` methods
419 self.consumer = self.builder(
420 group_id=self.group_id,
421 client_id=self.client_id,
422 **self._connection_args,
423 )
424 self.consumer_subgroup = [self.consumer]
426 # Subscribers starting should be called concurrently
427 # to balance them correctly
428 async with anyio.create_task_group() as tg:
429 for c in self.consumer_subgroup:
430 c.subscribe(
431 topics=self.topics,
432 listener=make_logging_listener(
433 consumer=c,
434 logger=self._outer_config.logger.logger.logger,
435 log_extra=self.get_log_context(None),
436 listener=self._listener,
437 ),
438 )
440 tg.start_soon(c.start)
442 self._post_start()
444 if self.calls: 444 ↛ exitline 444 didn't return from function 'start' because the condition on line 444 was always true
445 for c in self.consumer_subgroup:
446 self.add_task(self._run_consume_loop, (c,))
448 async def stop(self) -> None:
449 if self.consumer_subgroup: 449 ↛ 456line 449 didn't jump to line 456 because the condition on line 449 was always true
450 async with anyio.create_task_group() as tg:
451 for consumer in self.consumer_subgroup:
452 tg.start_soon(consumer.stop)
454 self.consumer_subgroup = []
456 await super().stop()
458 async def get_msg(self, consumer: "AIOKafkaConsumer") -> "KafkaRawMessage":
459 assert consumer, "You should setup subscriber at first."
460 message = await consumer.getone()
461 message.consumer = consumer
462 return cast("KafkaRawMessage", message)