Coverage for faststream / _internal / application.py: 89%

110 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import logging 

2from abc import abstractmethod 

3from collections.abc import AsyncIterator, Callable, Sequence 

4from contextlib import asynccontextmanager 

5from typing import TYPE_CHECKING, Any, Optional, TypeVar 

6 

7from typing_extensions import ParamSpec 

8 

9from faststream._internal.di import FastDependsConfig 

10from faststream._internal.logger import logger 

11from faststream._internal.utils import apply_types 

12from faststream._internal.utils.functions import fake_context, to_async 

13from faststream.exceptions import SetupError 

14from faststream.specification import AsyncAPI 

15 

16if TYPE_CHECKING: 

17 from faststream._internal.basic_types import ( 

18 AnyCallable, 

19 AsyncFunc, 

20 Lifespan, 

21 LoggerProto, 

22 SettingField, 

23 ) 

24 from faststream._internal.broker import BrokerUsecase 

25 from faststream._internal.context import ContextRepo 

26 from faststream.specification.base import SpecificationFactory 

27 

28 

29try: 

30 from pydantic import ValidationError as PValidation 

31 

32 from faststream.exceptions import StartupValidationError 

33 

34 @asynccontextmanager 

35 async def catch_startup_validation_error() -> AsyncIterator[None]: 

36 try: 

37 yield 

38 except PValidation as e: 

39 missed_fields = [] 

40 invalid_fields = [] 

41 for x in e.errors(): 

42 location = str(x["loc"][0]) 

43 if x["type"] == "missing": 

44 missed_fields.append(location) 

45 else: 

46 invalid_fields.append(location) 

47 

48 raise StartupValidationError( 

49 missed_fields=missed_fields, 

50 invalid_fields=invalid_fields, 

51 ) from e 

52 

53except ImportError: 

54 catch_startup_validation_error = fake_context 

55 

56 

57P_HookParams = ParamSpec("P_HookParams") 

58T_HookReturn = TypeVar("T_HookReturn") 

59 

60 

61class StartAbleApplication: 

62 def __init__( 

63 self, 

64 broker: Optional["BrokerUsecase[Any, Any]"] = None, 

65 /, 

66 specification: Optional["SpecificationFactory"] = None, 

67 config: Optional["FastDependsConfig"] = None, 

68 ) -> None: 

69 self._init_setupable_( 

70 broker, 

71 config=config, 

72 specification=specification, 

73 ) 

74 

75 @property 

76 def context(self) -> "ContextRepo": 

77 return self.config.context 

78 

79 def _init_setupable_( # noqa: PLW3201 

80 self, 

81 broker: Optional["BrokerUsecase[Any, Any]"] = None, 

82 /, 

83 specification: Optional["SpecificationFactory"] = None, 

84 config: Optional["FastDependsConfig"] = None, 

85 ) -> None: 

86 self.config = config or FastDependsConfig() 

87 self.config.context.set_global("app", self) 

88 self.brokers: list[BrokerUsecase[Any, Any]] = [] 

89 

90 self.schema: SpecificationFactory = specification or AsyncAPI() 

91 

92 if broker: 

93 self._add_broker(broker) 

94 

95 async def _start_broker(self) -> None: 

96 assert self.brokers, "You should setup a broker" 

97 for b in self.brokers: 

98 await b.start() 

99 

100 @property 

101 def broker(self) -> Optional["BrokerUsecase[Any, Any]"]: 

102 return self.brokers[0] if self.brokers else None 

103 

104 def set_broker(self, broker: "BrokerUsecase[Any, Any]") -> None: 

105 """Set already existed App object broker. 

106 

107 Useful then you create/init broker in `on_startup` hook. 

108 """ 

109 if self.brokers: 

110 msg = f"`{self}` already has a broker. You can't use multiple brokers until 1.0.0 release." 

111 raise SetupError(msg) 

112 self._add_broker(broker) 

113 

114 def _add_broker(self, broker: "BrokerUsecase[Any, Any]") -> None: 

115 if broker in self.brokers: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true

116 msg = f"Broker {broker} is already added" 

117 raise SetupError(msg) 

118 self.brokers.append(broker) 

119 self.schema.add_broker(broker) 

120 broker._update_fd_config(self.config) 

121 

122 

123class Application(StartAbleApplication): 

124 def __init__( 

125 self, 

126 broker: Optional["BrokerUsecase[Any, Any]"] = None, 

127 /, 

128 config: Optional["FastDependsConfig"] = None, 

129 logger: Optional["LoggerProto"] = logger, 

130 lifespan: Optional["Lifespan"] = None, 

131 on_startup: Sequence["AnyCallable"] = (), 

132 after_startup: Sequence["AnyCallable"] = (), 

133 on_shutdown: Sequence["AnyCallable"] = (), 

134 after_shutdown: Sequence["AnyCallable"] = (), 

135 specification: Optional["SpecificationFactory"] = None, 

136 ) -> None: 

137 self.logger = logger 

138 

139 super().__init__(broker, config=config, specification=specification) 

140 

141 self._on_startup_calling: list[AsyncFunc] = [ 

142 apply_types( 

143 to_async(x), 

144 serializer_cls=self.config._serializer, 

145 context__=self.context, 

146 ) 

147 for x in on_startup 

148 ] 

149 self._after_startup_calling: list[AsyncFunc] = [ 

150 apply_types( 

151 to_async(x), 

152 serializer_cls=self.config._serializer, 

153 context__=self.context, 

154 ) 

155 for x in after_startup 

156 ] 

157 self._on_shutdown_calling: list[AsyncFunc] = [ 

158 apply_types( 

159 to_async(x), 

160 serializer_cls=self.config._serializer, 

161 context__=self.context, 

162 ) 

163 for x in on_shutdown 

164 ] 

165 self._after_shutdown_calling: list[AsyncFunc] = [ 

166 apply_types( 

167 to_async(x), 

168 serializer_cls=self.config._serializer, 

169 context__=self.context, 

170 ) 

171 for x in after_shutdown 

172 ] 

173 

174 if lifespan: 

175 self.lifespan_context = apply_types( 

176 func=lifespan, 

177 serializer_cls=self.config._serializer, 

178 cast_result=False, 

179 context__=self.context, 

180 ) 

181 else: 

182 self.lifespan_context = fake_context 

183 

184 @abstractmethod 

185 def exit(self) -> None: 

186 """Stop application manually.""" 

187 ... 

188 

189 @abstractmethod 

190 async def run( 

191 self, 

192 log_level: int, 

193 run_extra_options: dict[str, "SettingField"] | None = None, 

194 ) -> None: ... 

195 

196 # Startup 

197 

198 async def _startup( 

199 self, 

200 log_level: int = logging.INFO, 

201 run_extra_options: dict[str, "SettingField"] | None = None, 

202 ) -> None: 

203 """Private method calls `start` with logging.""" 

204 async with self._startup_logging(log_level=log_level): 

205 await self.start(**(run_extra_options or {})) 

206 

207 self.running = True 

208 

209 async def start( 

210 self, 

211 **run_extra_options: "SettingField", 

212 ) -> None: 

213 """Executes startup hooks and start broker.""" 

214 async with self._start_hooks_context(**run_extra_options): 

215 await self._start_broker() 

216 

217 @asynccontextmanager 

218 async def _start_hooks_context( 

219 self, 

220 **run_extra_options: "SettingField", 

221 ) -> AsyncIterator[None]: 

222 async with catch_startup_validation_error(): 

223 for func in self._on_startup_calling: 

224 await func(**run_extra_options) 

225 

226 yield 

227 

228 for func in self._after_startup_calling: 

229 await func() 

230 

231 @asynccontextmanager 

232 async def _startup_logging( 

233 self, 

234 log_level: int = logging.INFO, 

235 ) -> AsyncIterator[None]: 

236 """Separated startup logging.""" 

237 self._log( 

238 log_level, 

239 "FastStream app starting...", 

240 ) 

241 

242 yield 

243 

244 self._log( 

245 log_level, 

246 "FastStream app started successfully! To exit, press CTRL+C", 

247 ) 

248 

249 # Shutdown 

250 

251 async def _shutdown(self, log_level: int = logging.INFO) -> None: 

252 """Private method calls `stop` with logging.""" 

253 async with self._shutdown_logging(log_level=log_level): 

254 await self.stop() 

255 

256 self.running = False 

257 

258 async def stop(self) -> None: 

259 """Executes shutdown hooks and stop broker.""" 

260 async with self._shutdown_hooks_context(): 

261 for broker in self.brokers: 

262 await broker.stop() 

263 

264 @asynccontextmanager 

265 async def _shutdown_hooks_context(self) -> AsyncIterator[None]: 

266 for func in self._on_shutdown_calling: 

267 await func() 

268 

269 yield 

270 

271 for func in self._after_shutdown_calling: 

272 await func() 

273 

274 @asynccontextmanager 

275 async def _shutdown_logging( 

276 self, 

277 log_level: int = logging.INFO, 

278 ) -> AsyncIterator[None]: 

279 """Separated startup logging.""" 

280 self._log(log_level, "FastStream app shutting down...") 

281 

282 yield 

283 

284 self._log(log_level, "FastStream app shut down gracefully.") 

285 

286 # Service methods 

287 

288 def _log(self, level: int, message: str) -> None: 

289 if self.logger is not None: 

290 self.logger.log(level, message) 

291 

292 # Hooks 

293 

294 def on_startup( 

295 self, 

296 func: Callable[P_HookParams, T_HookReturn], 

297 ) -> Callable[P_HookParams, T_HookReturn]: 

298 """Add hook running BEFORE broker connected. 

299 

300 This hook also takes an extra CLI options as a kwargs. 

301 """ 

302 self._on_startup_calling.append( 

303 apply_types( 

304 to_async(func), 

305 serializer_cls=self.config._serializer, 

306 context__=self.context, 

307 ), 

308 ) 

309 return func 

310 

311 def on_shutdown( 

312 self, 

313 func: Callable[P_HookParams, T_HookReturn], 

314 ) -> Callable[P_HookParams, T_HookReturn]: 

315 """Add hook running BEFORE broker disconnected.""" 

316 self._on_shutdown_calling.append( 

317 apply_types( 

318 to_async(func), 

319 serializer_cls=self.config._serializer, 

320 context__=self.context, 

321 ), 

322 ) 

323 return func 

324 

325 def after_startup( 

326 self, 

327 func: Callable[P_HookParams, T_HookReturn], 

328 ) -> Callable[P_HookParams, T_HookReturn]: 

329 """Add hook running AFTER broker connected.""" 

330 self._after_startup_calling.append( 

331 apply_types( 

332 to_async(func), 

333 serializer_cls=self.config._serializer, 

334 context__=self.context, 

335 ), 

336 ) 

337 return func 

338 

339 def after_shutdown( 

340 self, 

341 func: Callable[P_HookParams, T_HookReturn], 

342 ) -> Callable[P_HookParams, T_HookReturn]: 

343 """Add hook running AFTER broker disconnected.""" 

344 self._after_shutdown_calling.append( 

345 apply_types( 

346 to_async(func), 

347 serializer_cls=self.config._serializer, 

348 context__=self.context, 

349 ), 

350 ) 

351 return func