Coverage for faststream / middlewares / exception.py: 96%
45 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1from collections.abc import Awaitable, Callable
2from typing import (
3 TYPE_CHECKING,
4 Any,
5 Literal,
6 NoReturn,
7 Optional,
8 TypeAlias,
9 cast,
10 overload,
11)
13from faststream._internal.middlewares import BaseMiddleware
14from faststream._internal.utils import apply_types
15from faststream._internal.utils.functions import FakeContext, to_async
16from faststream.exceptions import IgnoredException
18if TYPE_CHECKING:
19 from contextlib import AbstractContextManager
20 from types import TracebackType
22 from faststream._internal.basic_types import AsyncFuncAny
23 from faststream._internal.context.repository import ContextRepo
24 from faststream.message import StreamMessage
27GeneralExceptionHandler: TypeAlias = Callable[..., None] | Callable[..., Awaitable[None]]
28PublishingExceptionHandler: TypeAlias = Callable[..., Any]
30CastedGeneralExceptionHandler: TypeAlias = Callable[..., Awaitable[None]]
31CastedPublishingExceptionHandler: TypeAlias = Callable[..., Awaitable[Any]]
32CastedHandlers: TypeAlias = dict[
33 type[Exception],
34 CastedGeneralExceptionHandler,
35]
36CastedPublishingHandlers: TypeAlias = dict[
37 type[Exception],
38 CastedPublishingExceptionHandler,
39]
42class ExceptionMiddleware:
43 __slots__ = ("_handlers", "_publish_handlers")
45 _handlers: CastedHandlers
46 _publish_handlers: CastedPublishingHandlers
48 def __init__(
49 self,
50 handlers: dict[type[Exception], GeneralExceptionHandler] | None = None,
51 publish_handlers: dict[type[Exception], PublishingExceptionHandler] | None = None,
52 ) -> None:
53 self._handlers: CastedHandlers = {
54 IgnoredException: ignore_handler,
55 **{
56 exc_type: apply_types(
57 cast("Callable[..., Awaitable[None]]", to_async(handler)),
58 serializer_cls=None,
59 )
60 for exc_type, handler in (handlers or {}).items()
61 },
62 }
64 self._publish_handlers: CastedPublishingHandlers = {
65 IgnoredException: ignore_handler,
66 **{
67 exc_type: apply_types(to_async(handler), serializer_cls=None)
68 for exc_type, handler in (publish_handlers or {}).items()
69 },
70 }
72 @overload
73 def add_handler(
74 self,
75 exc: type[Exception],
76 publish: Literal[False] = False,
77 ) -> Callable[[GeneralExceptionHandler], GeneralExceptionHandler]: ...
79 @overload
80 def add_handler(
81 self,
82 exc: type[Exception],
83 publish: Literal[True] = ...,
84 ) -> Callable[[PublishingExceptionHandler], PublishingExceptionHandler]: ...
86 def add_handler(
87 self,
88 exc: type[Exception],
89 publish: bool = False,
90 ) -> (
91 Callable[[GeneralExceptionHandler], GeneralExceptionHandler]
92 | Callable[[PublishingExceptionHandler], PublishingExceptionHandler]
93 ):
94 if publish:
96 def pub_wrapper(
97 func: PublishingExceptionHandler,
98 ) -> PublishingExceptionHandler:
99 self._publish_handlers[exc] = apply_types(
100 to_async(func),
101 serializer_cls=None,
102 )
103 return func
105 return pub_wrapper
107 def default_wrapper(
108 func: GeneralExceptionHandler,
109 ) -> GeneralExceptionHandler:
110 self._handlers[exc] = apply_types(
111 to_async(func),
112 serializer_cls=None,
113 )
114 return func
116 return default_wrapper
118 def __call__(
119 self,
120 msg: Any | None,
121 /,
122 *,
123 context: "ContextRepo",
124 ) -> "_BaseExceptionMiddleware":
125 """Real middleware runtime constructor."""
126 return _BaseExceptionMiddleware(
127 handlers=self._handlers,
128 publish_handlers=self._publish_handlers,
129 context=context,
130 msg=msg,
131 )
134class _BaseExceptionMiddleware(BaseMiddleware):
135 def __init__(
136 self,
137 *,
138 handlers: CastedHandlers,
139 publish_handlers: CastedPublishingHandlers,
140 context: "ContextRepo",
141 msg: Any | None,
142 ) -> None:
143 super().__init__(msg, context=context)
144 self._handlers = handlers
145 self._publish_handlers = publish_handlers
147 async def consume_scope(
148 self,
149 call_next: "AsyncFuncAny",
150 msg: "StreamMessage[Any]",
151 ) -> Any:
152 try:
153 return await call_next(msg)
155 except Exception as exc:
156 for cls in type(exc).__mro__:
157 if cls in self._publish_handlers:
158 return await self._publish_handlers[cls](exc, context__=self.context)
160 raise
162 async def after_processed(
163 self,
164 exc_type: type[BaseException] | None = None,
165 exc_val: BaseException | None = None,
166 exc_tb: Optional["TracebackType"] = None,
167 ) -> bool | None:
168 if exc_type:
169 for cls in exc_type.__mro__:
170 if cls in self._handlers:
171 handler = self._handlers[cls]
172 # TODO: remove it after context will be moved to middleware
173 # In case parser/decoder error occurred
174 scope: AbstractContextManager[Any]
175 if not self.context.get_local("message"): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 scope = self.context.scope("message", self.msg)
177 else:
178 scope = FakeContext()
180 with scope:
181 await handler(exc_val, context__=self.context)
183 return True
185 return False
187 return None
190async def ignore_handler(
191 exception: IgnoredException,
192 **kwargs: Any, # suppress context
193) -> NoReturn:
194 raise exception