Coverage for tests / brokers / base / fastapi.py: 99%
415 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import asyncio
2from contextlib import asynccontextmanager
3from typing import Annotated, Any, TypeVar
4from unittest.mock import Mock
6import pytest
7from fastapi import BackgroundTasks, Depends, FastAPI, Header
8from fastapi.exceptions import RequestValidationError
9from fastapi.testclient import TestClient
11from faststream import (
12 Context as FSContext,
13 Depends as FSDepends,
14 Response,
15)
16from faststream._internal.broker import BrokerUsecase
17from faststream._internal.broker.router import BrokerRouter
18from faststream._internal.fastapi.context import Context
19from faststream._internal.fastapi.route import StreamMessage
20from faststream._internal.fastapi.router import StreamRouter
21from faststream.exceptions import SetupError
23from .basic import BaseTestcaseConfig
25Broker = TypeVar("Broker", bound=BrokerUsecase)
28@pytest.mark.asyncio()
29class FastAPITestcase(BaseTestcaseConfig):
30 router_class: type[StreamRouter[BrokerUsecase]]
31 broker_router_class: type[BrokerRouter[Any]]
33 async def test_base_real(self, mock: Mock, queue: str) -> None:
34 event = asyncio.Event()
36 router = self.router_class()
38 args, kwargs = self.get_subscriber_params(queue)
40 @router.subscriber(*args, **kwargs)
41 async def hello(msg):
42 event.set()
43 return mock(msg)
45 async with router.broker:
46 await router.broker.start()
47 await asyncio.wait(
48 (
49 asyncio.create_task(router.broker.publish("hi", queue)),
50 asyncio.create_task(event.wait()),
51 ),
52 timeout=self.timeout,
53 )
55 assert event.is_set()
56 mock.assert_called_with("hi")
58 async def test_background(
59 self,
60 mock: Mock,
61 queue: str,
62 event: asyncio.Event,
63 ) -> None:
64 router = self.router_class()
66 def task(msg: Any) -> None:
67 event.set()
68 return mock(msg)
70 args, kwargs = self.get_subscriber_params(queue)
72 @router.subscriber(*args, **kwargs)
73 async def hello(msg, tasks: BackgroundTasks) -> None:
74 tasks.add_task(task, msg)
76 async with router.broker:
77 await router.broker.start()
78 await asyncio.wait(
79 (
80 asyncio.create_task(router.broker.publish("hi", queue)),
81 asyncio.create_task(event.wait()),
82 ),
83 timeout=self.timeout,
84 )
86 mock.assert_called_with("hi")
88 async def test_context(self, mock: Mock, queue: str, event: asyncio.Event) -> None:
89 router = self.router_class()
90 context = router.context
92 context_key = "message.headers"
94 args, kwargs = self.get_subscriber_params(queue)
96 @router.subscriber(*args, **kwargs)
97 async def hello(msg: Any = Context(context_key)) -> None:
98 try:
99 mock(msg == context.resolve(context_key) and msg["1"] == "1")
100 finally:
101 event.set()
103 async with router.broker:
104 await router.broker.start()
105 await asyncio.wait(
106 (
107 asyncio.create_task(
108 router.broker.publish("", queue, headers={"1": "1"}),
109 ),
110 asyncio.create_task(event.wait()),
111 ),
112 timeout=self.timeout,
113 )
115 assert event.is_set()
116 mock.assert_called_with(True)
118 async def test_context_annotated(
119 self,
120 mock: Mock,
121 queue: str,
122 event: asyncio.Event,
123 ) -> None:
124 router = self.router_class()
125 context = router.context
127 context_key = "message.headers"
129 args, kwargs = self.get_subscriber_params(queue)
131 @router.subscriber(*args, **kwargs)
132 async def hello(msg: Annotated[Any, Context(context_key)]) -> None:
133 try:
134 mock(msg == context.resolve(context_key) and msg["1"] == "1")
135 finally:
136 event.set()
138 async with router.broker:
139 await router.broker.start()
140 await asyncio.wait(
141 (
142 asyncio.create_task(
143 router.broker.publish("", queue, headers={"1": "1"}),
144 ),
145 asyncio.create_task(event.wait()),
146 ),
147 timeout=self.timeout,
148 )
150 assert event.is_set()
151 mock.assert_called_with(True)
153 @pytest.mark.flaky(reruns=3, reruns_delay=1)
154 async def test_initial_context(self, queue: str, event: asyncio.Event) -> None:
155 router = self.router_class()
156 context = router.context
158 args, kwargs = self.get_subscriber_params(queue)
160 @router.subscriber(*args, **kwargs)
161 async def hello(msg: int, data: set[int] = Context(queue, initial=set)) -> None:
162 data.add(msg)
163 if len(data) == 2:
164 event.set()
166 async with router.broker:
167 await router.broker.start()
168 await asyncio.wait(
169 (
170 asyncio.create_task(router.broker.publish(1, queue)),
171 asyncio.create_task(router.broker.publish(2, queue)),
172 asyncio.create_task(event.wait()),
173 ),
174 timeout=self.timeout,
175 )
177 assert context.get(queue) == {1, 2}
178 context.reset_global(queue)
180 async def test_double_real(self, mock: Mock, queue: str) -> None:
181 event = asyncio.Event()
182 event2 = asyncio.Event()
184 router = self.router_class()
186 args, kwargs = self.get_subscriber_params(queue)
187 sub1 = router.subscriber(*args, **kwargs)
189 args2, kwargs2 = self.get_subscriber_params(queue + "2")
191 @sub1
192 @router.subscriber(*args2, **kwargs2)
193 async def hello(msg: str) -> None:
194 if event.is_set():
195 event2.set()
196 else:
197 event.set()
198 mock()
200 async with router.broker:
201 await router.broker.start()
202 await asyncio.wait(
203 (
204 asyncio.create_task(router.broker.publish("hi", queue)),
205 asyncio.create_task(router.broker.publish("hi", queue + "2")),
206 asyncio.create_task(event.wait()),
207 asyncio.create_task(event2.wait()),
208 ),
209 timeout=self.timeout,
210 )
212 assert event.is_set()
213 assert event2.is_set()
214 assert mock.call_count == 2
216 async def test_base_publisher_real(
217 self,
218 mock: Mock,
219 queue: str,
220 ) -> None:
221 event = asyncio.Event()
223 router = self.router_class()
225 args, kwargs = self.get_subscriber_params(queue)
227 @router.subscriber(*args, **kwargs)
228 @router.publisher(queue + "resp")
229 async def m() -> str:
230 return "hi"
232 args2, kwargs2 = self.get_subscriber_params(queue + "resp")
234 @router.subscriber(*args2, **kwargs2)
235 async def resp(msg) -> None:
236 event.set()
237 mock(msg)
239 async with router.broker:
240 await router.broker.start()
242 await asyncio.wait(
243 (
244 asyncio.create_task(router.broker.publish("", queue)),
245 asyncio.create_task(event.wait()),
246 ),
247 timeout=self.timeout,
248 )
250 assert event.is_set()
251 mock.assert_called_once_with("hi")
253 async def test_injection_fastapi(
254 self,
255 mock: Mock,
256 queue: str,
257 event: asyncio.Event,
258 ) -> None:
259 router = self.router_class()
261 args, kwargs = self.get_subscriber_params(queue)
263 @router.subscriber(*args, **kwargs)
264 async def subscriber(msg: StreamMessage) -> None:
265 mock("app" in msg.scope)
266 event.set()
268 async with router.broker:
269 await router.broker.start()
270 await asyncio.wait(
271 (
272 asyncio.create_task(router.broker.publish(None, queue)),
273 asyncio.create_task(event.wait()),
274 ),
275 timeout=self.timeout,
276 )
278 mock.assert_called_once_with(True)
281@pytest.mark.asyncio()
282class FastAPILocalTestcase(BaseTestcaseConfig):
283 router_class: type[StreamRouter[BrokerUsecase]]
285 async def test_base(self, queue: str) -> None:
286 router = self.router_class()
288 app = FastAPI()
289 app.include_router(router)
291 args, kwargs = self.get_subscriber_params(queue)
293 @router.subscriber(*args, **kwargs)
294 async def hello() -> str:
295 return "hi"
297 async with self.patch_broker(router.broker) as br:
298 with TestClient(app) as client:
299 assert client.app_state["broker"] is br
301 r = await br.request(
302 "hi",
303 queue,
304 timeout=0.5,
305 )
306 assert await r.decode() == "hi", r
308 async def test_request(self, queue: str) -> None:
309 """Local test due request exists in all TestClients."""
310 router = self.router_class(setup_state=False)
312 app = FastAPI()
314 args, kwargs = self.get_subscriber_params(queue)
316 @router.subscriber(*args, **kwargs)
317 async def hello():
318 return Response("Hi!", headers={"x-header": "test"})
320 async with self.patch_broker(router.broker) as br:
321 with TestClient(app) as client:
322 assert not client.app_state.get("broker")
324 r = await br.request(
325 "hi",
326 queue,
327 timeout=0.5,
328 )
329 assert await r.decode() == "Hi!"
330 assert r.headers["x-header"] == "test"
332 async def test_base_without_state(self, queue: str) -> None:
333 router = self.router_class(setup_state=False)
335 app = FastAPI()
337 args, kwargs = self.get_subscriber_params(queue)
339 @router.subscriber(*args, **kwargs)
340 async def hello() -> str:
341 return "hi"
343 async with self.patch_broker(router.broker) as br:
344 with TestClient(app) as client:
345 assert not client.app_state.get("broker")
347 r = await br.request(
348 "hi",
349 queue,
350 timeout=0.5,
351 )
352 assert await r.decode() == "hi", r
354 async def test_invalid(self, queue: str) -> None:
355 router = self.router_class()
357 app = FastAPI()
359 args, kwargs = self.get_subscriber_params(queue)
361 @router.subscriber(*args, **kwargs)
362 async def hello(msg: int) -> None: ...
364 app.include_router(router)
366 async with self.patch_broker(router.broker) as br:
367 with TestClient(app):
368 with pytest.raises(RequestValidationError):
369 await br.publish("hi", queue)
371 async def test_headers(self, queue: str) -> None:
372 router = self.router_class()
374 args, kwargs = self.get_subscriber_params(queue)
376 @router.subscriber(*args, **kwargs)
377 async def hello(w=Header()):
378 return w
380 async with self.patch_broker(router.broker) as br:
381 r = await br.request(
382 "",
383 queue,
384 headers={"w": "hi"},
385 timeout=0.5,
386 )
387 assert await r.decode() == "hi", r
389 async def test_depends(self, mock: Mock, queue: str) -> None:
390 router = self.router_class()
392 def dep(a):
393 mock(a)
394 return a
396 args, kwargs = self.get_subscriber_params(queue)
398 @router.subscriber(*args, **kwargs)
399 async def hello(a, w=Depends(dep)):
400 return w
402 async with self.patch_broker(router.broker) as br:
403 r = await br.request(
404 {"a": "hi"},
405 queue,
406 timeout=0.5,
407 )
408 assert await r.decode() == "hi", r
410 mock.assert_called_once_with("hi")
412 async def test_mixed_depends(self, mock: Mock, queue: str) -> None:
413 router = self.router_class()
415 def dep(a: str) -> str:
416 mock(a)
417 return a
419 args, kwargs = self.get_subscriber_params(queue)
421 @router.subscriber(*args, **kwargs)
422 async def hello(
423 a: str,
424 w: Annotated[
425 str,
426 Depends(dep),
427 FSDepends(dep), # will be ignored
428 ],
429 ) -> str:
430 return w
432 async with self.patch_broker(router.broker) as br:
433 r = await br.request(
434 {"a": "hi"},
435 queue,
436 timeout=0.5,
437 )
438 assert await r.decode() == "hi", r
440 mock.assert_called_once_with("hi")
442 async def test_depends_from_fastdepends_default(self, queue: str) -> None:
443 router = self.router_class()
445 args, kwargs = self.get_subscriber_params(queue)
447 subscriber = router.subscriber(*args, **kwargs)
449 @subscriber
450 def sub(d: Any = FSDepends(lambda: 1)) -> None: ...
452 app = FastAPI()
453 app.include_router(router)
455 with pytest.raises(SetupError): # noqa: PT012
456 async with self.patch_broker(router.broker):
457 with TestClient(app):
458 ...
460 async def test_depends_from_fastdepends_annotated(self, queue: str) -> None:
461 router = self.router_class()
463 args, kwargs = self.get_subscriber_params(queue)
465 subscriber = router.subscriber(*args, **kwargs)
467 @subscriber
468 def sub(d: Annotated[Any, FSDepends(lambda: 1)]) -> None: ...
470 app = FastAPI()
471 app.include_router(router)
473 with pytest.raises(SetupError): # noqa: PT012
474 async with self.patch_broker(router.broker):
475 with TestClient(app):
476 ...
478 async def test_depends_combined_annotated(self, queue: str) -> None:
479 router = self.router_class()
481 args, kwargs = self.get_subscriber_params(queue)
483 subscriber = router.subscriber(*args, **kwargs)
485 @subscriber
486 def sub(
487 d: Annotated[Any, FSDepends(lambda: 1), Depends(lambda: 1)],
488 ) -> None: ...
490 app = FastAPI()
491 app.include_router(router)
493 async with self.patch_broker(router.broker):
494 with TestClient(app):
495 ...
497 async def test_faststream_context(self, queue: str) -> None:
498 router = self.router_class()
500 args, kwargs = self.get_subscriber_params(queue)
502 @router.subscriber(*args, **kwargs)
503 async def hello(msg: Any = FSContext()) -> None: ...
505 app = FastAPI()
506 app.include_router(router)
508 with pytest.raises(SetupError): # noqa: PT012
509 async with self.patch_broker(router.broker):
510 with TestClient(app):
511 ...
513 async def test_faststream_context_annotated(self, queue: str) -> None:
514 router = self.router_class()
516 args, kwargs = self.get_subscriber_params(queue)
518 @router.subscriber(*args, **kwargs)
519 async def hello(msg: Annotated[Any, FSContext()]) -> None: ...
521 app = FastAPI()
522 app.include_router(router)
524 with pytest.raises(SetupError): # noqa: PT012
525 async with self.patch_broker(router.broker):
526 with TestClient(app):
527 ...
529 async def test_combined_context_annotated(self, queue: str) -> None:
530 router = self.router_class()
532 args, kwargs = self.get_subscriber_params(queue)
534 @router.subscriber(*args, **kwargs)
535 async def hello(
536 msg: Annotated[
537 Any,
538 Context("message.headers"),
539 FSContext("message.headers"),
540 ],
541 ) -> None: ...
543 app = FastAPI()
544 app.include_router(router)
546 async with self.patch_broker(router.broker):
547 with TestClient(app):
548 ...
550 async def test_nested_combined_context_annotated(self, queue: str) -> None:
551 router = self.router_class()
553 args, kwargs = self.get_subscriber_params(queue)
555 @router.subscriber(*args, **kwargs)
556 async def hello(
557 msg: Annotated[
558 Annotated[Any, FSContext("message.headers")],
559 Context("message.headers"),
560 ],
561 ) -> None: ...
563 app = FastAPI()
564 app.include_router(router)
566 async with self.patch_broker(router.broker):
567 with TestClient(app):
568 ...
570 async def test_yield_depends(self, mock: Mock, queue: str) -> None:
571 router = self.router_class()
573 def dep(a):
574 mock.start()
575 yield a
576 mock.close()
578 args, kwargs = self.get_subscriber_params(queue)
580 @router.subscriber(*args, **kwargs)
581 async def hello(a, w=Depends(dep)):
582 mock.start.assert_called_once()
583 assert not mock.close.call_count
584 return w
586 async with self.patch_broker(router.broker) as br:
587 r = await br.request(
588 {"a": "hi"},
589 queue,
590 timeout=0.5,
591 )
592 assert await r.decode() == "hi", r
594 mock.start.assert_called_once()
595 mock.close.assert_called_once()
597 async def test_router_depends(self, mock: Mock, queue: str) -> None:
598 def mock_dep() -> None:
599 mock()
601 router = self.router_class(dependencies=(Depends(mock_dep, use_cache=False),))
603 args, kwargs = self.get_subscriber_params(queue)
605 @router.subscriber(*args, **kwargs)
606 async def hello(a):
607 return a
609 async with self.patch_broker(router.broker) as br:
610 r = await br.request("hi", queue, timeout=0.5)
611 assert await r.decode() == "hi", r
613 mock.assert_called_once()
615 async def test_subscriber_depends(self, mock: Mock, queue: str) -> None:
616 def mock_dep() -> None:
617 mock()
619 router = self.router_class()
621 args, kwargs = self.get_subscriber_params(
622 queue,
623 dependencies=(Depends(mock_dep, use_cache=False),),
624 )
626 @router.subscriber(*args, **kwargs)
627 async def hello(a):
628 return a
630 async with self.patch_broker(router.broker) as br:
631 r = await br.request(
632 "hi",
633 queue,
634 timeout=0.5,
635 )
636 assert await r.decode() == "hi", r
638 mock.assert_called_once()
640 async def test_hooks(self, mock: Mock) -> None:
641 router = self.router_class()
643 app = FastAPI()
644 app.include_router(router)
646 @router.after_startup
647 def test_sync(app) -> None:
648 mock.sync_called()
650 @router.after_startup
651 async def test_async(app) -> None:
652 mock.async_called()
654 @router.on_broker_shutdown
655 def test_shutdown_sync(app) -> None:
656 mock.sync_shutdown_called()
658 @router.on_broker_shutdown
659 async def test_shutdown_async(app) -> None:
660 mock.async_shutdown_called()
662 async with self.patch_broker(router.broker), router.lifespan_context(app):
663 pass
665 mock.sync_called.assert_called_once()
666 mock.async_called.assert_called_once()
667 mock.sync_shutdown_called.assert_called_once()
668 mock.async_shutdown_called.assert_called_once()
670 async def test_existed_lifespan_startup(self, mock: Mock) -> None:
671 @asynccontextmanager
672 async def lifespan(app):
673 mock.start()
674 yield {"lifespan": True}
675 mock.close()
677 router = self.router_class(lifespan=lifespan)
679 app = FastAPI()
680 app.include_router(router)
682 async with (
683 self.patch_broker(router.broker),
684 router.lifespan_context(
685 app,
686 ) as context,
687 ):
688 assert context["lifespan"]
690 mock.start.assert_called_once()
691 mock.close.assert_called_once()
693 async def test_subscriber_mock(self, queue: str) -> None:
694 router = self.router_class()
696 args, kwargs = self.get_subscriber_params(queue)
698 @router.subscriber(*args, **kwargs)
699 async def m() -> str:
700 return "hi"
702 async with self.patch_broker(router.broker) as rb:
703 await rb.publish("hello", queue)
704 m.mock.assert_called_once_with("hello")
706 async def test_publisher_mock(self, queue: str) -> None:
707 router = self.router_class()
709 publisher = router.publisher(queue + "resp")
711 args, kwargs = self.get_subscriber_params(queue)
712 sub = router.subscriber(*args, **kwargs)
714 @publisher
715 @sub
716 async def m() -> str:
717 return "response"
719 async with self.patch_broker(router.broker) as rb:
720 await rb.publish("hello", queue)
721 publisher.mock.assert_called_with("response")
723 async def test_include(self, queue: str) -> None:
724 router = self.router_class()
725 router2 = self.broker_router_class()
727 app = FastAPI()
729 args, kwargs = self.get_subscriber_params(queue)
731 @router.subscriber(*args, **kwargs)
732 async def hello() -> str:
733 return "hi"
735 args2, kwargs2 = self.get_subscriber_params(queue + "1")
737 @router2.subscriber(*args2, **kwargs2)
738 async def hello_router2() -> str:
739 return "hi"
741 router.include_router(router2)
742 app.include_router(router)
744 async with self.patch_broker(router.broker) as br:
745 with TestClient(app) as client:
746 assert client.app_state["broker"] is br
748 r = await br.request(
749 "hi",
750 queue,
751 timeout=0.5,
752 )
753 assert await r.decode() == "hi", r
755 r = await br.request(
756 "hi",
757 queue + "1",
758 timeout=0.5,
759 )
760 assert await r.decode() == "hi", r
762 async def test_dependency_overrides(self, mock: Mock, queue: str) -> None:
763 router = self.router_class()
765 def dep1() -> None:
766 raise AssertionError
768 def dep2() -> None:
769 mock()
771 app = FastAPI()
772 app.dependency_overrides[dep1] = dep2
774 args, kwargs = self.get_subscriber_params(queue)
776 @router.subscriber(*args, **kwargs)
777 async def hello_router2(dep: None = Depends(dep1)) -> str:
778 return "hi"
780 app.include_router(router)
782 async with self.patch_broker(router.broker) as br:
783 with TestClient(app) as client:
784 assert client.app_state["broker"] is br
786 r = await br.request(
787 "hi",
788 queue,
789 timeout=0.5,
790 )
791 assert await r.decode() == "hi", r
793 mock.assert_called_once()
795 async def test_nested_router(self, queue: str) -> None:
796 router = self.router_class()
797 router2 = self.router_class()
799 args, kwargs = self.get_subscriber_params(queue)
801 @router2.subscriber(*args, **kwargs)
802 async def hello_router2() -> str:
803 return "hi"
805 with pytest.raises(
806 TypeError,
807 match="Including a StreamRouter into another StreamRouter",
808 ):
809 router.include_router(router2)
811 def test_nested_stream_router_raises(
812 self,
813 queue: str,
814 ) -> None:
815 """Including a StreamRouter into another StreamRouter must raise TypeError.
817 This pattern is unsupported (issue #2657). Users should include a regular
818 broker router (e.g. KafkaRouter) into the StreamRouter instead.
819 """
820 router = self.router_class()
821 router2 = self.router_class()
823 with pytest.raises(
824 TypeError,
825 match="Including a StreamRouter into another StreamRouter is not supported",
826 ):
827 router.include_router(router2)