Coverage for faststream / _internal / endpoint / call_wrapper.py: 96%
58 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import asyncio
2from collections.abc import Awaitable, Callable, Reversible, Sequence
3from typing import (
4 TYPE_CHECKING,
5 Any,
6 Generic,
7 Optional,
8)
9from unittest.mock import MagicMock
11import anyio
13from faststream._internal.types import P_HandlerParams, T_HandlerReturn
14from faststream.exceptions import SetupError
16if TYPE_CHECKING:
17 from fast_depends.core import CallModel
18 from fast_depends.dependencies import Dependant
20 from faststream._internal.basic_types import Decorator
21 from faststream._internal.di import FastDependsConfig
22 from faststream._internal.endpoint.publisher import PublisherProto
23 from faststream._internal.endpoint.subscriber import SubscriberUsecase
24 from faststream.message import StreamMessage
27def ensure_call_wrapper(
28 call: Callable[P_HandlerParams, T_HandlerReturn],
29) -> "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]":
30 if isinstance(call, HandlerCallWrapper):
31 return call
33 return HandlerCallWrapper(call)
36class HandlerCallWrapper(Generic[P_HandlerParams, T_HandlerReturn]):
37 """A generic class to wrap handler calls."""
39 future: Optional["asyncio.Future[Any]"]
40 _wrapped_call: Callable[..., Awaitable[Any]] | None
41 _original_call: Callable[P_HandlerParams, T_HandlerReturn]
43 _publishers: list["PublisherProto[Any]"]
45 # we have to store subscribers here
46 # to protect them from garbage collection
47 _subscribers: list["SubscriberUsecase[Any]"]
49 __slots__ = (
50 "_original_call",
51 "_publishers",
52 "_subscribers",
53 "_wrapped_call",
54 "future",
55 "is_test",
56 "mock",
57 )
59 def __init__(
60 self,
61 call: Callable[P_HandlerParams, T_HandlerReturn],
62 ) -> None:
63 """Initialize a handler."""
64 self._original_call = call
65 self._wrapped_call = None
67 self._publishers = []
68 self._subscribers = []
70 self.mock = MagicMock()
71 self.future = None
72 self.is_test = False
74 def __call__(
75 self,
76 *args: P_HandlerParams.args,
77 **kwargs: P_HandlerParams.kwargs,
78 ) -> T_HandlerReturn:
79 """Calls the object as a function."""
80 return self._original_call(*args, **kwargs)
82 async def call_wrapped(
83 self,
84 message: "StreamMessage[Any]",
85 ) -> Any:
86 """Calls the wrapped function with the given message."""
87 assert self._wrapped_call, "You should use `set_wrapped` first"
88 if self.is_test:
89 self.mock(await message.decode())
90 return await self._wrapped_call(message)
92 def set_wrapped(
93 self,
94 *,
95 dependencies: Sequence["Dependant"],
96 _call_decorators: Reversible["Decorator"],
97 config: "FastDependsConfig",
98 ) -> "CallModel":
99 dependent = config.build_call(
100 self._original_call,
101 dependencies=dependencies,
102 call_decorators=_call_decorators,
103 )
104 self._original_call = dependent.original_call
105 self._wrapped_call = dependent.wrapped_call
106 return dependent.dependent
108 async def wait_call(self, timeout: float | None = None) -> None:
109 """Waits for a call with an optional timeout."""
110 assert self.future is not None, "You can use this method only with TestClient"
111 with anyio.fail_after(timeout):
112 await self.future
114 def set_test(self) -> None:
115 self.is_test = True
116 self.mock.reset_mock()
117 self.refresh(with_mock=True)
119 def reset_test(self) -> None:
120 self.is_test = False
121 self.mock.reset_mock()
122 self.future = None
124 def trigger(
125 self,
126 result: Any = None,
127 error: BaseException | None = None,
128 ) -> None:
129 if not self.is_test:
130 return
132 if self.future is None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 msg = "You can use this method only with TestClient"
134 raise SetupError(msg)
136 if self.future.done():
137 self.future = asyncio.Future()
139 if error:
140 self.future.set_exception(error)
142 else:
143 self.future.set_result(result)
145 def refresh(self, with_mock: bool = False) -> None:
146 if asyncio.events._get_running_loop() is not None:
147 self.future = asyncio.Future()
149 if with_mock and self.mock is not None:
150 self.mock.reset_mock()