Coverage for faststream / rabbit / testing.py: 96%
118 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1from collections.abc import Generator, Iterator, Mapping
2from contextlib import ExitStack, contextmanager
3from typing import TYPE_CHECKING, Any, Optional, Union, cast
4from unittest import mock
5from unittest.mock import AsyncMock
7import aiormq
8import anyio
9from aio_pika.message import IncomingMessage, encode_expiration
10from pamqp import commands as spec
11from pamqp.header import ContentHeader
12from typing_extensions import override
14from faststream._internal.endpoint.utils import ParserComposition
15from faststream._internal.testing.broker import TestBroker, change_producer
16from faststream.exceptions import SubscriberNotFound
17from faststream.message import gen_cor_id
18from faststream.rabbit.broker.broker import RabbitBroker
19from faststream.rabbit.parser import AioPikaParser
20from faststream.rabbit.publisher.producer import AioPikaFastProducer
21from faststream.rabbit.schemas import (
22 ExchangeType,
23 RabbitExchange,
24 RabbitQueue,
25)
27if TYPE_CHECKING:
28 from aio_pika.abc import DateType, HeadersType
29 from fast_depends.library.serializer import SerializerProto
31 from faststream.rabbit.publisher import RabbitPublisher
32 from faststream.rabbit.response import RabbitPublishCommand
33 from faststream.rabbit.subscriber import RabbitSubscriber
34 from faststream.rabbit.types import AioPikaSendableMessage
36__all__ = ("TestRabbitBroker",)
39class TestRabbitBroker(TestBroker[RabbitBroker]):
40 """A class to test RabbitMQ brokers."""
42 @contextmanager
43 def _patch_broker(self, broker: "RabbitBroker") -> Generator[None, None, None]:
44 with (
45 mock.patch.object(
46 broker,
47 "_channel",
48 new_callable=AsyncMock,
49 ),
50 mock.patch.object(
51 broker.config,
52 "declarer",
53 new_callable=AsyncMock,
54 ),
55 super()._patch_broker(broker),
56 ):
57 yield
59 @contextmanager
60 def _patch_producer(self, broker: RabbitBroker) -> Iterator[None]:
61 fake_producer = FakeProducer(broker)
63 with ExitStack() as es:
64 es.enter_context(
65 change_producer(broker.config.broker_config, fake_producer),
66 )
67 yield
69 @staticmethod
70 async def _fake_connect(broker: "RabbitBroker", *args: Any, **kwargs: Any) -> None:
71 pass
73 @staticmethod
74 def create_publisher_fake_subscriber(
75 broker: "RabbitBroker",
76 publisher: "RabbitPublisher",
77 ) -> tuple["RabbitSubscriber", bool]:
78 sub: RabbitSubscriber | None = None
79 for handler in broker.subscribers:
80 handler = cast("RabbitSubscriber", handler)
81 if _is_handler_matches(
82 handler,
83 publisher.routing(),
84 {},
85 publisher.exchange,
86 ):
87 sub = handler
88 break
90 if sub is None:
91 is_real = False
92 sub = broker.subscriber(
93 queue=publisher.routing(),
94 exchange=publisher.exchange,
95 persistent=False,
96 )
97 else:
98 is_real = True
100 return sub, is_real
103class PatchedMessage(IncomingMessage):
104 """Patched message class for testing purposes.
106 This class extends aio_pika's IncomingMessage class and is used to simulate RabbitMQ message handling during tests.
107 """
109 routing_key: str
111 async def ack(self, multiple: bool = False) -> None:
112 """Asynchronously acknowledge a message."""
114 async def nack(self, multiple: bool = False, requeue: bool = True) -> None:
115 """Nack the message."""
117 async def reject(self, requeue: bool = False) -> None:
118 """Rejects a task."""
121def build_message(
122 message: "AioPikaSendableMessage" = "",
123 queue: Union["RabbitQueue", str] = "",
124 exchange: Union["RabbitExchange", str, None] = None,
125 *,
126 routing_key: str = "",
127 persist: bool = False,
128 reply_to: str | None = None,
129 headers: Optional["HeadersType"] = None,
130 content_type: str | None = None,
131 content_encoding: str | None = None,
132 priority: int | None = None,
133 correlation_id: str | None = None,
134 expiration: Optional["DateType"] = None,
135 message_id: str | None = None,
136 timestamp: Optional["DateType"] = None,
137 message_type: str | None = None,
138 user_id: str | None = None,
139 app_id: str | None = None,
140 serializer: Optional["SerializerProto"] = None,
141) -> PatchedMessage:
142 """Build a patched RabbitMQ message for testing."""
143 que = RabbitQueue.validate(queue)
144 exch = RabbitExchange.validate(exchange)
146 routing = routing_key or que.routing()
148 correlation_id = correlation_id or gen_cor_id()
149 msg = AioPikaParser.encode_message(
150 message=message,
151 persist=persist,
152 reply_to=reply_to,
153 headers=headers,
154 content_type=content_type,
155 content_encoding=content_encoding,
156 priority=priority,
157 correlation_id=correlation_id,
158 expiration=expiration,
159 message_id=message_id or correlation_id,
160 timestamp=timestamp,
161 message_type=message_type,
162 user_id=user_id,
163 app_id=app_id,
164 serializer=serializer,
165 )
167 return PatchedMessage(
168 aiormq.abc.DeliveredMessage(
169 delivery=spec.Basic.Deliver(
170 exchange=getattr(exch, "name", ""),
171 routing_key=routing,
172 ),
173 header=ContentHeader(
174 properties=spec.Basic.Properties(
175 content_type=msg.content_type,
176 headers=msg.headers,
177 reply_to=msg.reply_to,
178 content_encoding=msg.content_encoding,
179 priority=msg.priority,
180 correlation_id=msg.correlation_id,
181 expiration=encode_expiration(msg.expiration),
182 message_id=msg.message_id,
183 timestamp=msg.timestamp,
184 message_type=message_type,
185 user_id=msg.user_id,
186 app_id=msg.app_id,
187 ),
188 ),
189 body=msg.body,
190 channel=AsyncMock(),
191 ),
192 )
195class FakeProducer(AioPikaFastProducer):
196 """A fake RabbitMQ producer for testing purposes.
198 This class extends AioPikaFastProducer and is used to simulate RabbitMQ message publishing during tests.
199 """
201 def __init__(self, broker: RabbitBroker) -> None:
202 self.broker = broker
204 default_parser = AioPikaParser()
205 self._parser = ParserComposition(broker._parser, default_parser.parse_message)
206 self._decoder = ParserComposition(
207 broker._decoder,
208 default_parser.decode_message,
209 )
211 @override
212 async def publish(
213 self,
214 cmd: "RabbitPublishCommand",
215 ) -> None:
216 """Publish a message to a RabbitMQ queue or exchange."""
217 incoming = build_message(
218 message=cmd.body,
219 exchange=cmd.exchange,
220 routing_key=cmd.destination,
221 correlation_id=cmd.correlation_id,
222 headers=cmd.headers,
223 reply_to=cmd.reply_to,
224 serializer=self.broker.config.fd_config._serializer,
225 **cmd.message_options,
226 )
228 called = False
229 for handler in self.broker.subscribers: # pragma: no branch
230 handler = cast("RabbitSubscriber", handler)
231 if _is_handler_matches(
232 handler,
233 incoming.routing_key,
234 incoming.headers,
235 cmd.exchange,
236 ):
237 called = True
238 await self._execute_handler(incoming, handler)
240 if not called:
241 raise SubscriberNotFound
243 @override
244 async def request(
245 self,
246 cmd: "RabbitPublishCommand",
247 ) -> "PatchedMessage":
248 """Make a synchronous request to RabbitMQ."""
249 incoming = build_message(
250 message=cmd.body,
251 exchange=cmd.exchange,
252 routing_key=cmd.destination,
253 correlation_id=cmd.correlation_id,
254 headers=cmd.headers,
255 serializer=self.broker.config.fd_config._serializer,
256 **cmd.message_options,
257 )
259 for handler in self.broker.subscribers: # pragma: no branch
260 handler = cast("RabbitSubscriber", handler)
261 if _is_handler_matches(
262 handler,
263 incoming.routing_key,
264 incoming.headers,
265 cmd.exchange,
266 ):
267 with anyio.fail_after(cmd.timeout):
268 return await self._execute_handler(incoming, handler)
270 raise SubscriberNotFound
272 async def _execute_handler(
273 self,
274 msg: PatchedMessage,
275 handler: "RabbitSubscriber",
276 ) -> "PatchedMessage":
277 result = await handler.process_message(msg)
278 return build_message(
279 routing_key=msg.routing_key,
280 message=result.body,
281 headers=result.headers,
282 correlation_id=result.correlation_id,
283 serializer=self.broker.config.fd_config._serializer,
284 )
287def _is_handler_matches(
288 handler: "RabbitSubscriber",
289 routing_key: str,
290 headers: Optional["Mapping[Any, Any]"] = None,
291 exchange: Optional["RabbitExchange"] = None,
292) -> bool:
293 headers = headers or {}
294 exchange = RabbitExchange.validate(exchange)
296 if handler.exchange != exchange:
297 return False
299 if handler.exchange is None or handler.exchange.type == ExchangeType.DIRECT:
300 return handler.routing() == routing_key
302 if handler.exchange.type == ExchangeType.FANOUT:
303 return True
305 if handler.exchange.type == ExchangeType.TOPIC:
306 return apply_pattern(handler.routing(), routing_key)
308 if handler.exchange.type == ExchangeType.HEADERS:
309 queue_headers = (handler.queue.bind_arguments or {}).copy()
311 if not queue_headers: 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true
312 return True
314 match_rule = queue_headers.pop("x-match", "all")
316 full_match = True
317 is_headers_empty = True
318 for k, v in queue_headers.items():
319 if headers.get(k) != v:
320 full_match = False
321 else:
322 is_headers_empty = False
324 if is_headers_empty:
325 return False
327 return full_match or (match_rule == "any")
329 raise AssertionError
332def apply_pattern(pattern: str, current: str) -> bool:
333 """Apply a pattern to a routing key."""
334 pattern_queue = iter(pattern.split("."))
335 current_queue = iter(current.split("."))
337 pattern_symb = next(pattern_queue, None)
338 while pattern_symb:
339 if (next_symb := next(current_queue, None)) is None:
340 return False
342 if pattern_symb == "#":
343 next_pattern = next(pattern_queue, None)
345 if next_pattern is None:
346 return True
348 if (next_symb := next(current_queue, None)) is None: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true
349 return False
351 while next_pattern == "*":
352 next_pattern = next(pattern_queue, None)
353 if (next_symb := next(current_queue, None)) is None:
354 return False
356 while next_symb != next_pattern:
357 if (next_symb := next(current_queue, None)) is None: 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true
358 return False
360 pattern_symb = next(pattern_queue, None)
362 elif pattern_symb in {"*", next_symb}:
363 pattern_symb = next(pattern_queue, None)
365 else:
366 return False
368 return next(current_queue, None) is None