Coverage for faststream / _internal / endpoint / call_wrapper.py: 96%

58 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import asyncio 

2from collections.abc import Awaitable, Callable, Reversible, Sequence 

3from typing import ( 

4 TYPE_CHECKING, 

5 Any, 

6 Generic, 

7 Optional, 

8) 

9from unittest.mock import MagicMock 

10 

11import anyio 

12 

13from faststream._internal.types import P_HandlerParams, T_HandlerReturn 

14from faststream.exceptions import SetupError 

15 

16if TYPE_CHECKING: 

17 from fast_depends.core import CallModel 

18 from fast_depends.dependencies import Dependant 

19 

20 from faststream._internal.basic_types import Decorator 

21 from faststream._internal.di import FastDependsConfig 

22 from faststream._internal.endpoint.publisher import PublisherProto 

23 from faststream._internal.endpoint.subscriber import SubscriberUsecase 

24 from faststream.message import StreamMessage 

25 

26 

27def ensure_call_wrapper( 

28 call: Callable[P_HandlerParams, T_HandlerReturn], 

29) -> "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]": 

30 if isinstance(call, HandlerCallWrapper): 

31 return call 

32 

33 return HandlerCallWrapper(call) 

34 

35 

36class HandlerCallWrapper(Generic[P_HandlerParams, T_HandlerReturn]): 

37 """A generic class to wrap handler calls.""" 

38 

39 future: Optional["asyncio.Future[Any]"] 

40 _wrapped_call: Callable[..., Awaitable[Any]] | None 

41 _original_call: Callable[P_HandlerParams, T_HandlerReturn] 

42 

43 _publishers: list["PublisherProto[Any]"] 

44 

45 # we have to store subscribers here 

46 # to protect them from garbage collection 

47 _subscribers: list["SubscriberUsecase[Any]"] 

48 

49 __slots__ = ( 

50 "_original_call", 

51 "_publishers", 

52 "_subscribers", 

53 "_wrapped_call", 

54 "future", 

55 "is_test", 

56 "mock", 

57 ) 

58 

59 def __init__( 

60 self, 

61 call: Callable[P_HandlerParams, T_HandlerReturn], 

62 ) -> None: 

63 """Initialize a handler.""" 

64 self._original_call = call 

65 self._wrapped_call = None 

66 

67 self._publishers = [] 

68 self._subscribers = [] 

69 

70 self.mock = MagicMock() 

71 self.future = None 

72 self.is_test = False 

73 

74 def __call__( 

75 self, 

76 *args: P_HandlerParams.args, 

77 **kwargs: P_HandlerParams.kwargs, 

78 ) -> T_HandlerReturn: 

79 """Calls the object as a function.""" 

80 return self._original_call(*args, **kwargs) 

81 

82 async def call_wrapped( 

83 self, 

84 message: "StreamMessage[Any]", 

85 ) -> Any: 

86 """Calls the wrapped function with the given message.""" 

87 assert self._wrapped_call, "You should use `set_wrapped` first" 

88 if self.is_test: 

89 self.mock(await message.decode()) 

90 return await self._wrapped_call(message) 

91 

92 def set_wrapped( 

93 self, 

94 *, 

95 dependencies: Sequence["Dependant"], 

96 _call_decorators: Reversible["Decorator"], 

97 config: "FastDependsConfig", 

98 ) -> "CallModel": 

99 dependent = config.build_call( 

100 self._original_call, 

101 dependencies=dependencies, 

102 call_decorators=_call_decorators, 

103 ) 

104 self._original_call = dependent.original_call 

105 self._wrapped_call = dependent.wrapped_call 

106 return dependent.dependent 

107 

108 async def wait_call(self, timeout: float | None = None) -> None: 

109 """Waits for a call with an optional timeout.""" 

110 assert self.future is not None, "You can use this method only with TestClient" 

111 with anyio.fail_after(timeout): 

112 await self.future 

113 

114 def set_test(self) -> None: 

115 self.is_test = True 

116 self.mock.reset_mock() 

117 self.refresh(with_mock=True) 

118 

119 def reset_test(self) -> None: 

120 self.is_test = False 

121 self.mock.reset_mock() 

122 self.future = None 

123 

124 def trigger( 

125 self, 

126 result: Any = None, 

127 error: BaseException | None = None, 

128 ) -> None: 

129 if not self.is_test: 

130 return 

131 

132 if self.future is None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 msg = "You can use this method only with TestClient" 

134 raise SetupError(msg) 

135 

136 if self.future.done(): 

137 self.future = asyncio.Future() 

138 

139 if error: 

140 self.future.set_exception(error) 

141 

142 else: 

143 self.future.set_result(result) 

144 

145 def refresh(self, with_mock: bool = False) -> None: 

146 if asyncio.events._get_running_loop() is not None: 

147 self.future = asyncio.Future() 

148 

149 if with_mock and self.mock is not None: 

150 self.mock.reset_mock()