Coverage for faststream / kafka / testing.py: 88%
88 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import re
2from collections.abc import Callable, Generator, Iterable, Iterator
3from contextlib import ExitStack, contextmanager
4from datetime import datetime, timezone
5from typing import TYPE_CHECKING, Any, Optional, cast
6from unittest.mock import AsyncMock, MagicMock
8import anyio
9from aiokafka import ConsumerRecord
10from typing_extensions import override
12from faststream._internal.endpoint.utils import ParserComposition
13from faststream._internal.testing.broker import TestBroker, change_producer
14from faststream.exceptions import SubscriberNotFound
15from faststream.kafka import TopicPartition
16from faststream.kafka.broker import KafkaBroker
17from faststream.kafka.message import KafkaMessage
18from faststream.kafka.parser import AioKafkaParser
19from faststream.kafka.publisher.producer import AioKafkaFastProducer
20from faststream.kafka.publisher.usecase import BatchPublisher
21from faststream.kafka.subscriber.usecase import BatchSubscriber
22from faststream.message import encode_message, gen_cor_id
24if TYPE_CHECKING:
25 from fast_depends.library.serializer import SerializerProto
27 from faststream._internal.basic_types import SendableMessage
28 from faststream.kafka.publisher.usecase import LogicPublisher
29 from faststream.kafka.response import KafkaPublishCommand
30 from faststream.kafka.subscriber.usecase import LogicSubscriber
32__all__ = ("TestKafkaBroker",)
35class TestKafkaBroker(TestBroker[KafkaBroker]):
36 """A class to test Kafka brokers."""
38 @contextmanager
39 def _patch_producer(self, broker: KafkaBroker) -> Iterator[None]:
40 fake_producer = FakeProducer(broker)
42 with ExitStack() as es:
43 es.enter_context(
44 change_producer(broker.config.broker_config, fake_producer),
45 )
46 yield
48 @staticmethod
49 async def _fake_connect( # type: ignore[override]
50 broker: KafkaBroker,
51 *args: Any,
52 **kwargs: Any,
53 ) -> Callable[..., AsyncMock]:
54 broker.config.broker_config._admin_client = AsyncMock()
56 builder = MagicMock(return_value=FakeConsumer())
57 broker.config.broker_config.builder = builder
59 return _fake_connection
61 @staticmethod
62 def create_publisher_fake_subscriber(
63 broker: KafkaBroker,
64 publisher: "LogicPublisher",
65 ) -> tuple["LogicSubscriber[Any]", bool]:
66 sub: LogicSubscriber[Any] | None = None
67 for handler in broker.subscribers:
68 handler = cast("LogicSubscriber[Any]", handler)
69 if _is_handler_matches(handler, publisher.topic, publisher.partition):
70 sub = handler
71 break
73 if sub is None:
74 is_real = False
76 topic_name = publisher.topic
78 if publisher.partition: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 tp = TopicPartition(
80 topic=topic_name,
81 partition=publisher.partition,
82 )
83 sub = broker.subscriber(
84 partitions=[tp],
85 batch=isinstance(publisher, BatchPublisher),
86 persistent=False,
87 )
88 else:
89 sub = broker.subscriber(
90 topic_name,
91 batch=isinstance(publisher, BatchPublisher),
92 persistent=False,
93 )
94 else:
95 is_real = True
97 return sub, is_real
100class FakeConsumer:
101 async def start(self) -> None:
102 pass
104 async def stop(self) -> None:
105 pass
107 def subscribe(self, *args: Any, **kwargs: Any) -> None:
108 pass
111class FakeProducer(AioKafkaFastProducer):
112 """A fake Kafka producer for testing purposes.
114 This class extends AioKafkaFastProducer and is used to simulate Kafka message publishing during tests.
115 """
117 def __init__(self, broker: KafkaBroker) -> None:
118 self.broker = broker
120 default = AioKafkaParser(
121 msg_class=KafkaMessage,
122 regex=None,
123 )
125 self._parser = ParserComposition(broker._parser, default.parse_message)
126 self._decoder = ParserComposition(broker._decoder, default.decode_message)
128 def __bool__(self) -> bool:
129 return True
131 @property
132 def closed(self) -> bool:
133 return False
135 @override
136 async def publish(self, cmd: "KafkaPublishCommand") -> None:
137 """Publish a message to the Kafka broker."""
138 incoming = build_message(
139 message=cmd.body,
140 topic=cmd.destination,
141 key=cmd.key,
142 partition=cmd.partition,
143 timestamp_ms=cmd.timestamp_ms,
144 headers=cmd.headers,
145 correlation_id=cmd.correlation_id,
146 reply_to=cmd.reply_to,
147 serializer=self.broker.config.fd_config._serializer,
148 )
150 for handler in _find_handler(
151 cast("list[LogicSubscriber[Any]]", self.broker.subscribers),
152 cmd.destination,
153 cmd.partition,
154 ):
155 msg_to_send = [incoming] if isinstance(handler, BatchSubscriber) else incoming
157 await self._execute_handler(msg_to_send, cmd.destination, handler)
159 @override
160 async def request(self, cmd: "KafkaPublishCommand") -> "ConsumerRecord":
161 incoming = build_message(
162 message=cmd.body,
163 topic=cmd.destination,
164 key=cmd.key,
165 partition=cmd.partition,
166 timestamp_ms=cmd.timestamp_ms,
167 headers=cmd.headers,
168 correlation_id=cmd.correlation_id,
169 serializer=self.broker.config.fd_config._serializer,
170 )
172 for handler in _find_handler(
173 cast("list[LogicSubscriber[Any]]", self.broker.subscribers),
174 cmd.destination,
175 cmd.partition,
176 ):
177 msg_to_send = [incoming] if isinstance(handler, BatchSubscriber) else incoming
179 with anyio.fail_after(cmd.timeout):
180 return await self._execute_handler(
181 msg_to_send,
182 cmd.destination,
183 handler,
184 )
186 raise SubscriberNotFound
188 @override
189 async def publish_batch(
190 self,
191 cmd: "KafkaPublishCommand",
192 ) -> None:
193 """Publish a batch of messages to the Kafka broker."""
194 for handler in _find_handler(
195 cast("list[LogicSubscriber[Any]]", self.broker.subscribers),
196 cmd.destination,
197 cmd.partition,
198 ):
199 messages = (
200 build_message(
201 message=message,
202 topic=cmd.destination,
203 partition=cmd.partition,
204 timestamp_ms=cmd.timestamp_ms,
205 key=cmd.key_for(message_position),
206 headers=cmd.headers,
207 correlation_id=cmd.correlation_id,
208 reply_to=cmd.reply_to,
209 serializer=self.broker.config.fd_config._serializer,
210 )
211 for message_position, message in enumerate(cmd.batch_bodies)
212 )
214 if isinstance(handler, BatchSubscriber): 214 ↛ 218line 214 didn't jump to line 218 because the condition on line 214 was always true
215 await self._execute_handler(list(messages), cmd.destination, handler)
217 else:
218 for m in messages:
219 await self._execute_handler(m, cmd.destination, handler)
221 async def _execute_handler(
222 self,
223 msg: Any,
224 topic: str,
225 handler: "LogicSubscriber[Any]",
226 ) -> "ConsumerRecord":
227 result = await handler.process_message(msg)
229 return build_message(
230 topic=topic,
231 message=result.body,
232 headers=result.headers,
233 correlation_id=result.correlation_id,
234 serializer=self.broker.config.fd_config._serializer,
235 )
238def build_message(
239 message: "SendableMessage",
240 topic: str,
241 partition: int | None = None,
242 timestamp_ms: int | None = None,
243 key: bytes | None = None,
244 headers: dict[str, str] | None = None,
245 correlation_id: str | None = None,
246 *,
247 reply_to: str = "",
248 serializer: Optional["SerializerProto"],
249) -> "ConsumerRecord":
250 """Build a Kafka ConsumerRecord for a sendable message."""
251 msg, content_type = encode_message(message, serializer=serializer)
253 k = key or b""
255 headers = {
256 "content-type": content_type or "",
257 "correlation_id": correlation_id or gen_cor_id(),
258 **(headers or {}),
259 }
261 if reply_to:
262 headers["reply_to"] = headers.get("reply_to", reply_to)
264 return ConsumerRecord(
265 value=msg,
266 topic=topic,
267 partition=partition or 0,
268 key=k,
269 serialized_key_size=len(k),
270 serialized_value_size=len(msg),
271 checksum=sum(msg),
272 offset=0,
273 headers=[(i, j.encode()) for i, j in headers.items()],
274 timestamp_type=1,
275 timestamp=timestamp_ms or int(datetime.now(timezone.utc).timestamp() * 1000),
276 )
279def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock:
280 mock = AsyncMock()
281 mock.subscribe = MagicMock
282 mock.assign = MagicMock
283 return mock
286def _find_handler(
287 subscribers: Iterable["LogicSubscriber[Any]"],
288 topic: str,
289 partition: int | None,
290) -> Generator["LogicSubscriber[Any]", None, None]:
291 published_groups = set()
292 for handler in subscribers: # pragma: no branch
293 if _is_handler_matches(handler, topic, partition):
294 if handler.group_id:
295 if handler.group_id in published_groups:
296 continue
297 else:
298 published_groups.add(handler.group_id)
299 yield handler
302def _is_handler_matches(
303 handler: "LogicSubscriber[Any]",
304 topic: str,
305 partition: int | None,
306) -> bool:
307 return bool(
308 any(
309 p.topic == topic and (partition is None or p.partition == partition)
310 for p in handler.partitions
311 )
312 or topic in handler.topics
313 or (handler.pattern and re.match(handler.pattern, topic)),
314 )