Coverage for tests / brokers / base / middlewares.py: 99%
468 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import asyncio
2from unittest.mock import MagicMock, call
4import pytest
6from faststream import Context
7from faststream._internal.basic_types import DecodedMessage
8from faststream.exceptions import SkipMessage
9from faststream.middlewares import BaseMiddleware, ExceptionMiddleware
11from .basic import BaseTestcaseConfig
14@pytest.mark.asyncio()
15class MiddlewaresOrderTestcase(BaseTestcaseConfig):
16 async def test_broker_middleware_order(self, queue: str, mock: MagicMock) -> None:
17 class InnerMiddleware(BaseMiddleware):
18 async def __aenter__(self) -> None:
19 mock.enter_inner()
20 mock.enter("inner")
22 async def __aexit__(self, *args) -> None:
23 mock.exit_inner()
24 mock.exit("inner")
26 async def consume_scope(self, call_next, msg) -> None:
27 mock.consume_inner()
28 mock.sub("inner")
29 return await call_next(msg)
31 async def publish_scope(self, call_next, cmd) -> None:
32 mock.publish_inner()
33 mock.pub("inner")
34 return await call_next(cmd)
36 class OuterMiddleware(BaseMiddleware):
37 async def __aenter__(self) -> None:
38 mock.enter_outer()
39 mock.enter("outer")
41 async def __aexit__(self, *args) -> None:
42 mock.exit_outer()
43 mock.exit("outer")
45 async def consume_scope(self, call_next, msg) -> None:
46 mock.consume_outer()
47 mock.sub("outer")
48 return await call_next(msg)
50 async def publish_scope(self, call_next, cmd) -> None:
51 mock.publish_outer()
52 mock.pub("outer")
53 return await call_next(cmd)
55 broker = self.get_broker(middlewares=[OuterMiddleware, InnerMiddleware])
57 args, kwargs = self.get_subscriber_params(queue)
59 @broker.subscriber(*args, **kwargs)
60 async def handler(msg):
61 pass
63 async with self.patch_broker(broker) as br:
64 await br.publish(None, queue)
66 mock.consume_inner.assert_called_once()
67 mock.consume_outer.assert_called_once()
68 mock.publish_inner.assert_called_once()
69 mock.publish_outer.assert_called_once()
70 mock.enter_inner.assert_called_once()
71 mock.enter_outer.assert_called_once()
72 mock.exit_inner.assert_called_once()
73 mock.exit_outer.assert_called_once()
75 assert [c.args[0] for c in mock.sub.call_args_list] == ["outer", "inner"]
76 assert [c.args[0] for c in mock.pub.call_args_list] == ["outer", "inner"]
77 assert [c.args[0] for c in mock.enter.call_args_list] == ["outer", "inner"]
78 assert [c.args[0] for c in mock.exit.call_args_list] == ["inner", "outer"]
80 async def test_publisher_middleware_order(
81 self,
82 queue: str,
83 mock: MagicMock,
84 ) -> None:
85 class InnerMiddleware(BaseMiddleware):
86 async def publish_scope(self, call_next, cmd):
87 mock.publish_inner()
88 mock("inner")
89 return await call_next(cmd)
91 class MiddleMiddleware(BaseMiddleware):
92 async def publish_scope(self, call_next, cmd):
93 mock.publish_middle()
94 mock("middle")
95 return await call_next(cmd)
97 class OuterMiddleware(BaseMiddleware):
98 async def publish_scope(self, call_next, cmd):
99 mock.publish_outer()
100 mock("outer")
101 return await call_next(cmd)
103 broker = self.get_broker(
104 middlewares=[OuterMiddleware, MiddleMiddleware, InnerMiddleware],
105 )
106 publisher = broker.publisher(queue)
108 args, kwargs = self.get_subscriber_params(queue)
110 @broker.subscriber(*args, **kwargs)
111 async def handler(msg):
112 pass
114 async with self.patch_broker(broker):
115 await publisher.publish(None, queue)
117 mock.publish_inner.assert_called_once()
118 mock.publish_middle.assert_called_once()
119 mock.publish_outer.assert_called_once()
121 assert [c.args[0] for c in mock.call_args_list] == ["outer", "middle", "inner"]
123 async def test_publisher_with_router_middleware_order(
124 self,
125 queue: str,
126 mock: MagicMock,
127 ) -> None:
128 class InnerMiddleware(BaseMiddleware):
129 async def publish_scope(self, call_next, cmd):
130 mock.publish_inner()
131 mock("inner")
132 return await call_next(cmd)
134 class MiddleMiddleware(BaseMiddleware):
135 async def publish_scope(self, call_next, cmd):
136 mock.publish_middle()
137 mock("middle")
138 return await call_next(cmd)
140 class OuterMiddleware(BaseMiddleware):
141 async def publish_scope(self, call_next, cmd):
142 mock.publish_outer()
143 mock("outer")
144 return await call_next(cmd)
146 broker = self.get_broker(middlewares=[OuterMiddleware])
147 router = self.get_router(middlewares=[MiddleMiddleware])
148 router2 = self.get_router(middlewares=[InnerMiddleware])
150 publisher = router2.publisher(queue)
152 args, kwargs = self.get_subscriber_params(queue)
154 @router2.subscriber(*args, **kwargs)
155 async def handler(msg):
156 pass
158 router.include_router(router2)
159 broker.include_router(router)
161 async with self.patch_broker(broker):
162 await publisher.publish(None, queue)
164 mock.publish_inner.assert_called_once()
165 mock.publish_middle.assert_called_once()
166 mock.publish_outer.assert_called_once()
168 assert [c.args[0] for c in mock.call_args_list] == ["outer", "middle", "inner"]
170 async def test_consume_middleware_order(self, queue: str, mock: MagicMock) -> None:
171 class InnerMiddleware(BaseMiddleware):
172 async def consume_scope(self, call_next, cmd):
173 mock.consume_inner()
174 mock("inner")
175 return await call_next(cmd)
177 class MiddleMiddleware(BaseMiddleware):
178 async def consume_scope(self, call_next, cmd):
179 mock.consume_middle()
180 mock("middle")
181 return await call_next(cmd)
183 class OuterMiddleware(BaseMiddleware):
184 async def consume_scope(self, call_next, cmd):
185 mock.consume_outer()
186 mock("outer")
187 return await call_next(cmd)
189 broker = self.get_broker(
190 middlewares=[OuterMiddleware, MiddleMiddleware, InnerMiddleware],
191 )
193 args, kwargs = self.get_subscriber_params(queue)
195 @broker.subscriber(*args, **kwargs)
196 async def handler(msg):
197 pass
199 async with self.patch_broker(broker) as br:
200 await br.publish(None, queue)
202 mock.consume_inner.assert_called_once()
203 mock.consume_middle.assert_called_once()
204 mock.consume_outer.assert_called_once()
206 assert [c.args[0] for c in mock.call_args_list] == ["outer", "middle", "inner"]
208 async def test_consume_with_router_middleware_order(
209 self,
210 queue: str,
211 mock: MagicMock,
212 ) -> None:
213 class InnerMiddleware(BaseMiddleware):
214 async def consume_scope(self, call_next, cmd):
215 mock("inner")
216 return await call_next(cmd)
218 class MiddleMiddleware(BaseMiddleware):
219 async def consume_scope(self, call_next, cmd):
220 mock("middle")
221 return await call_next(cmd)
223 class OuterMiddleware(BaseMiddleware):
224 async def consume_scope(self, call_next, cmd):
225 mock("outer")
226 return await call_next(cmd)
228 broker = self.get_broker(middlewares=[OuterMiddleware])
229 router = self.get_router(middlewares=[MiddleMiddleware])
230 router2 = self.get_router(middlewares=[InnerMiddleware])
232 args, kwargs = self.get_subscriber_params(queue)
234 @router2.subscriber(*args, **kwargs)
235 async def handler(msg):
236 pass
238 router.include_router(router2)
239 broker.include_router(router)
240 async with self.patch_broker(broker) as br:
241 await br.publish(None, queue)
243 call_order = [c.args[0] for c in mock.call_args_list]
244 assert call_order == ["outer", "middle", "inner"], call_order
247@pytest.mark.asyncio()
248class LocalMiddlewareTestcase(BaseTestcaseConfig):
249 async def test_subscriber_middleware(
250 self,
251 queue: str,
252 mock: MagicMock,
253 ) -> None:
254 event = asyncio.Event()
256 class TapMiddleware(BaseMiddleware):
257 async def consume_scope(self, call_next, msg):
258 mock.start(await msg.decode())
259 result = await call_next(msg)
260 mock.end()
261 event.set()
262 return result
264 broker = self.get_broker(middlewares=(TapMiddleware,))
266 args, kwargs = self.get_subscriber_params(queue)
268 @broker.subscriber(*args, **kwargs)
269 async def handler(m) -> str:
270 mock.inner(m)
271 return "end"
273 async with self.patch_broker(broker) as br:
274 await br.start()
275 await asyncio.wait(
276 (
277 asyncio.create_task(br.publish("start", queue)),
278 asyncio.create_task(event.wait()),
279 ),
280 timeout=self.timeout,
281 )
283 mock.start.assert_called_once_with("start")
284 mock.inner.assert_called_once_with("start")
286 assert event.is_set()
287 mock.end.assert_called_once()
289 async def test_error_traceback(
290 self,
291 queue: str,
292 mock: MagicMock,
293 event: asyncio.Event,
294 ) -> None:
295 class ErrorTraceMiddleware(BaseMiddleware):
296 async def consume_scope(self, call_next, msg):
297 try:
298 return await call_next(msg)
299 except Exception as e:
300 mock(isinstance(e, ValueError))
301 raise
303 broker = self.get_broker(middlewares=(ErrorTraceMiddleware,))
305 args, kwargs = self.get_subscriber_params(queue)
307 @broker.subscriber(*args, **kwargs)
308 async def handler2(m):
309 event.set()
310 raise ValueError
312 async with self.patch_broker(broker) as br:
313 await br.start()
315 await asyncio.wait(
316 (
317 asyncio.create_task(br.publish("", queue)),
318 asyncio.create_task(event.wait()),
319 ),
320 timeout=self.timeout,
321 )
323 assert event.is_set()
324 mock.assert_called_once_with(True)
327@pytest.mark.asyncio()
328class MiddlewareTestcase(LocalMiddlewareTestcase):
329 async def test_global_middleware(
330 self,
331 queue: str,
332 mock: MagicMock,
333 event: asyncio.Event,
334 ) -> None:
335 class mid(BaseMiddleware): # noqa: N801
336 async def on_receive(self):
337 mock.start(self.msg)
338 return await super().on_receive()
340 async def after_processed(self, exc_type, exc_val, exc_tb):
341 mock.end()
342 return await super().after_processed(exc_type, exc_val, exc_tb)
344 broker = self.get_broker(
345 middlewares=(mid,),
346 )
348 args, kwargs = self.get_subscriber_params(queue)
350 @broker.subscriber(*args, **kwargs)
351 async def handler(m) -> str:
352 event.set()
353 return ""
355 async with self.patch_broker(broker) as br:
356 await br.start()
357 await asyncio.wait(
358 (
359 asyncio.create_task(br.publish("", queue)),
360 asyncio.create_task(event.wait()),
361 ),
362 timeout=self.timeout,
363 )
364 assert event.is_set()
366 mock.start.assert_called_once()
367 mock.end.assert_called_once()
369 async def test_add_global_middleware(
370 self,
371 queue: str,
372 mock: MagicMock,
373 event: asyncio.Event,
374 ) -> None:
375 class mid(BaseMiddleware): # noqa: N801
376 async def on_receive(self):
377 mock.start(self.msg)
378 return await super().on_receive()
380 async def after_processed(self, exc_type, exc_val, exc_tb):
381 mock.end()
382 return await super().after_processed(exc_type, exc_val, exc_tb)
384 broker = self.get_broker()
386 # already registered subscriber
387 args, kwargs = self.get_subscriber_params(queue)
389 @broker.subscriber(*args, **kwargs)
390 async def handler(m) -> str:
391 event.set()
392 return ""
394 # should affect to already registered and a new subscriber both
395 broker.add_middleware(mid)
397 event2 = asyncio.Event()
399 # new subscriber
400 args2, kwargs2 = self.get_subscriber_params(queue + "1")
402 @broker.subscriber(*args2, **kwargs2)
403 async def handler2(m) -> str:
404 event2.set()
405 return ""
407 async with self.patch_broker(broker) as br:
408 await br.start()
409 await asyncio.wait(
410 (
411 asyncio.create_task(br.publish("", queue)),
412 asyncio.create_task(br.publish("", f"{queue}1")),
413 asyncio.create_task(event.wait()),
414 asyncio.create_task(event2.wait()),
415 ),
416 timeout=self.timeout,
417 )
419 assert event.is_set()
420 assert mock.start.call_count == 2, mock.start.call_count
421 assert mock.end.call_count == 2, mock.end.call_count
423 async def test_patch_publish(
424 self,
425 queue: str,
426 mock: MagicMock,
427 event: asyncio.Event,
428 ) -> None:
429 class Mid(BaseMiddleware):
430 async def publish_scope(self, call_next, cmd):
431 cmd.body *= 2
432 return await call_next(cmd)
434 broker = self.get_broker(middlewares=(Mid,))
436 args, kwargs = self.get_subscriber_params(queue)
438 @broker.subscriber(*args, **kwargs)
439 async def handler(m):
440 return m
442 args2, kwargs2 = self.get_subscriber_params(queue + "r")
444 @broker.subscriber(*args2, **kwargs2)
445 async def handler_resp(m) -> None:
446 mock(m)
447 event.set()
449 async with self.patch_broker(broker) as br:
450 await br.start()
452 await asyncio.wait(
453 (
454 asyncio.create_task(br.publish("r", queue, reply_to=queue + "r")),
455 asyncio.create_task(event.wait()),
456 ),
457 timeout=self.timeout,
458 )
460 assert event.is_set()
461 mock.assert_called_once_with("rrrr")
463 async def test_global_publisher_middleware(
464 self,
465 queue: str,
466 mock: MagicMock,
467 event: asyncio.Event,
468 ) -> None:
469 class Mid(BaseMiddleware):
470 async def publish_scope(self, call_next, cmd):
471 cmd.body *= 2
472 mock.enter(cmd.body)
473 try:
474 return await call_next(cmd)
475 finally:
476 mock.end()
477 if mock.end.call_count > 2:
478 event.set()
480 broker = self.get_broker(middlewares=(Mid,))
482 args, kwargs = self.get_subscriber_params(queue)
484 @broker.subscriber(*args, **kwargs)
485 @broker.publisher(queue + "1")
486 @broker.publisher(queue + "2")
487 async def handler(m):
488 mock.inner(m)
489 return m
491 async with self.patch_broker(broker) as br:
492 await br.start()
493 await asyncio.wait(
494 (
495 asyncio.create_task(br.publish("1", queue)),
496 asyncio.create_task(event.wait()),
497 ),
498 timeout=self.timeout,
499 )
501 assert event.is_set()
502 mock.inner.assert_called_once_with("11")
503 assert mock.enter.call_count == 3
504 mock.enter.assert_called_with("1111")
505 assert mock.end.call_count == 3
508@pytest.mark.asyncio()
509class ExceptionMiddlewareTestcase(BaseTestcaseConfig):
510 async def test_exception_middleware_default_msg(
511 self,
512 queue: str,
513 mock: MagicMock,
514 event: asyncio.Event,
515 ) -> None:
516 mid = ExceptionMiddleware()
518 @mid.add_handler(ValueError, publish=True)
519 async def value_error_handler(exc) -> str:
520 return "value"
522 broker = self.get_broker(apply_types=True, middlewares=(mid,))
524 args, kwargs = self.get_subscriber_params(queue)
526 @broker.subscriber(*args, **kwargs)
527 @broker.publisher(queue + "1")
528 async def subscriber1(m):
529 raise ValueError
531 args, kwargs = self.get_subscriber_params(queue + "1")
533 @broker.subscriber(*args, **kwargs)
534 async def subscriber2(msg=Context("message")) -> None:
535 mock(await msg.decode())
536 event.set()
538 async with self.patch_broker(broker) as br:
539 await br.start()
540 await asyncio.wait(
541 (
542 asyncio.create_task(br.publish("", queue)),
543 asyncio.create_task(event.wait()),
544 ),
545 timeout=self.timeout,
546 )
548 assert event.is_set()
549 assert mock.call_count == 1
550 mock.assert_called_once_with("value")
552 async def test_exception_middleware_skip_msg(
553 self,
554 queue: str,
555 mock: MagicMock,
556 event: asyncio.Event,
557 ) -> None:
558 mid = ExceptionMiddleware()
560 @mid.add_handler(ValueError, publish=True)
561 async def value_error_handler(exc):
562 event.set()
563 raise SkipMessage
565 broker = self.get_broker(middlewares=(mid,))
566 args, kwargs = self.get_subscriber_params(queue)
568 @broker.subscriber(*args, **kwargs)
569 @broker.publisher(queue + "1")
570 async def subscriber1(m):
571 raise ValueError
573 args2, kwargs2 = self.get_subscriber_params(queue + "1")
575 @broker.subscriber(*args2, **kwargs2)
576 async def subscriber2(msg=Context("message")) -> None:
577 mock(await msg.decode())
579 async with self.patch_broker(broker) as br:
580 await br.start()
581 await asyncio.wait(
582 (
583 asyncio.create_task(br.publish("", queue)),
584 asyncio.create_task(event.wait()),
585 ),
586 timeout=self.timeout,
587 )
589 assert event.is_set()
590 assert mock.call_count == 0
592 async def test_exception_middleware_do_not_catch_skip_msg(
593 self,
594 queue: str,
595 mock: MagicMock,
596 event: asyncio.Event,
597 ) -> None:
598 mid = ExceptionMiddleware()
600 @mid.add_handler(Exception)
601 async def value_error_handler(exc) -> None:
602 mock()
604 broker = self.get_broker(middlewares=(mid,))
605 args, kwargs = self.get_subscriber_params(queue)
607 @broker.subscriber(*args, **kwargs)
608 async def subscriber(m):
609 event.set()
610 raise SkipMessage
612 async with self.patch_broker(broker) as br:
613 await br.start()
614 await asyncio.wait(
615 (
616 asyncio.create_task(br.publish("", queue)),
617 asyncio.create_task(event.wait()),
618 ),
619 timeout=self.timeout,
620 )
621 await asyncio.sleep(0.001)
623 assert event.is_set()
624 assert mock.call_count == 0
626 async def test_exception_middleware_reraise(
627 self,
628 queue: str,
629 mock: MagicMock,
630 event: asyncio.Event,
631 ) -> None:
632 mid = ExceptionMiddleware()
634 @mid.add_handler(ValueError, publish=True)
635 async def value_error_handler(exc):
636 event.set()
637 raise exc
639 broker = self.get_broker(middlewares=(mid,))
640 args, kwargs = self.get_subscriber_params(queue)
642 @broker.subscriber(*args, **kwargs)
643 @broker.publisher(queue + "1")
644 async def subscriber1(m):
645 raise ValueError
647 args2, kwargs2 = self.get_subscriber_params(queue + "1")
649 @broker.subscriber(*args2, **kwargs2)
650 async def subscriber2(msg=Context("message")) -> None:
651 mock(await msg.decode())
653 async with self.patch_broker(broker) as br:
654 await br.start()
655 await asyncio.wait(
656 (
657 asyncio.create_task(br.publish("", queue)),
658 asyncio.create_task(event.wait()),
659 ),
660 timeout=self.timeout,
661 )
663 assert event.is_set()
664 assert mock.call_count == 0
666 async def test_exception_middleware_different_handler(
667 self,
668 queue: str,
669 mock: MagicMock,
670 event: asyncio.Event,
671 ) -> None:
672 mid = ExceptionMiddleware()
674 @mid.add_handler(ZeroDivisionError, publish=True)
675 async def zero_error_handler(exc) -> str:
676 return "zero"
678 @mid.add_handler(ValueError, publish=True)
679 async def value_error_handler(exc) -> str:
680 return "value"
682 broker = self.get_broker(apply_types=True, middlewares=(mid,))
683 args, kwargs = self.get_subscriber_params(queue)
685 publisher = broker.publisher(queue + "2")
687 @broker.subscriber(*args, **kwargs)
688 @publisher
689 async def subscriber1(m):
690 raise ZeroDivisionError
692 args2, kwargs2 = self.get_subscriber_params(queue + "1")
694 @broker.subscriber(*args2, **kwargs2)
695 @publisher
696 async def subscriber2(m):
697 raise ValueError
699 args3, kwargs3 = self.get_subscriber_params(queue + "2")
701 @broker.subscriber(*args3, **kwargs3)
702 async def subscriber3(msg=Context("message")) -> None:
703 mock(await msg.decode())
704 if mock.call_count > 1:
705 event.set()
707 async with self.patch_broker(broker) as br:
708 await br.start()
709 await asyncio.wait(
710 (
711 asyncio.create_task(br.publish("", queue)),
712 asyncio.create_task(br.publish("", queue + "1")),
713 asyncio.create_task(event.wait()),
714 ),
715 timeout=self.timeout,
716 )
718 assert event.is_set()
719 assert mock.call_count == 2
720 mock.assert_has_calls([call("zero"), call("value")], any_order=True)
722 async def test_exception_middleware_init_handler_same(self) -> None:
723 mid1 = ExceptionMiddleware()
725 @mid1.add_handler(ValueError)
726 async def value_error_handler(exc) -> str:
727 return "value"
729 mid2 = ExceptionMiddleware(handlers={ValueError: value_error_handler})
731 assert list(mid1._handlers.keys()) == list(mid2._handlers.keys())
733 async def test_exception_middleware_init_publish_handler_same(self) -> None:
734 mid1 = ExceptionMiddleware()
736 @mid1.add_handler(ValueError, publish=True)
737 async def value_error_handler(exc) -> str:
738 return "value"
740 mid2 = ExceptionMiddleware(publish_handlers={ValueError: value_error_handler})
742 assert list(mid1._publish_handlers.keys()) == list(mid2._publish_handlers.keys())
744 async def test_exception_middleware_decoder_error(
745 self,
746 queue: str,
747 mock: MagicMock,
748 event: asyncio.Event,
749 ) -> None:
750 async def decoder(
751 msg,
752 original_decoder,
753 ) -> DecodedMessage:
754 raise ValueError
756 mid = ExceptionMiddleware()
758 @mid.add_handler(ValueError)
759 async def value_error_handler(exc) -> None:
760 event.set()
762 broker = self.get_broker(middlewares=(mid,), decoder=decoder)
764 args, kwargs = self.get_subscriber_params(queue)
766 @broker.subscriber(*args, **kwargs)
767 async def subscriber1(m):
768 raise ZeroDivisionError
770 async with self.patch_broker(broker) as br:
771 await br.start()
772 await asyncio.wait(
773 (
774 asyncio.create_task(br.publish("", queue)),
775 asyncio.create_task(event.wait()),
776 ),
777 timeout=self.timeout,
778 )
780 assert event.is_set()
782 async def test_exception_middleware_mro_resolution(
783 self,
784 queue: str,
785 mock: MagicMock,
786 event: asyncio.Event,
787 ) -> None:
788 """Test MRO-based resolution picks the most specific handler."""
790 class ExcAError(Exception):
791 pass
793 class ExcBError(ExcAError):
794 pass
796 mid = ExceptionMiddleware()
798 # Register parent handler BEFORE child handler
799 @mid.add_handler(ExcAError, publish=True)
800 async def handle_a(exc) -> str:
801 return "parent"
803 @mid.add_handler(ExcBError, publish=True)
804 async def handle_b(exc) -> str:
805 return "child"
807 broker = self.get_broker(apply_types=True, middlewares=(mid,))
808 args, kwargs = self.get_subscriber_params(queue)
810 @broker.subscriber(*args, **kwargs)
811 @broker.publisher(queue + "1")
812 async def subscriber1(m):
813 raise ExcBError
815 args2, kwargs2 = self.get_subscriber_params(queue + "1")
817 @broker.subscriber(*args2, **kwargs2)
818 async def subscriber2(msg=Context("message")) -> None:
819 mock(await msg.decode())
820 event.set()
822 async with self.patch_broker(broker) as br:
823 await br.start()
824 await asyncio.wait(
825 (
826 asyncio.create_task(br.publish("", queue)),
827 asyncio.create_task(event.wait()),
828 ),
829 timeout=self.timeout,
830 )
832 assert event.is_set()
833 mock.assert_called_once_with("child")