Coverage for faststream / confluent / testing.py: 86%
111 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1from collections.abc import Callable, Generator, Iterable, Iterator
2from contextlib import ExitStack, contextmanager
3from datetime import datetime, timezone
4from typing import TYPE_CHECKING, Any, Optional, cast
5from unittest.mock import AsyncMock, MagicMock
7import anyio
8from typing_extensions import override
10from faststream._internal.endpoint.utils import ParserComposition
11from faststream._internal.testing.broker import TestBroker, change_producer
12from faststream.confluent.broker import KafkaBroker
13from faststream.confluent.parser import AsyncConfluentParser
14from faststream.confluent.publisher.producer import AsyncConfluentFastProducer
15from faststream.confluent.publisher.usecase import BatchPublisher
16from faststream.confluent.schemas import TopicPartition
17from faststream.confluent.subscriber.usecase import BatchSubscriber
18from faststream.exceptions import SubscriberNotFound
19from faststream.message import encode_message, gen_cor_id
21if TYPE_CHECKING:
22 from fast_depends.library.serializer import SerializerProto
24 from faststream._internal.basic_types import SendableMessage
25 from faststream.confluent.publisher.usecase import LogicPublisher
26 from faststream.confluent.response import KafkaPublishCommand
27 from faststream.confluent.subscriber.usecase import LogicSubscriber
30__all__ = ("TestKafkaBroker",)
33class TestKafkaBroker(TestBroker[KafkaBroker]):
34 """A class to test Kafka brokers."""
36 @contextmanager
37 def _patch_producer(self, broker: KafkaBroker) -> Iterator[None]:
38 fake_producer = FakeProducer(broker)
40 with ExitStack() as es:
41 es.enter_context(
42 change_producer(broker.config.broker_config, fake_producer),
43 )
44 yield
46 @staticmethod
47 async def _fake_connect( # type: ignore[override]
48 broker: KafkaBroker,
49 *args: Any,
50 **kwargs: Any,
51 ) -> Callable[..., AsyncMock]:
52 broker.config.broker_config.admin.admin_client = MagicMock()
53 return _fake_connection
55 @staticmethod
56 def create_publisher_fake_subscriber(
57 broker: KafkaBroker,
58 publisher: "LogicPublisher",
59 ) -> tuple["LogicSubscriber[Any]", bool]:
60 sub: LogicSubscriber[Any] | None = None
61 for handler in broker.subscribers:
62 handler = cast("LogicSubscriber[Any]", handler)
63 if _is_handler_matches(
64 handler,
65 topic=publisher.topic,
66 partition=publisher.partition,
67 ):
68 sub = handler
69 break
71 if sub is None:
72 is_real = False
74 topic_name = publisher.topic
76 if publisher.partition: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 tp = TopicPartition(
78 topic=topic_name,
79 partition=publisher.partition,
80 )
81 sub = broker.subscriber(
82 partitions=[tp],
83 batch=isinstance(publisher, BatchPublisher),
84 auto_offset_reset="earliest",
85 persistent=False,
86 )
87 else:
88 sub = broker.subscriber(
89 topic_name,
90 batch=isinstance(publisher, BatchPublisher),
91 auto_offset_reset="earliest",
92 persistent=False,
93 )
94 else:
95 is_real = True
97 return sub, is_real
100class FakeProducer(AsyncConfluentFastProducer):
101 """A fake Kafka producer for testing purposes.
103 This class extends AsyncConfluentFastProducer and is used to simulate Kafka message publishing during tests.
104 """
106 def __init__(self, broker: KafkaBroker) -> None:
107 self.broker = broker
109 default = AsyncConfluentParser()
110 self._parser = ParserComposition(broker._parser, default.parse_message)
111 self._decoder = ParserComposition(broker._decoder, default.decode_message)
113 def __bool__(self) -> bool:
114 return True
116 async def ping(self, timeout: float) -> bool:
117 return True
119 @override
120 async def publish(self, cmd: "KafkaPublishCommand") -> None:
121 """Publish a message to the Kafka broker."""
122 incoming = build_message(
123 message=cmd.body,
124 topic=cmd.destination,
125 key=cmd.key,
126 partition=cmd.partition,
127 timestamp_ms=cmd.timestamp_ms,
128 headers=cmd.headers,
129 correlation_id=cmd.correlation_id,
130 reply_to=cmd.reply_to,
131 serializer=self.broker.config.fd_config._serializer,
132 )
134 for handler in _find_handler(
135 cast("Iterable[LogicSubscriber[Any]]", self.broker.subscribers),
136 cmd.destination,
137 cmd.partition,
138 ):
139 msg_to_send = [incoming] if isinstance(handler, BatchSubscriber) else incoming
141 await self._execute_handler(msg_to_send, cmd.destination, handler)
143 @override
144 async def publish_batch(self, cmd: "KafkaPublishCommand") -> None:
145 """Publish a batch of messages to the Kafka broker."""
146 for handler in _find_handler(
147 cast("Iterable[LogicSubscriber[Any]]", self.broker.subscribers),
148 cmd.destination,
149 cmd.partition,
150 ):
151 messages = (
152 build_message(
153 message=message,
154 topic=cmd.destination,
155 partition=cmd.partition,
156 timestamp_ms=cmd.timestamp_ms,
157 key=cmd.key_for(message_position),
158 headers=cmd.headers,
159 correlation_id=cmd.correlation_id,
160 reply_to=cmd.reply_to,
161 serializer=self.broker.config.fd_config._serializer,
162 )
163 for message_position, message in enumerate(cmd.batch_bodies)
164 )
166 if isinstance(handler, BatchSubscriber): 166 ↛ 170line 166 didn't jump to line 170 because the condition on line 166 was always true
167 await self._execute_handler(list(messages), cmd.destination, handler)
169 else:
170 for m in messages:
171 await self._execute_handler(m, cmd.destination, handler)
173 @override
174 async def request(self, cmd: "KafkaPublishCommand") -> "MockConfluentMessage":
175 incoming = build_message(
176 message=cmd.body,
177 topic=cmd.destination,
178 key=cmd.key,
179 partition=cmd.partition,
180 timestamp_ms=cmd.timestamp_ms,
181 headers=cmd.headers,
182 correlation_id=cmd.correlation_id,
183 serializer=self.broker.config.fd_config._serializer,
184 )
186 for handler in _find_handler(
187 cast("Iterable[LogicSubscriber[Any]]", self.broker.subscribers),
188 cmd.destination,
189 cmd.partition,
190 ):
191 msg_to_send = [incoming] if isinstance(handler, BatchSubscriber) else incoming
193 with anyio.fail_after(cmd.timeout):
194 return await self._execute_handler(
195 msg_to_send,
196 cmd.destination,
197 handler,
198 )
200 raise SubscriberNotFound
202 async def _execute_handler(
203 self,
204 msg: Any,
205 topic: str,
206 handler: "LogicSubscriber[Any]",
207 ) -> "MockConfluentMessage":
208 result = await handler.process_message(msg)
210 return build_message(
211 topic=topic,
212 message=result.body,
213 headers=result.headers,
214 correlation_id=result.correlation_id or gen_cor_id(),
215 serializer=self.broker.config.fd_config._serializer,
216 )
219class MockConfluentMessage:
220 def __init__(
221 self,
222 raw_msg: bytes,
223 topic: str,
224 key: bytes | str,
225 headers: list[tuple[str, bytes]],
226 offset: int,
227 partition: int,
228 timestamp_type: int,
229 timestamp_ms: int,
230 error: str | None = None,
231 ) -> None:
232 self._raw_msg = raw_msg
233 self._topic = topic
235 if isinstance(key, str): 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true
236 self._key = key.encode()
237 else:
238 self._key = key
240 self._headers = headers
241 self._error = error
242 self._offset = offset
243 self._partition = partition
244 self._timestamp = (timestamp_type, timestamp_ms)
246 def len(self) -> int:
247 return len(self._raw_msg)
249 def error(self) -> str | None:
250 return self._error
252 def headers(self) -> list[tuple[str, bytes]]:
253 return self._headers
255 def key(self) -> bytes:
256 return self._key
258 def offset(self) -> int:
259 return self._offset
261 def partition(self) -> int:
262 return self._partition
264 def timestamp(self) -> tuple[int, int]:
265 return self._timestamp
267 def topic(self) -> str:
268 return self._topic
270 def value(self) -> bytes:
271 return self._raw_msg
274def build_message(
275 message: "SendableMessage",
276 topic: str,
277 *,
278 correlation_id: str | None = None,
279 partition: int | None = None,
280 timestamp_ms: int | None = None,
281 key: bytes | str | None = None,
282 headers: dict[str, str] | None = None,
283 reply_to: str = "",
284 serializer: Optional["SerializerProto"] = None,
285) -> MockConfluentMessage:
286 """Build a mock confluent_kafka.Message for a sendable message."""
287 msg, content_type = encode_message(message, serializer)
288 k = key or b""
289 headers = {
290 "content-type": content_type or "",
291 "correlation_id": correlation_id or gen_cor_id(),
292 "reply_to": reply_to,
293 **(headers or {}),
294 }
296 # https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#confluent_kafka.Message.timestamp
297 return MockConfluentMessage(
298 raw_msg=msg,
299 topic=topic,
300 key=k,
301 headers=[(i, j.encode()) for i, j in headers.items()],
302 offset=0,
303 partition=partition or 0,
304 timestamp_type=1,
305 timestamp_ms=timestamp_ms or int(datetime.now(timezone.utc).timestamp() * 1000),
306 )
309def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock:
310 mock = AsyncMock()
311 mock.getone.return_value = MagicMock()
312 mock.getmany.return_value = [MagicMock()]
313 return mock
316def _find_handler(
317 subscribers: Iterable["LogicSubscriber[Any]"],
318 topic: str,
319 partition: int | None,
320) -> Generator["LogicSubscriber[Any]", None, None]:
321 published_groups = set()
322 for handler in subscribers: # pragma: no branch
323 if _is_handler_matches(handler, topic, partition):
324 if handler.group_id:
325 if handler.group_id in published_groups:
326 continue
327 else:
328 published_groups.add(handler.group_id)
329 yield handler
332def _is_handler_matches(
333 handler: "LogicSubscriber[Any]",
334 topic: str,
335 partition: int | None,
336) -> bool:
337 return bool(
338 any(
339 p.topic == topic and (partition is None or p.partition == partition)
340 for p in handler.partitions
341 )
342 or topic in handler.topics,
343 )