Coverage for faststream / confluent / subscriber / usecase.py: 98%
95 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import logging
2from abc import abstractmethod
3from collections.abc import AsyncIterator, Sequence
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Optional,
8 cast,
9)
11import anyio
12from confluent_kafka import KafkaException, Message
13from typing_extensions import override
15from faststream._internal.endpoint.subscriber import SubscriberUsecase
16from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin, TasksMixin
17from faststream._internal.endpoint.utils import process_msg
18from faststream._internal.types import MsgType
19from faststream.confluent.parser import AsyncConfluentParser
20from faststream.confluent.publisher.fake import KafkaFakePublisher
21from faststream.confluent.schemas import TopicPartition
23if TYPE_CHECKING:
24 from faststream._internal.endpoint.publisher import PublisherProto
25 from faststream._internal.endpoint.subscriber import SubscriberSpecification
26 from faststream._internal.endpoint.subscriber.call_item import CallsCollection
27 from faststream.confluent.configs import KafkaBrokerConfig
28 from faststream.confluent.helpers.client import AsyncConfluentConsumer
29 from faststream.confluent.message import KafkaMessage
30 from faststream.message import StreamMessage
32 from .config import KafkaSubscriberConfig
35class LogicSubscriber(TasksMixin, SubscriberUsecase[MsgType]):
36 """A class to handle logic for consuming messages from Kafka."""
38 _outer_config: "KafkaBrokerConfig"
40 group_id: str | None
42 consumer: Optional["AsyncConfluentConsumer"]
43 parser: AsyncConfluentParser
45 def __init__(
46 self,
47 config: "KafkaSubscriberConfig",
48 specification: "SubscriberSpecification[Any, Any]",
49 calls: "CallsCollection[MsgType]",
50 ) -> None:
51 super().__init__(config, specification, calls)
53 self.__connection_data = config.connection_data
55 self.group_id = config.group_id
57 self._topics = config.topics
58 self._partitions = config.partitions
60 self.consumer = None
61 self.polling_interval = config.polling_interval
63 @property
64 def client_id(self) -> str | None:
65 return self._outer_config.client_id
67 @property
68 def topics(self) -> list[str]:
69 return [f"{self._outer_config.prefix}{t}" for t in self._topics]
71 @property
72 def partitions(self) -> list[TopicPartition]:
73 return [p.add_prefix(self._outer_config.prefix) for p in self._partitions]
75 @override
76 async def start(self) -> None:
77 """Start the consumer."""
78 await super().start()
79 self.consumer = consumer = self._outer_config.builder(
80 *self.topics,
81 partitions=self.partitions,
82 group_id=self.group_id,
83 client_id=self.client_id,
84 **self.__connection_data,
85 )
86 self.parser._setup(consumer)
87 await consumer.start()
89 self._post_start()
91 if self.calls:
92 self.add_task(self._consume)
94 async def stop(self) -> None:
95 await super().stop()
97 if self.consumer is not None:
98 await self.consumer.stop()
99 self.consumer = None
101 @override
102 async def get_one(
103 self,
104 *,
105 timeout: float = 5.0,
106 ) -> "KafkaMessage | None":
107 assert self.consumer, "You should start subscriber at first."
108 assert not self.calls, (
109 "You can't use `get_one` method if subscriber has registered handlers."
110 )
112 raw_message = await self.consumer.getone(timeout=timeout)
114 context = self._outer_config.fd_config.context
116 async_parser, async_decoder = self._get_parser_and_decoder()
118 return await process_msg( # type: ignore[return-value]
119 msg=raw_message,
120 middlewares=(
121 m(raw_message, context=context) for m in self._broker_middlewares
122 ),
123 parser=async_parser,
124 decoder=async_decoder,
125 )
127 @override
128 async def __aiter__(self) -> AsyncIterator["KafkaMessage"]: # type: ignore[override]
129 assert self.consumer, "You should start subscriber at first."
130 assert not self.calls, (
131 "You can't use iterator if subscriber has registered handlers."
132 )
134 context = self._outer_config.fd_config.context
135 async_parser, async_decoder = self._get_parser_and_decoder()
137 timeout = 5.0
138 while True:
139 raw_message = await self.consumer.getone(timeout=timeout)
141 if raw_message is None: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 continue
144 yield cast(
145 "KafkaMessage",
146 await process_msg(
147 msg=raw_message,
148 middlewares=(
149 m(raw_message, context=context) for m in self._broker_middlewares
150 ),
151 parser=async_parser,
152 decoder=async_decoder,
153 ),
154 )
156 def _make_response_publisher(
157 self,
158 message: "StreamMessage[Any]",
159 ) -> Sequence["PublisherProto"]:
160 return (
161 KafkaFakePublisher(
162 self._outer_config.producer,
163 topic=message.reply_to,
164 ),
165 )
167 async def consume_one(self, msg: MsgType) -> None:
168 await self.consume(msg)
170 @abstractmethod
171 async def get_msg(self) -> MsgType | None:
172 raise NotImplementedError
174 async def _consume(self) -> None:
175 assert self.consumer, "You should start subscriber at first."
177 connected = True
178 while self.running:
179 try:
180 msg = await self.get_msg()
182 except KafkaException as e: # pragma: no cover # noqa: PERF203
183 self._log(
184 logging.ERROR,
185 message="Message fetch error",
186 exc_info=e,
187 )
189 if connected:
190 connected = False
192 await anyio.sleep(5)
194 else:
195 if not connected: # pragma: no cover
196 connected = True
198 if msg is not None:
199 await self.consume_one(msg)
201 @property
202 def topic_names(self) -> list[str]:
203 topics = self.topics or (f"{p.topic}-{p.partition}" for p in self.partitions)
204 return [f"{self._outer_config.prefix}{t}" for t in topics]
206 @staticmethod
207 def build_log_context(
208 message: Optional["StreamMessage[Any]"],
209 topic: str,
210 group_id: str | None = None,
211 ) -> dict[str, str]:
212 return {
213 "topic": topic,
214 "group_id": group_id or "",
215 "message_id": getattr(message, "message_id", ""),
216 }
219class DefaultSubscriber(LogicSubscriber[Message]):
220 def __init__(
221 self,
222 config: "KafkaSubscriberConfig",
223 specification: "SubscriberSpecification[Any, Any]",
224 calls: "CallsCollection[Message]",
225 ) -> None:
226 self.parser = AsyncConfluentParser(is_manual=not config.ack_first)
227 config.decoder = self.parser.decode_message
228 config.parser = self.parser.parse_message
229 super().__init__(config, specification, calls)
231 async def get_msg(self) -> Optional["Message"]:
232 assert self.consumer, "You should setup subscriber at first."
233 return await self.consumer.getone(timeout=self.polling_interval)
235 def get_log_context(
236 self,
237 message: Optional["StreamMessage[Message]"],
238 ) -> dict[str, str]:
239 if message is None:
240 topic = ",".join(self.topic_names)
241 else:
242 topic = message.raw_message.topic() or ",".join(self.topics)
244 return self.build_log_context(
245 message=message,
246 topic=topic,
247 group_id=self.group_id,
248 )
251class ConcurrentDefaultSubscriber(ConcurrentMixin["Message"], DefaultSubscriber):
252 async def start(self) -> None:
253 await super().start()
254 self.start_consume_task()
256 async def consume_one(self, msg: "Message") -> None:
257 await self._put_msg(msg)
260class BatchSubscriber(LogicSubscriber[tuple[Message, ...]]):
261 def __init__(
262 self,
263 config: "KafkaSubscriberConfig",
264 specification: "SubscriberSpecification[Any, Any]",
265 calls: "CallsCollection[tuple[Message, ...]]",
266 max_records: int | None,
267 ) -> None:
268 self.parser = AsyncConfluentParser(is_manual=not config.ack_first)
269 config.decoder = self.parser.decode_batch
270 config.parser = self.parser.parse_batch
271 super().__init__(config, specification, calls)
273 self.max_records = max_records
275 async def get_msg(self) -> tuple["Message", ...] | None:
276 assert self.consumer, "You should setup subscriber at first."
277 return (
278 await self.consumer.getmany(
279 timeout=self.polling_interval,
280 max_records=self.max_records,
281 )
282 or None
283 )
285 def get_log_context(
286 self,
287 message: Optional["StreamMessage[tuple[Message, ...]]"],
288 ) -> dict[str, str]:
289 if message is None:
290 topic = ",".join(self.topic_names)
291 else:
292 topic = message.raw_message[0].topic() or ",".join(self.topic_names)
294 return self.build_log_context(
295 message=message,
296 topic=topic,
297 group_id=self.group_id,
298 )