Coverage for fastapi/openapi/utils.py: 100%
251 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-08 03:53 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-08 03:53 +0000
1import http.client 1abcde
2import inspect 1abcde
3import warnings 1abcde
4from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast 1abcde
6from fastapi import routing 1abcde
7from fastapi._compat import ( 1abcde
8 GenerateJsonSchema,
9 JsonSchemaValue,
10 ModelField,
11 Undefined,
12 get_compat_model_name_map,
13 get_definitions,
14 get_schema_from_model_field,
15 lenient_issubclass,
16)
17from fastapi.datastructures import DefaultPlaceholder 1abcde
18from fastapi.dependencies.models import Dependant 1abcde
19from fastapi.dependencies.utils import get_flat_dependant, get_flat_params 1abcde
20from fastapi.encoders import jsonable_encoder 1abcde
21from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE 1abcde
22from fastapi.openapi.models import OpenAPI 1abcde
23from fastapi.params import Body, Param 1abcde
24from fastapi.responses import Response 1abcde
25from fastapi.types import ModelNameMap 1abcde
26from fastapi.utils import ( 1abcde
27 deep_dict_update,
28 generate_operation_id_for_path,
29 is_body_allowed_for_status_code,
30)
31from starlette.responses import JSONResponse 1abcde
32from starlette.routing import BaseRoute 1abcde
33from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY 1abcde
34from typing_extensions import Literal 1abcde
36validation_error_definition = { 1abcde
37 "title": "ValidationError",
38 "type": "object",
39 "properties": {
40 "loc": {
41 "title": "Location",
42 "type": "array",
43 "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
44 },
45 "msg": {"title": "Message", "type": "string"},
46 "type": {"title": "Error Type", "type": "string"},
47 },
48 "required": ["loc", "msg", "type"],
49}
51validation_error_response_definition = { 1abcde
52 "title": "HTTPValidationError",
53 "type": "object",
54 "properties": {
55 "detail": {
56 "title": "Detail",
57 "type": "array",
58 "items": {"$ref": REF_PREFIX + "ValidationError"},
59 }
60 },
61}
63status_code_ranges: Dict[str, str] = { 1abcde
64 "1XX": "Information",
65 "2XX": "Success",
66 "3XX": "Redirection",
67 "4XX": "Client Error",
68 "5XX": "Server Error",
69 "DEFAULT": "Default Response",
70}
73def get_openapi_security_definitions( 1abcde
74 flat_dependant: Dependant,
75) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
76 security_definitions = {} 1abcde
77 operation_security = [] 1abcde
78 for security_requirement in flat_dependant.security_requirements: 1abcde
79 security_definition = jsonable_encoder( 1abcde
80 security_requirement.security_scheme.model,
81 by_alias=True,
82 exclude_none=True,
83 )
84 security_name = security_requirement.security_scheme.scheme_name 1abcde
85 security_definitions[security_name] = security_definition 1abcde
86 operation_security.append({security_name: security_requirement.scopes}) 1abcde
87 return security_definitions, operation_security 1abcde
90def get_openapi_operation_parameters( 1abcde
91 *,
92 all_route_params: Sequence[ModelField],
93 schema_generator: GenerateJsonSchema,
94 model_name_map: ModelNameMap,
95 field_mapping: Dict[
96 Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
97 ],
98 separate_input_output_schemas: bool = True,
99) -> List[Dict[str, Any]]:
100 parameters = [] 1abcde
101 for param in all_route_params: 1abcde
102 field_info = param.field_info 1abcde
103 field_info = cast(Param, field_info) 1abcde
104 if not field_info.include_in_schema: 1abcde
105 continue 1abcde
106 param_schema = get_schema_from_model_field( 1abcde
107 field=param,
108 schema_generator=schema_generator,
109 model_name_map=model_name_map,
110 field_mapping=field_mapping,
111 separate_input_output_schemas=separate_input_output_schemas,
112 )
113 parameter = { 1abcde
114 "name": param.alias,
115 "in": field_info.in_.value,
116 "required": param.required,
117 "schema": param_schema,
118 }
119 if field_info.description: 1abcde
120 parameter["description"] = field_info.description 1abcde
121 if field_info.openapi_examples: 1abcde
122 parameter["examples"] = jsonable_encoder(field_info.openapi_examples) 1abcde
123 elif field_info.example != Undefined: 1abcde
124 parameter["example"] = jsonable_encoder(field_info.example) 1abcde
125 if field_info.deprecated: 1abcde
126 parameter["deprecated"] = True 1abcde
127 parameters.append(parameter) 1abcde
128 return parameters 1abcde
131def get_openapi_operation_request_body( 1abcde
132 *,
133 body_field: Optional[ModelField],
134 schema_generator: GenerateJsonSchema,
135 model_name_map: ModelNameMap,
136 field_mapping: Dict[
137 Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
138 ],
139 separate_input_output_schemas: bool = True,
140) -> Optional[Dict[str, Any]]:
141 if not body_field: 1abcde
142 return None 1abcde
143 assert isinstance(body_field, ModelField) 1abcde
144 body_schema = get_schema_from_model_field( 1abcde
145 field=body_field,
146 schema_generator=schema_generator,
147 model_name_map=model_name_map,
148 field_mapping=field_mapping,
149 separate_input_output_schemas=separate_input_output_schemas,
150 )
151 field_info = cast(Body, body_field.field_info) 1abcde
152 request_media_type = field_info.media_type 1abcde
153 required = body_field.required 1abcde
154 request_body_oai: Dict[str, Any] = {} 1abcde
155 if required: 1abcde
156 request_body_oai["required"] = required 1abcde
157 request_media_content: Dict[str, Any] = {"schema": body_schema} 1abcde
158 if field_info.openapi_examples: 1abcde
159 request_media_content["examples"] = jsonable_encoder( 1abcde
160 field_info.openapi_examples
161 )
162 elif field_info.example != Undefined: 1abcde
163 request_media_content["example"] = jsonable_encoder(field_info.example) 1abcde
164 request_body_oai["content"] = {request_media_type: request_media_content} 1abcde
165 return request_body_oai 1abcde
168def generate_operation_id( 1abcde
169 *, route: routing.APIRoute, method: str 1abcde
170) -> str: # pragma: nocover 1abcde
171 warnings.warn(
172 "fastapi.openapi.utils.generate_operation_id() was deprecated, "
173 "it is not used internally, and will be removed soon",
174 DeprecationWarning,
175 stacklevel=2,
176 )
177 if route.operation_id:
178 return route.operation_id
179 path: str = route.path_format
180 return generate_operation_id_for_path(name=route.name, path=path, method=method)
183def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: 1abcde
184 if route.summary: 1abcde
185 return route.summary 1abcde
186 return route.name.replace("_", " ").title() 1abcde
189def get_openapi_operation_metadata( 1abcde
190 *, route: routing.APIRoute, method: str, operation_ids: Set[str]
191) -> Dict[str, Any]:
192 operation: Dict[str, Any] = {} 1abcde
193 if route.tags: 1abcde
194 operation["tags"] = route.tags 1abcde
195 operation["summary"] = generate_operation_summary(route=route, method=method) 1abcde
196 if route.description: 1abcde
197 operation["description"] = route.description 1abcde
198 operation_id = route.operation_id or route.unique_id 1abcde
199 if operation_id in operation_ids: 1abcde
200 message = ( 1abcde
201 f"Duplicate Operation ID {operation_id} for function "
202 + f"{route.endpoint.__name__}"
203 )
204 file_name = getattr(route.endpoint, "__globals__", {}).get("__file__") 1abcde
205 if file_name: 1abcde
206 message += f" at {file_name}" 1abcde
207 warnings.warn(message, stacklevel=1) 1abcde
208 operation_ids.add(operation_id) 1abcde
209 operation["operationId"] = operation_id 1abcde
210 if route.deprecated: 1abcde
211 operation["deprecated"] = route.deprecated 1abcde
212 return operation 1abcde
215def get_openapi_path( 1abcde
216 *,
217 route: routing.APIRoute,
218 operation_ids: Set[str],
219 schema_generator: GenerateJsonSchema,
220 model_name_map: ModelNameMap,
221 field_mapping: Dict[
222 Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
223 ],
224 separate_input_output_schemas: bool = True,
225) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
226 path = {} 1abcde
227 security_schemes: Dict[str, Any] = {} 1abcde
228 definitions: Dict[str, Any] = {} 1abcde
229 assert route.methods is not None, "Methods must be a list" 1abcde
230 if isinstance(route.response_class, DefaultPlaceholder): 1abcde
231 current_response_class: Type[Response] = route.response_class.value 1abcde
232 else:
233 current_response_class = route.response_class 1abcde
234 assert current_response_class, "A response class is needed to generate OpenAPI" 1abcde
235 route_response_media_type: Optional[str] = current_response_class.media_type 1abcde
236 if route.include_in_schema: 1abcde
237 for method in route.methods: 1abcde
238 operation = get_openapi_operation_metadata( 1abcde
239 route=route, method=method, operation_ids=operation_ids
240 )
241 parameters: List[Dict[str, Any]] = [] 1abcde
242 flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) 1abcde
243 security_definitions, operation_security = get_openapi_security_definitions( 1abcde
244 flat_dependant=flat_dependant
245 )
246 if operation_security: 1abcde
247 operation.setdefault("security", []).extend(operation_security) 1abcde
248 if security_definitions: 1abcde
249 security_schemes.update(security_definitions) 1abcde
250 all_route_params = get_flat_params(route.dependant) 1abcde
251 operation_parameters = get_openapi_operation_parameters( 1abcde
252 all_route_params=all_route_params,
253 schema_generator=schema_generator,
254 model_name_map=model_name_map,
255 field_mapping=field_mapping,
256 separate_input_output_schemas=separate_input_output_schemas,
257 )
258 parameters.extend(operation_parameters) 1abcde
259 if parameters: 1abcde
260 all_parameters = { 1abcde
261 (param["in"], param["name"]): param for param in parameters
262 }
263 required_parameters = { 1abcde
264 (param["in"], param["name"]): param
265 for param in parameters
266 if param.get("required")
267 }
268 # Make sure required definitions of the same parameter take precedence
269 # over non-required definitions
270 all_parameters.update(required_parameters) 1abcde
271 operation["parameters"] = list(all_parameters.values()) 1abcde
272 if method in METHODS_WITH_BODY: 1abcde
273 request_body_oai = get_openapi_operation_request_body( 1abcde
274 body_field=route.body_field,
275 schema_generator=schema_generator,
276 model_name_map=model_name_map,
277 field_mapping=field_mapping,
278 separate_input_output_schemas=separate_input_output_schemas,
279 )
280 if request_body_oai: 1abcde
281 operation["requestBody"] = request_body_oai 1abcde
282 if route.callbacks: 1abcde
283 callbacks = {} 1abcde
284 for callback in route.callbacks: 1abcde
285 if isinstance(callback, routing.APIRoute): 1abcde
286 ( 1abcde
287 cb_path,
288 cb_security_schemes,
289 cb_definitions,
290 ) = get_openapi_path(
291 route=callback,
292 operation_ids=operation_ids,
293 schema_generator=schema_generator,
294 model_name_map=model_name_map,
295 field_mapping=field_mapping,
296 separate_input_output_schemas=separate_input_output_schemas,
297 )
298 callbacks[callback.name] = {callback.path: cb_path} 1abcde
299 operation["callbacks"] = callbacks 1abcde
300 if route.status_code is not None: 1abcde
301 status_code = str(route.status_code) 1abcde
302 else:
303 # It would probably make more sense for all response classes to have an
304 # explicit default status_code, and to extract it from them, instead of
305 # doing this inspection tricks, that would probably be in the future
306 # TODO: probably make status_code a default class attribute for all
307 # responses in Starlette
308 response_signature = inspect.signature(current_response_class.__init__) 1abcde
309 status_code_param = response_signature.parameters.get("status_code") 1abcde
310 if status_code_param is not None: 1abcde
311 if isinstance(status_code_param.default, int): 1abcde
312 status_code = str(status_code_param.default) 1abcde
313 operation.setdefault("responses", {}).setdefault(status_code, {})[ 1abcde
314 "description"
315 ] = route.response_description
316 if route_response_media_type and is_body_allowed_for_status_code( 1abcde
317 route.status_code
318 ):
319 response_schema = {"type": "string"} 1abcde
320 if lenient_issubclass(current_response_class, JSONResponse): 1abcde
321 if route.response_field: 1abcde
322 response_schema = get_schema_from_model_field( 1abcde
323 field=route.response_field,
324 schema_generator=schema_generator,
325 model_name_map=model_name_map,
326 field_mapping=field_mapping,
327 separate_input_output_schemas=separate_input_output_schemas,
328 )
329 else:
330 response_schema = {} 1abcde
331 operation.setdefault("responses", {}).setdefault( 1abcde
332 status_code, {}
333 ).setdefault("content", {}).setdefault(route_response_media_type, {})[
334 "schema"
335 ] = response_schema
336 if route.responses: 1abcde
337 operation_responses = operation.setdefault("responses", {}) 1abcde
338 for ( 1abcde
339 additional_status_code,
340 additional_response,
341 ) in route.responses.items():
342 process_response = additional_response.copy() 1abcde
343 process_response.pop("model", None) 1abcde
344 status_code_key = str(additional_status_code).upper() 1abcde
345 if status_code_key == "DEFAULT": 1abcde
346 status_code_key = "default" 1abcde
347 openapi_response = operation_responses.setdefault( 1abcde
348 status_code_key, {}
349 )
350 assert isinstance( 1abcde
351 process_response, dict
352 ), "An additional response must be a dict"
353 field = route.response_fields.get(additional_status_code) 1abcde
354 additional_field_schema: Optional[Dict[str, Any]] = None 1abcde
355 if field: 1abcde
356 additional_field_schema = get_schema_from_model_field( 1abcde
357 field=field,
358 schema_generator=schema_generator,
359 model_name_map=model_name_map,
360 field_mapping=field_mapping,
361 separate_input_output_schemas=separate_input_output_schemas,
362 )
363 media_type = route_response_media_type or "application/json" 1abcde
364 additional_schema = ( 1abcde
365 process_response.setdefault("content", {})
366 .setdefault(media_type, {})
367 .setdefault("schema", {})
368 )
369 deep_dict_update(additional_schema, additional_field_schema) 1abcde
370 status_text: Optional[str] = status_code_ranges.get( 1abcde
371 str(additional_status_code).upper()
372 ) or http.client.responses.get(int(additional_status_code))
373 description = ( 1abcde
374 process_response.get("description")
375 or openapi_response.get("description")
376 or status_text
377 or "Additional Response"
378 )
379 deep_dict_update(openapi_response, process_response) 1abcde
380 openapi_response["description"] = description 1abcde
381 http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) 1abcde
382 if (all_route_params or route.body_field) and not any( 1abcde
383 status in operation["responses"]
384 for status in [http422, "4XX", "default"]
385 ):
386 operation["responses"][http422] = { 1abcde
387 "description": "Validation Error",
388 "content": {
389 "application/json": {
390 "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
391 }
392 },
393 }
394 if "ValidationError" not in definitions: 1abcde
395 definitions.update( 1abcde
396 {
397 "ValidationError": validation_error_definition,
398 "HTTPValidationError": validation_error_response_definition,
399 }
400 )
401 if route.openapi_extra: 1abcde
402 deep_dict_update(operation, route.openapi_extra) 1abcde
403 path[method.lower()] = operation 1abcde
404 return path, security_schemes, definitions 1abcde
407def get_fields_from_routes( 1abcde
408 routes: Sequence[BaseRoute],
409) -> List[ModelField]:
410 body_fields_from_routes: List[ModelField] = [] 1abcde
411 responses_from_routes: List[ModelField] = [] 1abcde
412 request_fields_from_routes: List[ModelField] = [] 1abcde
413 callback_flat_models: List[ModelField] = [] 1abcde
414 for route in routes: 1abcde
415 if getattr(route, "include_in_schema", None) and isinstance( 1abcde
416 route, routing.APIRoute
417 ):
418 if route.body_field: 1abcde
419 assert isinstance( 1abcde
420 route.body_field, ModelField
421 ), "A request body must be a Pydantic Field"
422 body_fields_from_routes.append(route.body_field) 1abcde
423 if route.response_field: 1abcde
424 responses_from_routes.append(route.response_field) 1abcde
425 if route.response_fields: 1abcde
426 responses_from_routes.extend(route.response_fields.values()) 1abcde
427 if route.callbacks: 1abcde
428 callback_flat_models.extend(get_fields_from_routes(route.callbacks)) 1abcde
429 params = get_flat_params(route.dependant) 1abcde
430 request_fields_from_routes.extend(params) 1abcde
432 flat_models = callback_flat_models + list( 1abcde
433 body_fields_from_routes + responses_from_routes + request_fields_from_routes
434 )
435 return flat_models 1abcde
438def get_openapi( 1abcde
439 *,
440 title: str,
441 version: str,
442 openapi_version: str = "3.1.0",
443 summary: Optional[str] = None,
444 description: Optional[str] = None,
445 routes: Sequence[BaseRoute],
446 webhooks: Optional[Sequence[BaseRoute]] = None,
447 tags: Optional[List[Dict[str, Any]]] = None,
448 servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
449 terms_of_service: Optional[str] = None,
450 contact: Optional[Dict[str, Union[str, Any]]] = None,
451 license_info: Optional[Dict[str, Union[str, Any]]] = None,
452 separate_input_output_schemas: bool = True,
453) -> Dict[str, Any]:
454 info: Dict[str, Any] = {"title": title, "version": version} 1abcde
455 if summary: 1abcde
456 info["summary"] = summary 1abcde
457 if description: 1abcde
458 info["description"] = description 1abcde
459 if terms_of_service: 1abcde
460 info["termsOfService"] = terms_of_service 1abcde
461 if contact: 1abcde
462 info["contact"] = contact 1abcde
463 if license_info: 1abcde
464 info["license"] = license_info 1abcde
465 output: Dict[str, Any] = {"openapi": openapi_version, "info": info} 1abcde
466 if servers: 1abcde
467 output["servers"] = servers 1abcde
468 components: Dict[str, Dict[str, Any]] = {} 1abcde
469 paths: Dict[str, Dict[str, Any]] = {} 1abcde
470 webhook_paths: Dict[str, Dict[str, Any]] = {} 1abcde
471 operation_ids: Set[str] = set() 1abcde
472 all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) 1abcde
473 model_name_map = get_compat_model_name_map(all_fields) 1abcde
474 schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) 1abcde
475 field_mapping, definitions = get_definitions( 1abcde
476 fields=all_fields,
477 schema_generator=schema_generator,
478 model_name_map=model_name_map,
479 separate_input_output_schemas=separate_input_output_schemas,
480 )
481 for route in routes or []: 1abcde
482 if isinstance(route, routing.APIRoute): 1abcde
483 result = get_openapi_path( 1abcde
484 route=route,
485 operation_ids=operation_ids,
486 schema_generator=schema_generator,
487 model_name_map=model_name_map,
488 field_mapping=field_mapping,
489 separate_input_output_schemas=separate_input_output_schemas,
490 )
491 if result: 1abcde
492 path, security_schemes, path_definitions = result 1abcde
493 if path: 1abcde
494 paths.setdefault(route.path_format, {}).update(path) 1abcde
495 if security_schemes: 1abcde
496 components.setdefault("securitySchemes", {}).update( 1abcde
497 security_schemes
498 )
499 if path_definitions: 1abcde
500 definitions.update(path_definitions) 1abcde
501 for webhook in webhooks or []: 1abcde
502 if isinstance(webhook, routing.APIRoute): 1abcde
503 result = get_openapi_path( 1abcde
504 route=webhook,
505 operation_ids=operation_ids,
506 schema_generator=schema_generator,
507 model_name_map=model_name_map,
508 field_mapping=field_mapping,
509 separate_input_output_schemas=separate_input_output_schemas,
510 )
511 if result: 1abcde
512 path, security_schemes, path_definitions = result 1abcde
513 if path: 1abcde
514 webhook_paths.setdefault(webhook.path_format, {}).update(path) 1abcde
515 if security_schemes: 1abcde
516 components.setdefault("securitySchemes", {}).update( 1abcde
517 security_schemes
518 )
519 if path_definitions: 1abcde
520 definitions.update(path_definitions) 1abcde
521 if definitions: 1abcde
522 components["schemas"] = {k: definitions[k] for k in sorted(definitions)} 1abcde
523 if components: 1abcde
524 output["components"] = components 1abcde
525 output["paths"] = paths 1abcde
526 if webhook_paths: 1abcde
527 output["webhooks"] = webhook_paths 1abcde
528 if tags: 1abcde
529 output["tags"] = tags 1abcde
530 return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore 1abcde