Coverage for faststream / _internal / endpoint / subscriber / usecase.py: 94%
126 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1from abc import abstractmethod
2from collections.abc import AsyncIterator, Callable, Iterable, Sequence
3from contextlib import AbstractContextManager, AsyncExitStack
4from itertools import chain
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Generic,
9 NamedTuple,
10 Optional,
11 Union,
12)
14from typing_extensions import Self, overload, override
16from faststream._internal.endpoint.usecase import Endpoint
17from faststream._internal.endpoint.utils import ParserComposition
18from faststream._internal.types import (
19 AsyncCallable,
20 MsgType,
21 P_HandlerParams,
22 T_HandlerReturn,
23)
24from faststream._internal.utils.functions import FakeContext, to_async
25from faststream.exceptions import StopConsume, SubscriberNotFound
26from faststream.middlewares import AcknowledgementMiddleware
27from faststream.middlewares.logging import CriticalLogMiddleware
28from faststream.response import ensure_response
30from .call_item import (
31 CallsCollection,
32 HandlerItem,
33)
34from .utils import MultiLock, default_filter
36if TYPE_CHECKING:
37 from fast_depends.dependencies import Dependant
39 from faststream._internal.basic_types import Decorator
40 from faststream._internal.configs import SubscriberUsecaseConfig
41 from faststream._internal.endpoint.call_wrapper import HandlerCallWrapper
42 from faststream._internal.endpoint.publisher import PublisherProto
43 from faststream._internal.parser import CodecProto
44 from faststream._internal.types import (
45 AsyncFilter,
46 BrokerMiddleware,
47 CustomCallable,
48 Filter,
49 )
50 from faststream.message import StreamMessage
51 from faststream.middlewares import BaseMiddleware
52 from faststream.response import Response
53 from faststream.specification.schema import SubscriberSpec
55 from .specification import SubscriberSpecification
58class _CallOptions(NamedTuple):
59 parser: Optional["CustomCallable"]
60 decoder: Optional["CustomCallable"]
61 dependencies: Iterable["Dependant"]
62 codec: Optional["CodecProto"] = None
65class SubscriberUsecase(Endpoint, Generic[MsgType]):
66 """A class representing an asynchronous handler."""
68 lock: "AbstractContextManager[Any]"
69 extra_watcher_options: dict[str, Any]
70 graceful_timeout: float | None
72 def __init__(
73 self,
74 config: "SubscriberUsecaseConfig",
75 specification: "SubscriberSpecification",
76 calls: "CallsCollection[MsgType]",
77 ) -> None:
78 """Initialize a new instance of the class."""
79 super().__init__(config._outer_config)
81 self.calls = calls
82 self.specification = specification
84 self._no_reply = config.no_reply
85 self._parser = config.parser
86 self._decoder = config.decoder
88 self.ack_policy = config.ack_policy
89 self.__auto_ack_disabled = config.auto_ack_disabled
91 self._call_options = _CallOptions(
92 parser=None,
93 decoder=None,
94 dependencies=(),
95 codec=None,
96 )
98 self._call_decorators: tuple[Decorator, ...] = ()
100 self.running = False
101 self.lock = FakeContext()
103 self.extra_watcher_options = {}
105 @property
106 def _broker_middlewares(self) -> Sequence["BrokerMiddleware[MsgType]"]:
107 return self._outer_config.broker_middlewares
109 async def start(self) -> None:
110 """Private method to start subscriber by broker."""
111 self.lock = MultiLock()
113 self._build_fastdepends_model()
115 self._outer_config.logger.log(
116 f"`{self.specification.call_name}` waiting for messages",
117 extra=self.get_log_context(None),
118 )
120 def _get_parser_and_decoder(
121 self,
122 item_parser: Optional["CustomCallable"] = None,
123 item_decoder: Optional["CustomCallable"] = None,
124 ) -> tuple[AsyncCallable, AsyncCallable]:
125 """Method to resolve parsers with priority.
127 First priority
128 >>> sub = broker.subscriber()
129 >>>
130 >>> @sub(parser=P0_parser)
131 >>> async def handler(): ...
133 Second priority
134 >>> sub = broker.subscriber(parser=P1_parser)
136 Third priority
137 >>> Broker(parser=P2_parser)
139 Default parser is `self._parser`.
140 So, the final parser object is
141 >>> ParserComposition(P0_parser or P1_parser or P2_parser, self._parser)
142 """
143 if parser := (
144 item_parser or self._call_options.parser or self._outer_config.broker_parser
145 ):
146 async_parser: AsyncCallable = ParserComposition(parser, self._parser)
147 else:
148 async_parser = self._parser
150 # Codec takes priority over legacy decoder.
151 # Having both is an error — it's ambiguous which takes effect.
152 codec = self._call_options.codec or self._outer_config.broker_codec
153 decoder = (
154 item_decoder
155 or self._call_options.decoder
156 or self._outer_config.broker_decoder
157 )
159 if codec and decoder:
160 msg = "Cannot use both 'codec' and 'decoder' — 'codec' replaces 'decoder'."
161 raise ValueError(msg)
163 if codec:
164 async_decoder: AsyncCallable = codec.decode
165 elif decoder:
166 async_decoder = ParserComposition(decoder, self._decoder)
167 else:
168 async_decoder = self._decoder
170 return async_parser, async_decoder
172 def _build_fastdepends_model(self) -> None:
173 for call in self.calls:
174 async_parser, async_decoder = self._get_parser_and_decoder(
175 call.item_parser, call.item_decoder
176 )
178 call._setup(
179 parser=async_parser,
180 decoder=async_decoder,
181 config=self._outer_config.fd_config,
182 broker_dependencies=self._outer_config.broker_dependencies,
183 _call_decorators=self._call_decorators,
184 )
186 call.handler.refresh(with_mock=False)
188 def _post_start(self) -> None:
189 self.running = True
191 @abstractmethod
192 async def stop(self) -> None:
193 """Stop message consuming.
195 Blocks event loop up to graceful_timeout seconds.
196 """
197 # set running false before releasing to stop new message reading
198 self.running = False
200 # Wait for already consumed messages to be processed
201 if isinstance(self.lock, MultiLock):
202 await self.lock.wait_release(self._outer_config.graceful_timeout)
204 def add_call(
205 self,
206 *,
207 parser_: Optional["CustomCallable"],
208 decoder_: Optional["CustomCallable"],
209 dependencies_: Iterable["Dependant"],
210 codec_: Optional["CodecProto"] = None,
211 ) -> Self:
212 self._call_options = _CallOptions(
213 parser=parser_,
214 decoder=decoder_,
215 dependencies=dependencies_,
216 codec=codec_,
217 )
218 return self
220 @overload
221 def __call__(
222 self,
223 func: Callable[P_HandlerParams, T_HandlerReturn],
224 *,
225 filter: "Filter[Any]" = default_filter,
226 parser: Optional["CustomCallable"] = None,
227 decoder: Optional["CustomCallable"] = None,
228 dependencies: Iterable["Dependant"] = (),
229 ) -> "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]": ...
231 @overload
232 def __call__(
233 self,
234 func: None = None,
235 *,
236 filter: "Filter[Any]" = default_filter,
237 parser: Optional["CustomCallable"] = None,
238 decoder: Optional["CustomCallable"] = None,
239 dependencies: Iterable["Dependant"] = (),
240 ) -> Callable[
241 [Callable[P_HandlerParams, T_HandlerReturn]],
242 "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]",
243 ]: ...
245 @override
246 def __call__(
247 self,
248 func: Callable[P_HandlerParams, T_HandlerReturn] | None = None,
249 *,
250 filter: "Filter[Any]" = default_filter,
251 parser: Optional["CustomCallable"] = None,
252 decoder: Optional["CustomCallable"] = None,
253 dependencies: Iterable["Dependant"] = (),
254 ) -> Union[
255 "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]",
256 Callable[
257 [Callable[P_HandlerParams, T_HandlerReturn]],
258 "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]",
259 ],
260 ]:
261 total_deps = (*self._call_options.dependencies, *dependencies)
262 async_filter: AsyncFilter[StreamMessage[MsgType]] = to_async(filter)
264 def real_wrapper(
265 func: Callable[P_HandlerParams, T_HandlerReturn],
266 ) -> "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]":
267 handler = super(SubscriberUsecase, self).__call__(func)
268 handler._subscribers.append(self)
270 self.calls.add_call(
271 HandlerItem[MsgType](
272 handler=handler,
273 filter=async_filter,
274 item_parser=parser,
275 item_decoder=decoder,
276 dependencies=total_deps,
277 ),
278 )
280 return handler
282 if func is None:
283 return real_wrapper
285 return real_wrapper(func)
287 async def consume(self, msg: MsgType) -> Any:
288 """Consume a message asynchronously."""
289 if not self.running:
290 return None
292 try:
293 return await self.process_message(msg)
295 except StopConsume:
296 # Stop handler at StopConsume exception
297 await self.stop()
299 except SystemExit:
300 # Stop handler at `exit()` call
301 await self.stop()
303 if app := self._outer_config.fd_config.context.get("app"):
304 app.exit()
306 except Exception: # nosec B110
307 # All other exceptions were logged by CriticalLogMiddleware
308 pass
310 async def process_message(self, msg: MsgType) -> "Response":
311 """Execute all message processing stages."""
312 context = self._outer_config.fd_config.context
313 logger_state = self._outer_config.logger
315 async with AsyncExitStack() as stack:
316 stack.enter_context(self.lock)
318 # Enter context before middlewares
319 stack.enter_context(context.scope("handler_", self))
320 stack.enter_context(context.scope("logger", logger_state.logger.logger))
321 for k, v in self._outer_config.extra_context.items():
322 stack.enter_context(context.scope(k, v))
324 # enter all middlewares
325 middlewares: list[BaseMiddleware] = []
326 for base_m in self.__build__middlewares_stack():
327 middleware = base_m(msg, context=context)
328 middlewares.append(middleware)
329 await middleware.__aenter__()
331 cache: dict[Any, Any] = {}
332 parsing_error: Exception | None = None
333 for h in self.calls: 333 ↛ 379line 333 didn't jump to line 379 because the loop on line 333 didn't complete
334 try:
335 message = await h.is_suitable(msg, cache)
336 except Exception as e:
337 parsing_error = e
338 break
340 if message is not None:
341 stack.enter_context(
342 context.scope("log_context", self.get_log_context(message)),
343 )
344 stack.enter_context(context.scope("message", message))
346 # Middlewares should be exited before scope release
347 for m in middlewares:
348 stack.push_async_exit(m.__aexit__)
350 result_msg = ensure_response(
351 await h.call(
352 message=message,
353 # consumer middlewares
354 _extra_middlewares=(
355 m.consume_scope for m in middlewares[::-1]
356 ),
357 ),
358 )
360 if not result_msg.correlation_id:
361 result_msg.correlation_id = message.correlation_id
363 for p in chain(
364 self.__get_response_publisher(message),
365 h.handler._publishers,
366 ):
367 await p._publish(
368 result_msg.as_publish_command(),
369 _extra_middlewares=(
370 m.publish_scope for m in middlewares[::-1]
371 ),
372 )
374 # Return data for tests
375 return result_msg
377 # Suitable handler was not found or
378 # parsing/decoding exception occurred
379 for m in middlewares:
380 stack.push_async_exit(m.__aexit__)
382 # Reraise it to catch in tests
383 if parsing_error: 383 ↛ 386line 383 didn't jump to line 386 because the condition on line 383 was always true
384 raise parsing_error
386 error_msg = f"There is no suitable handler for {msg=}"
387 raise SubscriberNotFound(error_msg)
389 # An error was raised and processed by some middleware
390 return ensure_response(None)
392 def __build__middlewares_stack(self) -> tuple["BrokerMiddleware[MsgType]", ...]:
393 logger_state = self._outer_config.logger
395 if self.__auto_ack_disabled:
396 broker_middlewares = (
397 CriticalLogMiddleware(logger_state),
398 *self._broker_middlewares,
399 )
401 else:
402 broker_middlewares = (
403 AcknowledgementMiddleware(
404 logger=logger_state,
405 ack_policy=self.ack_policy,
406 extra_options=self.extra_watcher_options,
407 ),
408 CriticalLogMiddleware(logger_state),
409 *self._broker_middlewares,
410 )
412 return broker_middlewares
414 def __get_response_publisher(
415 self,
416 message: "StreamMessage[MsgType]",
417 ) -> Iterable["PublisherProto"]:
418 if not message.reply_to or self._no_reply:
419 return ()
421 return self._make_response_publisher(message)
423 @abstractmethod
424 def _make_response_publisher(
425 self,
426 message: "StreamMessage[MsgType]",
427 ) -> Iterable["PublisherProto"]:
428 raise NotImplementedError
430 @abstractmethod
431 async def get_one(self, *, timeout: float = 5) -> Optional["StreamMessage[MsgType]"]:
432 raise NotImplementedError
434 @abstractmethod
435 async def __aiter__(self) -> AsyncIterator["StreamMessage[MsgType]"]:
436 raise NotImplementedError
438 def get_log_context(
439 self,
440 message: Optional["StreamMessage[MsgType]"],
441 ) -> dict[str, str]:
442 """Generate log context."""
443 return {
444 "message_id": getattr(message, "message_id", ""),
445 }
447 def _log(
448 self,
449 log_level: int | None,
450 message: str,
451 extra: dict[str, Any] | None = None,
452 exc_info: Exception | None = None,
453 ) -> None:
454 self._outer_config.logger.log(
455 message,
456 log_level,
457 extra=extra,
458 exc_info=exc_info,
459 )
461 def schema(self) -> dict[str, "SubscriberSpec"]:
462 self._build_fastdepends_model()
463 return self.specification.get_schema()