Coverage for fastapi/openapi/utils.py: 100%

251 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-08 03:53 +0000

1import http.client 1abcde

2import inspect 1abcde

3import warnings 1abcde

4from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast 1abcde

5 

6from fastapi import routing 1abcde

7from fastapi._compat import ( 1abcde

8 GenerateJsonSchema, 

9 JsonSchemaValue, 

10 ModelField, 

11 Undefined, 

12 get_compat_model_name_map, 

13 get_definitions, 

14 get_schema_from_model_field, 

15 lenient_issubclass, 

16) 

17from fastapi.datastructures import DefaultPlaceholder 1abcde

18from fastapi.dependencies.models import Dependant 1abcde

19from fastapi.dependencies.utils import get_flat_dependant, get_flat_params 1abcde

20from fastapi.encoders import jsonable_encoder 1abcde

21from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE 1abcde

22from fastapi.openapi.models import OpenAPI 1abcde

23from fastapi.params import Body, Param 1abcde

24from fastapi.responses import Response 1abcde

25from fastapi.types import ModelNameMap 1abcde

26from fastapi.utils import ( 1abcde

27 deep_dict_update, 

28 generate_operation_id_for_path, 

29 is_body_allowed_for_status_code, 

30) 

31from starlette.responses import JSONResponse 1abcde

32from starlette.routing import BaseRoute 1abcde

33from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY 1abcde

34from typing_extensions import Literal 1abcde

35 

36validation_error_definition = { 1abcde

37 "title": "ValidationError", 

38 "type": "object", 

39 "properties": { 

40 "loc": { 

41 "title": "Location", 

42 "type": "array", 

43 "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, 

44 }, 

45 "msg": {"title": "Message", "type": "string"}, 

46 "type": {"title": "Error Type", "type": "string"}, 

47 }, 

48 "required": ["loc", "msg", "type"], 

49} 

50 

51validation_error_response_definition = { 1abcde

52 "title": "HTTPValidationError", 

53 "type": "object", 

54 "properties": { 

55 "detail": { 

56 "title": "Detail", 

57 "type": "array", 

58 "items": {"$ref": REF_PREFIX + "ValidationError"}, 

59 } 

60 }, 

61} 

62 

63status_code_ranges: Dict[str, str] = { 1abcde

64 "1XX": "Information", 

65 "2XX": "Success", 

66 "3XX": "Redirection", 

67 "4XX": "Client Error", 

68 "5XX": "Server Error", 

69 "DEFAULT": "Default Response", 

70} 

71 

72 

73def get_openapi_security_definitions( 1abcde

74 flat_dependant: Dependant, 

75) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: 

76 security_definitions = {} 1abcde

77 operation_security = [] 1abcde

78 for security_requirement in flat_dependant.security_requirements: 1abcde

79 security_definition = jsonable_encoder( 1abcde

80 security_requirement.security_scheme.model, 

81 by_alias=True, 

82 exclude_none=True, 

83 ) 

84 security_name = security_requirement.security_scheme.scheme_name 1abcde

85 security_definitions[security_name] = security_definition 1abcde

86 operation_security.append({security_name: security_requirement.scopes}) 1abcde

87 return security_definitions, operation_security 1abcde

88 

89 

90def get_openapi_operation_parameters( 1abcde

91 *, 

92 all_route_params: Sequence[ModelField], 

93 schema_generator: GenerateJsonSchema, 

94 model_name_map: ModelNameMap, 

95 field_mapping: Dict[ 

96 Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue 

97 ], 

98 separate_input_output_schemas: bool = True, 

99) -> List[Dict[str, Any]]: 

100 parameters = [] 1abcde

101 for param in all_route_params: 1abcde

102 field_info = param.field_info 1abcde

103 field_info = cast(Param, field_info) 1abcde

104 if not field_info.include_in_schema: 1abcde

105 continue 1abcde

106 param_schema = get_schema_from_model_field( 1abcde

107 field=param, 

108 schema_generator=schema_generator, 

109 model_name_map=model_name_map, 

110 field_mapping=field_mapping, 

111 separate_input_output_schemas=separate_input_output_schemas, 

112 ) 

113 parameter = { 1abcde

114 "name": param.alias, 

115 "in": field_info.in_.value, 

116 "required": param.required, 

117 "schema": param_schema, 

118 } 

119 if field_info.description: 1abcde

120 parameter["description"] = field_info.description 1abcde

121 if field_info.openapi_examples: 1abcde

122 parameter["examples"] = jsonable_encoder(field_info.openapi_examples) 1abcde

123 elif field_info.example != Undefined: 1abcde

124 parameter["example"] = jsonable_encoder(field_info.example) 1abcde

125 if field_info.deprecated: 1abcde

126 parameter["deprecated"] = True 1abcde

127 parameters.append(parameter) 1abcde

128 return parameters 1abcde

129 

130 

131def get_openapi_operation_request_body( 1abcde

132 *, 

133 body_field: Optional[ModelField], 

134 schema_generator: GenerateJsonSchema, 

135 model_name_map: ModelNameMap, 

136 field_mapping: Dict[ 

137 Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue 

138 ], 

139 separate_input_output_schemas: bool = True, 

140) -> Optional[Dict[str, Any]]: 

141 if not body_field: 1abcde

142 return None 1abcde

143 assert isinstance(body_field, ModelField) 1abcde

144 body_schema = get_schema_from_model_field( 1abcde

145 field=body_field, 

146 schema_generator=schema_generator, 

147 model_name_map=model_name_map, 

148 field_mapping=field_mapping, 

149 separate_input_output_schemas=separate_input_output_schemas, 

150 ) 

151 field_info = cast(Body, body_field.field_info) 1abcde

152 request_media_type = field_info.media_type 1abcde

153 required = body_field.required 1abcde

154 request_body_oai: Dict[str, Any] = {} 1abcde

155 if required: 1abcde

156 request_body_oai["required"] = required 1abcde

157 request_media_content: Dict[str, Any] = {"schema": body_schema} 1abcde

158 if field_info.openapi_examples: 1abcde

159 request_media_content["examples"] = jsonable_encoder( 1abcde

160 field_info.openapi_examples 

161 ) 

162 elif field_info.example != Undefined: 1abcde

163 request_media_content["example"] = jsonable_encoder(field_info.example) 1abcde

164 request_body_oai["content"] = {request_media_type: request_media_content} 1abcde

165 return request_body_oai 1abcde

166 

167 

168def generate_operation_id( 1abcde

169 *, route: routing.APIRoute, method: str 1abcde

170) -> str: # pragma: nocover 1abcde

171 warnings.warn( 

172 "fastapi.openapi.utils.generate_operation_id() was deprecated, " 

173 "it is not used internally, and will be removed soon", 

174 DeprecationWarning, 

175 stacklevel=2, 

176 ) 

177 if route.operation_id: 

178 return route.operation_id 

179 path: str = route.path_format 

180 return generate_operation_id_for_path(name=route.name, path=path, method=method) 

181 

182 

183def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: 1abcde

184 if route.summary: 1abcde

185 return route.summary 1abcde

186 return route.name.replace("_", " ").title() 1abcde

187 

188 

189def get_openapi_operation_metadata( 1abcde

190 *, route: routing.APIRoute, method: str, operation_ids: Set[str] 

191) -> Dict[str, Any]: 

192 operation: Dict[str, Any] = {} 1abcde

193 if route.tags: 1abcde

194 operation["tags"] = route.tags 1abcde

195 operation["summary"] = generate_operation_summary(route=route, method=method) 1abcde

196 if route.description: 1abcde

197 operation["description"] = route.description 1abcde

198 operation_id = route.operation_id or route.unique_id 1abcde

199 if operation_id in operation_ids: 1abcde

200 message = ( 1abcde

201 f"Duplicate Operation ID {operation_id} for function " 

202 + f"{route.endpoint.__name__}" 

203 ) 

204 file_name = getattr(route.endpoint, "__globals__", {}).get("__file__") 1abcde

205 if file_name: 1abcde

206 message += f" at {file_name}" 1abcde

207 warnings.warn(message, stacklevel=1) 1abcde

208 operation_ids.add(operation_id) 1abcde

209 operation["operationId"] = operation_id 1abcde

210 if route.deprecated: 1abcde

211 operation["deprecated"] = route.deprecated 1abcde

212 return operation 1abcde

213 

214 

215def get_openapi_path( 1abcde

216 *, 

217 route: routing.APIRoute, 

218 operation_ids: Set[str], 

219 schema_generator: GenerateJsonSchema, 

220 model_name_map: ModelNameMap, 

221 field_mapping: Dict[ 

222 Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue 

223 ], 

224 separate_input_output_schemas: bool = True, 

225) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: 

226 path = {} 1abcde

227 security_schemes: Dict[str, Any] = {} 1abcde

228 definitions: Dict[str, Any] = {} 1abcde

229 assert route.methods is not None, "Methods must be a list" 1abcde

230 if isinstance(route.response_class, DefaultPlaceholder): 1abcde

231 current_response_class: Type[Response] = route.response_class.value 1abcde

232 else: 

233 current_response_class = route.response_class 1abcde

234 assert current_response_class, "A response class is needed to generate OpenAPI" 1abcde

235 route_response_media_type: Optional[str] = current_response_class.media_type 1abcde

236 if route.include_in_schema: 1abcde

237 for method in route.methods: 1abcde

238 operation = get_openapi_operation_metadata( 1abcde

239 route=route, method=method, operation_ids=operation_ids 

240 ) 

241 parameters: List[Dict[str, Any]] = [] 1abcde

242 flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) 1abcde

243 security_definitions, operation_security = get_openapi_security_definitions( 1abcde

244 flat_dependant=flat_dependant 

245 ) 

246 if operation_security: 1abcde

247 operation.setdefault("security", []).extend(operation_security) 1abcde

248 if security_definitions: 1abcde

249 security_schemes.update(security_definitions) 1abcde

250 all_route_params = get_flat_params(route.dependant) 1abcde

251 operation_parameters = get_openapi_operation_parameters( 1abcde

252 all_route_params=all_route_params, 

253 schema_generator=schema_generator, 

254 model_name_map=model_name_map, 

255 field_mapping=field_mapping, 

256 separate_input_output_schemas=separate_input_output_schemas, 

257 ) 

258 parameters.extend(operation_parameters) 1abcde

259 if parameters: 1abcde

260 all_parameters = { 1abcde

261 (param["in"], param["name"]): param for param in parameters 

262 } 

263 required_parameters = { 1abcde

264 (param["in"], param["name"]): param 

265 for param in parameters 

266 if param.get("required") 

267 } 

268 # Make sure required definitions of the same parameter take precedence 

269 # over non-required definitions 

270 all_parameters.update(required_parameters) 1abcde

271 operation["parameters"] = list(all_parameters.values()) 1abcde

272 if method in METHODS_WITH_BODY: 1abcde

273 request_body_oai = get_openapi_operation_request_body( 1abcde

274 body_field=route.body_field, 

275 schema_generator=schema_generator, 

276 model_name_map=model_name_map, 

277 field_mapping=field_mapping, 

278 separate_input_output_schemas=separate_input_output_schemas, 

279 ) 

280 if request_body_oai: 1abcde

281 operation["requestBody"] = request_body_oai 1abcde

282 if route.callbacks: 1abcde

283 callbacks = {} 1abcde

284 for callback in route.callbacks: 1abcde

285 if isinstance(callback, routing.APIRoute): 1abcde

286 ( 1abcde

287 cb_path, 

288 cb_security_schemes, 

289 cb_definitions, 

290 ) = get_openapi_path( 

291 route=callback, 

292 operation_ids=operation_ids, 

293 schema_generator=schema_generator, 

294 model_name_map=model_name_map, 

295 field_mapping=field_mapping, 

296 separate_input_output_schemas=separate_input_output_schemas, 

297 ) 

298 callbacks[callback.name] = {callback.path: cb_path} 1abcde

299 operation["callbacks"] = callbacks 1abcde

300 if route.status_code is not None: 1abcde

301 status_code = str(route.status_code) 1abcde

302 else: 

303 # It would probably make more sense for all response classes to have an 

304 # explicit default status_code, and to extract it from them, instead of 

305 # doing this inspection tricks, that would probably be in the future 

306 # TODO: probably make status_code a default class attribute for all 

307 # responses in Starlette 

308 response_signature = inspect.signature(current_response_class.__init__) 1abcde

309 status_code_param = response_signature.parameters.get("status_code") 1abcde

310 if status_code_param is not None: 1abcde

311 if isinstance(status_code_param.default, int): 1abcde

312 status_code = str(status_code_param.default) 1abcde

313 operation.setdefault("responses", {}).setdefault(status_code, {})[ 1abcde

314 "description" 

315 ] = route.response_description 

316 if route_response_media_type and is_body_allowed_for_status_code( 1abcde

317 route.status_code 

318 ): 

319 response_schema = {"type": "string"} 1abcde

320 if lenient_issubclass(current_response_class, JSONResponse): 1abcde

321 if route.response_field: 1abcde

322 response_schema = get_schema_from_model_field( 1abcde

323 field=route.response_field, 

324 schema_generator=schema_generator, 

325 model_name_map=model_name_map, 

326 field_mapping=field_mapping, 

327 separate_input_output_schemas=separate_input_output_schemas, 

328 ) 

329 else: 

330 response_schema = {} 1abcde

331 operation.setdefault("responses", {}).setdefault( 1abcde

332 status_code, {} 

333 ).setdefault("content", {}).setdefault(route_response_media_type, {})[ 

334 "schema" 

335 ] = response_schema 

336 if route.responses: 1abcde

337 operation_responses = operation.setdefault("responses", {}) 1abcde

338 for ( 1abcde

339 additional_status_code, 

340 additional_response, 

341 ) in route.responses.items(): 

342 process_response = additional_response.copy() 1abcde

343 process_response.pop("model", None) 1abcde

344 status_code_key = str(additional_status_code).upper() 1abcde

345 if status_code_key == "DEFAULT": 1abcde

346 status_code_key = "default" 1abcde

347 openapi_response = operation_responses.setdefault( 1abcde

348 status_code_key, {} 

349 ) 

350 assert isinstance( 1abcde

351 process_response, dict 

352 ), "An additional response must be a dict" 

353 field = route.response_fields.get(additional_status_code) 1abcde

354 additional_field_schema: Optional[Dict[str, Any]] = None 1abcde

355 if field: 1abcde

356 additional_field_schema = get_schema_from_model_field( 1abcde

357 field=field, 

358 schema_generator=schema_generator, 

359 model_name_map=model_name_map, 

360 field_mapping=field_mapping, 

361 separate_input_output_schemas=separate_input_output_schemas, 

362 ) 

363 media_type = route_response_media_type or "application/json" 1abcde

364 additional_schema = ( 1abcde

365 process_response.setdefault("content", {}) 

366 .setdefault(media_type, {}) 

367 .setdefault("schema", {}) 

368 ) 

369 deep_dict_update(additional_schema, additional_field_schema) 1abcde

370 status_text: Optional[str] = status_code_ranges.get( 1abcde

371 str(additional_status_code).upper() 

372 ) or http.client.responses.get(int(additional_status_code)) 

373 description = ( 1abcde

374 process_response.get("description") 

375 or openapi_response.get("description") 

376 or status_text 

377 or "Additional Response" 

378 ) 

379 deep_dict_update(openapi_response, process_response) 1abcde

380 openapi_response["description"] = description 1abcde

381 http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) 1abcde

382 if (all_route_params or route.body_field) and not any( 1abcde

383 status in operation["responses"] 

384 for status in [http422, "4XX", "default"] 

385 ): 

386 operation["responses"][http422] = { 1abcde

387 "description": "Validation Error", 

388 "content": { 

389 "application/json": { 

390 "schema": {"$ref": REF_PREFIX + "HTTPValidationError"} 

391 } 

392 }, 

393 } 

394 if "ValidationError" not in definitions: 1abcde

395 definitions.update( 1abcde

396 { 

397 "ValidationError": validation_error_definition, 

398 "HTTPValidationError": validation_error_response_definition, 

399 } 

400 ) 

401 if route.openapi_extra: 1abcde

402 deep_dict_update(operation, route.openapi_extra) 1abcde

403 path[method.lower()] = operation 1abcde

404 return path, security_schemes, definitions 1abcde

405 

406 

407def get_fields_from_routes( 1abcde

408 routes: Sequence[BaseRoute], 

409) -> List[ModelField]: 

410 body_fields_from_routes: List[ModelField] = [] 1abcde

411 responses_from_routes: List[ModelField] = [] 1abcde

412 request_fields_from_routes: List[ModelField] = [] 1abcde

413 callback_flat_models: List[ModelField] = [] 1abcde

414 for route in routes: 1abcde

415 if getattr(route, "include_in_schema", None) and isinstance( 1abcde

416 route, routing.APIRoute 

417 ): 

418 if route.body_field: 1abcde

419 assert isinstance( 1abcde

420 route.body_field, ModelField 

421 ), "A request body must be a Pydantic Field" 

422 body_fields_from_routes.append(route.body_field) 1abcde

423 if route.response_field: 1abcde

424 responses_from_routes.append(route.response_field) 1abcde

425 if route.response_fields: 1abcde

426 responses_from_routes.extend(route.response_fields.values()) 1abcde

427 if route.callbacks: 1abcde

428 callback_flat_models.extend(get_fields_from_routes(route.callbacks)) 1abcde

429 params = get_flat_params(route.dependant) 1abcde

430 request_fields_from_routes.extend(params) 1abcde

431 

432 flat_models = callback_flat_models + list( 1abcde

433 body_fields_from_routes + responses_from_routes + request_fields_from_routes 

434 ) 

435 return flat_models 1abcde

436 

437 

438def get_openapi( 1abcde

439 *, 

440 title: str, 

441 version: str, 

442 openapi_version: str = "3.1.0", 

443 summary: Optional[str] = None, 

444 description: Optional[str] = None, 

445 routes: Sequence[BaseRoute], 

446 webhooks: Optional[Sequence[BaseRoute]] = None, 

447 tags: Optional[List[Dict[str, Any]]] = None, 

448 servers: Optional[List[Dict[str, Union[str, Any]]]] = None, 

449 terms_of_service: Optional[str] = None, 

450 contact: Optional[Dict[str, Union[str, Any]]] = None, 

451 license_info: Optional[Dict[str, Union[str, Any]]] = None, 

452 separate_input_output_schemas: bool = True, 

453) -> Dict[str, Any]: 

454 info: Dict[str, Any] = {"title": title, "version": version} 1abcde

455 if summary: 1abcde

456 info["summary"] = summary 1abcde

457 if description: 1abcde

458 info["description"] = description 1abcde

459 if terms_of_service: 1abcde

460 info["termsOfService"] = terms_of_service 1abcde

461 if contact: 1abcde

462 info["contact"] = contact 1abcde

463 if license_info: 1abcde

464 info["license"] = license_info 1abcde

465 output: Dict[str, Any] = {"openapi": openapi_version, "info": info} 1abcde

466 if servers: 1abcde

467 output["servers"] = servers 1abcde

468 components: Dict[str, Dict[str, Any]] = {} 1abcde

469 paths: Dict[str, Dict[str, Any]] = {} 1abcde

470 webhook_paths: Dict[str, Dict[str, Any]] = {} 1abcde

471 operation_ids: Set[str] = set() 1abcde

472 all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) 1abcde

473 model_name_map = get_compat_model_name_map(all_fields) 1abcde

474 schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) 1abcde

475 field_mapping, definitions = get_definitions( 1abcde

476 fields=all_fields, 

477 schema_generator=schema_generator, 

478 model_name_map=model_name_map, 

479 separate_input_output_schemas=separate_input_output_schemas, 

480 ) 

481 for route in routes or []: 1abcde

482 if isinstance(route, routing.APIRoute): 1abcde

483 result = get_openapi_path( 1abcde

484 route=route, 

485 operation_ids=operation_ids, 

486 schema_generator=schema_generator, 

487 model_name_map=model_name_map, 

488 field_mapping=field_mapping, 

489 separate_input_output_schemas=separate_input_output_schemas, 

490 ) 

491 if result: 1abcde

492 path, security_schemes, path_definitions = result 1abcde

493 if path: 1abcde

494 paths.setdefault(route.path_format, {}).update(path) 1abcde

495 if security_schemes: 1abcde

496 components.setdefault("securitySchemes", {}).update( 1abcde

497 security_schemes 

498 ) 

499 if path_definitions: 1abcde

500 definitions.update(path_definitions) 1abcde

501 for webhook in webhooks or []: 1abcde

502 if isinstance(webhook, routing.APIRoute): 1abcde

503 result = get_openapi_path( 1abcde

504 route=webhook, 

505 operation_ids=operation_ids, 

506 schema_generator=schema_generator, 

507 model_name_map=model_name_map, 

508 field_mapping=field_mapping, 

509 separate_input_output_schemas=separate_input_output_schemas, 

510 ) 

511 if result: 1abcde

512 path, security_schemes, path_definitions = result 1abcde

513 if path: 1abcde

514 webhook_paths.setdefault(webhook.path_format, {}).update(path) 1abcde

515 if security_schemes: 1abcde

516 components.setdefault("securitySchemes", {}).update( 1abcde

517 security_schemes 

518 ) 

519 if path_definitions: 1abcde

520 definitions.update(path_definitions) 1abcde

521 if definitions: 1abcde

522 components["schemas"] = {k: definitions[k] for k in sorted(definitions)} 1abcde

523 if components: 1abcde

524 output["components"] = components 1abcde

525 output["paths"] = paths 1abcde

526 if webhook_paths: 1abcde

527 output["webhooks"] = webhook_paths 1abcde

528 if tags: 1abcde

529 output["tags"] = tags 1abcde

530 return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore 1abcde