Coverage for faststream / _internal / endpoint / subscriber / usecase.py: 94%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from abc import abstractmethod 

2from collections.abc import AsyncIterator, Callable, Iterable, Sequence 

3from contextlib import AbstractContextManager, AsyncExitStack 

4from itertools import chain 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Generic, 

9 NamedTuple, 

10 Optional, 

11 Union, 

12) 

13 

14from typing_extensions import Self, overload, override 

15 

16from faststream._internal.endpoint.usecase import Endpoint 

17from faststream._internal.endpoint.utils import ParserComposition 

18from faststream._internal.types import ( 

19 AsyncCallable, 

20 MsgType, 

21 P_HandlerParams, 

22 T_HandlerReturn, 

23) 

24from faststream._internal.utils.functions import FakeContext, to_async 

25from faststream.exceptions import StopConsume, SubscriberNotFound 

26from faststream.middlewares import AcknowledgementMiddleware 

27from faststream.middlewares.logging import CriticalLogMiddleware 

28from faststream.response import ensure_response 

29 

30from .call_item import ( 

31 CallsCollection, 

32 HandlerItem, 

33) 

34from .utils import MultiLock, default_filter 

35 

36if TYPE_CHECKING: 

37 from fast_depends.dependencies import Dependant 

38 

39 from faststream._internal.basic_types import Decorator 

40 from faststream._internal.configs import SubscriberUsecaseConfig 

41 from faststream._internal.endpoint.call_wrapper import HandlerCallWrapper 

42 from faststream._internal.endpoint.publisher import PublisherProto 

43 from faststream._internal.parser import CodecProto 

44 from faststream._internal.types import ( 

45 AsyncFilter, 

46 BrokerMiddleware, 

47 CustomCallable, 

48 Filter, 

49 ) 

50 from faststream.message import StreamMessage 

51 from faststream.middlewares import BaseMiddleware 

52 from faststream.response import Response 

53 from faststream.specification.schema import SubscriberSpec 

54 

55 from .specification import SubscriberSpecification 

56 

57 

58class _CallOptions(NamedTuple): 

59 parser: Optional["CustomCallable"] 

60 decoder: Optional["CustomCallable"] 

61 dependencies: Iterable["Dependant"] 

62 codec: Optional["CodecProto"] = None 

63 

64 

65class SubscriberUsecase(Endpoint, Generic[MsgType]): 

66 """A class representing an asynchronous handler.""" 

67 

68 lock: "AbstractContextManager[Any]" 

69 extra_watcher_options: dict[str, Any] 

70 graceful_timeout: float | None 

71 

72 def __init__( 

73 self, 

74 config: "SubscriberUsecaseConfig", 

75 specification: "SubscriberSpecification", 

76 calls: "CallsCollection[MsgType]", 

77 ) -> None: 

78 """Initialize a new instance of the class.""" 

79 super().__init__(config._outer_config) 

80 

81 self.calls = calls 

82 self.specification = specification 

83 

84 self._no_reply = config.no_reply 

85 self._parser = config.parser 

86 self._decoder = config.decoder 

87 

88 self.ack_policy = config.ack_policy 

89 self.__auto_ack_disabled = config.auto_ack_disabled 

90 

91 self._call_options = _CallOptions( 

92 parser=None, 

93 decoder=None, 

94 dependencies=(), 

95 codec=None, 

96 ) 

97 

98 self._call_decorators: tuple[Decorator, ...] = () 

99 

100 self.running = False 

101 self.lock = FakeContext() 

102 

103 self.extra_watcher_options = {} 

104 

105 @property 

106 def _broker_middlewares(self) -> Sequence["BrokerMiddleware[MsgType]"]: 

107 return self._outer_config.broker_middlewares 

108 

109 async def start(self) -> None: 

110 """Private method to start subscriber by broker.""" 

111 self.lock = MultiLock() 

112 

113 self._build_fastdepends_model() 

114 

115 self._outer_config.logger.log( 

116 f"`{self.specification.call_name}` waiting for messages", 

117 extra=self.get_log_context(None), 

118 ) 

119 

120 def _get_parser_and_decoder( 

121 self, 

122 item_parser: Optional["CustomCallable"] = None, 

123 item_decoder: Optional["CustomCallable"] = None, 

124 ) -> tuple[AsyncCallable, AsyncCallable]: 

125 """Method to resolve parsers with priority. 

126 

127 First priority 

128 >>> sub = broker.subscriber() 

129 >>> 

130 >>> @sub(parser=P0_parser) 

131 >>> async def handler(): ... 

132 

133 Second priority 

134 >>> sub = broker.subscriber(parser=P1_parser) 

135 

136 Third priority 

137 >>> Broker(parser=P2_parser) 

138 

139 Default parser is `self._parser`. 

140 So, the final parser object is 

141 >>> ParserComposition(P0_parser or P1_parser or P2_parser, self._parser) 

142 """ 

143 if parser := ( 

144 item_parser or self._call_options.parser or self._outer_config.broker_parser 

145 ): 

146 async_parser: AsyncCallable = ParserComposition(parser, self._parser) 

147 else: 

148 async_parser = self._parser 

149 

150 # Codec takes priority over legacy decoder. 

151 # Having both is an error — it's ambiguous which takes effect. 

152 codec = self._call_options.codec or self._outer_config.broker_codec 

153 decoder = ( 

154 item_decoder 

155 or self._call_options.decoder 

156 or self._outer_config.broker_decoder 

157 ) 

158 

159 if codec and decoder: 

160 msg = "Cannot use both 'codec' and 'decoder' — 'codec' replaces 'decoder'." 

161 raise ValueError(msg) 

162 

163 if codec: 

164 async_decoder: AsyncCallable = codec.decode 

165 elif decoder: 

166 async_decoder = ParserComposition(decoder, self._decoder) 

167 else: 

168 async_decoder = self._decoder 

169 

170 return async_parser, async_decoder 

171 

172 def _build_fastdepends_model(self) -> None: 

173 for call in self.calls: 

174 async_parser, async_decoder = self._get_parser_and_decoder( 

175 call.item_parser, call.item_decoder 

176 ) 

177 

178 call._setup( 

179 parser=async_parser, 

180 decoder=async_decoder, 

181 config=self._outer_config.fd_config, 

182 broker_dependencies=self._outer_config.broker_dependencies, 

183 _call_decorators=self._call_decorators, 

184 ) 

185 

186 call.handler.refresh(with_mock=False) 

187 

188 def _post_start(self) -> None: 

189 self.running = True 

190 

191 @abstractmethod 

192 async def stop(self) -> None: 

193 """Stop message consuming. 

194 

195 Blocks event loop up to graceful_timeout seconds. 

196 """ 

197 # set running false before releasing to stop new message reading 

198 self.running = False 

199 

200 # Wait for already consumed messages to be processed 

201 if isinstance(self.lock, MultiLock): 

202 await self.lock.wait_release(self._outer_config.graceful_timeout) 

203 

204 def add_call( 

205 self, 

206 *, 

207 parser_: Optional["CustomCallable"], 

208 decoder_: Optional["CustomCallable"], 

209 dependencies_: Iterable["Dependant"], 

210 codec_: Optional["CodecProto"] = None, 

211 ) -> Self: 

212 self._call_options = _CallOptions( 

213 parser=parser_, 

214 decoder=decoder_, 

215 dependencies=dependencies_, 

216 codec=codec_, 

217 ) 

218 return self 

219 

220 @overload 

221 def __call__( 

222 self, 

223 func: Callable[P_HandlerParams, T_HandlerReturn], 

224 *, 

225 filter: "Filter[Any]" = default_filter, 

226 parser: Optional["CustomCallable"] = None, 

227 decoder: Optional["CustomCallable"] = None, 

228 dependencies: Iterable["Dependant"] = (), 

229 ) -> "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]": ... 

230 

231 @overload 

232 def __call__( 

233 self, 

234 func: None = None, 

235 *, 

236 filter: "Filter[Any]" = default_filter, 

237 parser: Optional["CustomCallable"] = None, 

238 decoder: Optional["CustomCallable"] = None, 

239 dependencies: Iterable["Dependant"] = (), 

240 ) -> Callable[ 

241 [Callable[P_HandlerParams, T_HandlerReturn]], 

242 "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]", 

243 ]: ... 

244 

245 @override 

246 def __call__( 

247 self, 

248 func: Callable[P_HandlerParams, T_HandlerReturn] | None = None, 

249 *, 

250 filter: "Filter[Any]" = default_filter, 

251 parser: Optional["CustomCallable"] = None, 

252 decoder: Optional["CustomCallable"] = None, 

253 dependencies: Iterable["Dependant"] = (), 

254 ) -> Union[ 

255 "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]", 

256 Callable[ 

257 [Callable[P_HandlerParams, T_HandlerReturn]], 

258 "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]", 

259 ], 

260 ]: 

261 total_deps = (*self._call_options.dependencies, *dependencies) 

262 async_filter: AsyncFilter[StreamMessage[MsgType]] = to_async(filter) 

263 

264 def real_wrapper( 

265 func: Callable[P_HandlerParams, T_HandlerReturn], 

266 ) -> "HandlerCallWrapper[P_HandlerParams, T_HandlerReturn]": 

267 handler = super(SubscriberUsecase, self).__call__(func) 

268 handler._subscribers.append(self) 

269 

270 self.calls.add_call( 

271 HandlerItem[MsgType]( 

272 handler=handler, 

273 filter=async_filter, 

274 item_parser=parser, 

275 item_decoder=decoder, 

276 dependencies=total_deps, 

277 ), 

278 ) 

279 

280 return handler 

281 

282 if func is None: 

283 return real_wrapper 

284 

285 return real_wrapper(func) 

286 

287 async def consume(self, msg: MsgType) -> Any: 

288 """Consume a message asynchronously.""" 

289 if not self.running: 

290 return None 

291 

292 try: 

293 return await self.process_message(msg) 

294 

295 except StopConsume: 

296 # Stop handler at StopConsume exception 

297 await self.stop() 

298 

299 except SystemExit: 

300 # Stop handler at `exit()` call 

301 await self.stop() 

302 

303 if app := self._outer_config.fd_config.context.get("app"): 

304 app.exit() 

305 

306 except Exception: # nosec B110 

307 # All other exceptions were logged by CriticalLogMiddleware 

308 pass 

309 

310 async def process_message(self, msg: MsgType) -> "Response": 

311 """Execute all message processing stages.""" 

312 context = self._outer_config.fd_config.context 

313 logger_state = self._outer_config.logger 

314 

315 async with AsyncExitStack() as stack: 

316 stack.enter_context(self.lock) 

317 

318 # Enter context before middlewares 

319 stack.enter_context(context.scope("handler_", self)) 

320 stack.enter_context(context.scope("logger", logger_state.logger.logger)) 

321 for k, v in self._outer_config.extra_context.items(): 

322 stack.enter_context(context.scope(k, v)) 

323 

324 # enter all middlewares 

325 middlewares: list[BaseMiddleware] = [] 

326 for base_m in self.__build__middlewares_stack(): 

327 middleware = base_m(msg, context=context) 

328 middlewares.append(middleware) 

329 await middleware.__aenter__() 

330 

331 cache: dict[Any, Any] = {} 

332 parsing_error: Exception | None = None 

333 for h in self.calls: 333 ↛ 379line 333 didn't jump to line 379 because the loop on line 333 didn't complete

334 try: 

335 message = await h.is_suitable(msg, cache) 

336 except Exception as e: 

337 parsing_error = e 

338 break 

339 

340 if message is not None: 

341 stack.enter_context( 

342 context.scope("log_context", self.get_log_context(message)), 

343 ) 

344 stack.enter_context(context.scope("message", message)) 

345 

346 # Middlewares should be exited before scope release 

347 for m in middlewares: 

348 stack.push_async_exit(m.__aexit__) 

349 

350 result_msg = ensure_response( 

351 await h.call( 

352 message=message, 

353 # consumer middlewares 

354 _extra_middlewares=( 

355 m.consume_scope for m in middlewares[::-1] 

356 ), 

357 ), 

358 ) 

359 

360 if not result_msg.correlation_id: 

361 result_msg.correlation_id = message.correlation_id 

362 

363 for p in chain( 

364 self.__get_response_publisher(message), 

365 h.handler._publishers, 

366 ): 

367 await p._publish( 

368 result_msg.as_publish_command(), 

369 _extra_middlewares=( 

370 m.publish_scope for m in middlewares[::-1] 

371 ), 

372 ) 

373 

374 # Return data for tests 

375 return result_msg 

376 

377 # Suitable handler was not found or 

378 # parsing/decoding exception occurred 

379 for m in middlewares: 

380 stack.push_async_exit(m.__aexit__) 

381 

382 # Reraise it to catch in tests 

383 if parsing_error: 383 ↛ 386line 383 didn't jump to line 386 because the condition on line 383 was always true

384 raise parsing_error 

385 

386 error_msg = f"There is no suitable handler for {msg=}" 

387 raise SubscriberNotFound(error_msg) 

388 

389 # An error was raised and processed by some middleware 

390 return ensure_response(None) 

391 

392 def __build__middlewares_stack(self) -> tuple["BrokerMiddleware[MsgType]", ...]: 

393 logger_state = self._outer_config.logger 

394 

395 if self.__auto_ack_disabled: 

396 broker_middlewares = ( 

397 CriticalLogMiddleware(logger_state), 

398 *self._broker_middlewares, 

399 ) 

400 

401 else: 

402 broker_middlewares = ( 

403 AcknowledgementMiddleware( 

404 logger=logger_state, 

405 ack_policy=self.ack_policy, 

406 extra_options=self.extra_watcher_options, 

407 ), 

408 CriticalLogMiddleware(logger_state), 

409 *self._broker_middlewares, 

410 ) 

411 

412 return broker_middlewares 

413 

414 def __get_response_publisher( 

415 self, 

416 message: "StreamMessage[MsgType]", 

417 ) -> Iterable["PublisherProto"]: 

418 if not message.reply_to or self._no_reply: 

419 return () 

420 

421 return self._make_response_publisher(message) 

422 

423 @abstractmethod 

424 def _make_response_publisher( 

425 self, 

426 message: "StreamMessage[MsgType]", 

427 ) -> Iterable["PublisherProto"]: 

428 raise NotImplementedError 

429 

430 @abstractmethod 

431 async def get_one(self, *, timeout: float = 5) -> Optional["StreamMessage[MsgType]"]: 

432 raise NotImplementedError 

433 

434 @abstractmethod 

435 async def __aiter__(self) -> AsyncIterator["StreamMessage[MsgType]"]: 

436 raise NotImplementedError 

437 

438 def get_log_context( 

439 self, 

440 message: Optional["StreamMessage[MsgType]"], 

441 ) -> dict[str, str]: 

442 """Generate log context.""" 

443 return { 

444 "message_id": getattr(message, "message_id", ""), 

445 } 

446 

447 def _log( 

448 self, 

449 log_level: int | None, 

450 message: str, 

451 extra: dict[str, Any] | None = None, 

452 exc_info: Exception | None = None, 

453 ) -> None: 

454 self._outer_config.logger.log( 

455 message, 

456 log_level, 

457 extra=extra, 

458 exc_info=exc_info, 

459 ) 

460 

461 def schema(self) -> dict[str, "SubscriberSpec"]: 

462 self._build_fastdepends_model() 

463 return self.specification.get_schema()