Coverage for faststream / middlewares / exception.py: 96%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from collections.abc import Awaitable, Callable 

2from typing import ( 

3 TYPE_CHECKING, 

4 Any, 

5 Literal, 

6 NoReturn, 

7 Optional, 

8 TypeAlias, 

9 cast, 

10 overload, 

11) 

12 

13from faststream._internal.middlewares import BaseMiddleware 

14from faststream._internal.utils import apply_types 

15from faststream._internal.utils.functions import FakeContext, to_async 

16from faststream.exceptions import IgnoredException 

17 

18if TYPE_CHECKING: 

19 from contextlib import AbstractContextManager 

20 from types import TracebackType 

21 

22 from faststream._internal.basic_types import AsyncFuncAny 

23 from faststream._internal.context.repository import ContextRepo 

24 from faststream.message import StreamMessage 

25 

26 

27GeneralExceptionHandler: TypeAlias = Callable[..., None] | Callable[..., Awaitable[None]] 

28PublishingExceptionHandler: TypeAlias = Callable[..., Any] 

29 

30CastedGeneralExceptionHandler: TypeAlias = Callable[..., Awaitable[None]] 

31CastedPublishingExceptionHandler: TypeAlias = Callable[..., Awaitable[Any]] 

32CastedHandlers: TypeAlias = dict[ 

33 type[Exception], 

34 CastedGeneralExceptionHandler, 

35] 

36CastedPublishingHandlers: TypeAlias = dict[ 

37 type[Exception], 

38 CastedPublishingExceptionHandler, 

39] 

40 

41 

42class ExceptionMiddleware: 

43 __slots__ = ("_handlers", "_publish_handlers") 

44 

45 _handlers: CastedHandlers 

46 _publish_handlers: CastedPublishingHandlers 

47 

48 def __init__( 

49 self, 

50 handlers: dict[type[Exception], GeneralExceptionHandler] | None = None, 

51 publish_handlers: dict[type[Exception], PublishingExceptionHandler] | None = None, 

52 ) -> None: 

53 self._handlers: CastedHandlers = { 

54 IgnoredException: ignore_handler, 

55 **{ 

56 exc_type: apply_types( 

57 cast("Callable[..., Awaitable[None]]", to_async(handler)), 

58 serializer_cls=None, 

59 ) 

60 for exc_type, handler in (handlers or {}).items() 

61 }, 

62 } 

63 

64 self._publish_handlers: CastedPublishingHandlers = { 

65 IgnoredException: ignore_handler, 

66 **{ 

67 exc_type: apply_types(to_async(handler), serializer_cls=None) 

68 for exc_type, handler in (publish_handlers or {}).items() 

69 }, 

70 } 

71 

72 @overload 

73 def add_handler( 

74 self, 

75 exc: type[Exception], 

76 publish: Literal[False] = False, 

77 ) -> Callable[[GeneralExceptionHandler], GeneralExceptionHandler]: ... 

78 

79 @overload 

80 def add_handler( 

81 self, 

82 exc: type[Exception], 

83 publish: Literal[True] = ..., 

84 ) -> Callable[[PublishingExceptionHandler], PublishingExceptionHandler]: ... 

85 

86 def add_handler( 

87 self, 

88 exc: type[Exception], 

89 publish: bool = False, 

90 ) -> ( 

91 Callable[[GeneralExceptionHandler], GeneralExceptionHandler] 

92 | Callable[[PublishingExceptionHandler], PublishingExceptionHandler] 

93 ): 

94 if publish: 

95 

96 def pub_wrapper( 

97 func: PublishingExceptionHandler, 

98 ) -> PublishingExceptionHandler: 

99 self._publish_handlers[exc] = apply_types( 

100 to_async(func), 

101 serializer_cls=None, 

102 ) 

103 return func 

104 

105 return pub_wrapper 

106 

107 def default_wrapper( 

108 func: GeneralExceptionHandler, 

109 ) -> GeneralExceptionHandler: 

110 self._handlers[exc] = apply_types( 

111 to_async(func), 

112 serializer_cls=None, 

113 ) 

114 return func 

115 

116 return default_wrapper 

117 

118 def __call__( 

119 self, 

120 msg: Any | None, 

121 /, 

122 *, 

123 context: "ContextRepo", 

124 ) -> "_BaseExceptionMiddleware": 

125 """Real middleware runtime constructor.""" 

126 return _BaseExceptionMiddleware( 

127 handlers=self._handlers, 

128 publish_handlers=self._publish_handlers, 

129 context=context, 

130 msg=msg, 

131 ) 

132 

133 

134class _BaseExceptionMiddleware(BaseMiddleware): 

135 def __init__( 

136 self, 

137 *, 

138 handlers: CastedHandlers, 

139 publish_handlers: CastedPublishingHandlers, 

140 context: "ContextRepo", 

141 msg: Any | None, 

142 ) -> None: 

143 super().__init__(msg, context=context) 

144 self._handlers = handlers 

145 self._publish_handlers = publish_handlers 

146 

147 async def consume_scope( 

148 self, 

149 call_next: "AsyncFuncAny", 

150 msg: "StreamMessage[Any]", 

151 ) -> Any: 

152 try: 

153 return await call_next(msg) 

154 

155 except Exception as exc: 

156 for cls in type(exc).__mro__: 

157 if cls in self._publish_handlers: 

158 return await self._publish_handlers[cls](exc, context__=self.context) 

159 

160 raise 

161 

162 async def after_processed( 

163 self, 

164 exc_type: type[BaseException] | None = None, 

165 exc_val: BaseException | None = None, 

166 exc_tb: Optional["TracebackType"] = None, 

167 ) -> bool | None: 

168 if exc_type: 

169 for cls in exc_type.__mro__: 

170 if cls in self._handlers: 

171 handler = self._handlers[cls] 

172 # TODO: remove it after context will be moved to middleware 

173 # In case parser/decoder error occurred 

174 scope: AbstractContextManager[Any] 

175 if not self.context.get_local("message"): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 scope = self.context.scope("message", self.msg) 

177 else: 

178 scope = FakeContext() 

179 

180 with scope: 

181 await handler(exc_val, context__=self.context) 

182 

183 return True 

184 

185 return False 

186 

187 return None 

188 

189 

190async def ignore_handler( 

191 exception: IgnoredException, 

192 **kwargs: Any, # suppress context 

193) -> NoReturn: 

194 raise exception