Coverage for fastagency/api/openapi/patch_datamodel_code_generator.py: 89%
48 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
1from typing import Union 1ahefgibcd
3from datamodel_code_generator.imports import ( 1ahefgibcd
4 IMPORT_LITERAL,
5 IMPORT_LITERAL_BACKPORT,
6 Imports,
7)
8from datamodel_code_generator.model import pydantic as pydantic_model 1ahefgibcd
9from datamodel_code_generator.model import pydantic_v2 as pydantic_model_v2 1ahefgibcd
10from datamodel_code_generator.model.base import ( 1ahefgibcd
11 DataModel,
12)
13from datamodel_code_generator.reference import Reference 1ahefgibcd
15# from datamodel_code_generator.parser import base
16from fastapi_code_generator.parser import OpenAPIParser 1ahefgibcd
18from ...logging import get_logger 1ahefgibcd
20logger = get_logger(__name__) 1ahefgibcd
22# Save the original method before patching
23original_apply_discriminator_type = OpenAPIParser._Parser__apply_discriminator_type 1ahefgibcd
26def patch_apply_discriminator_type() -> None: # noqa: C901 1ahefgibcd
27 def __apply_discriminator_type_patched( # noqa: C901 1ahefgibcd
28 self: OpenAPIParser,
29 models: list[DataModel],
30 imports: Imports,
31 ) -> None:
32 for model in models: 1aefgbcd
33 for field in model.fields: 1aefgbcd
34 discriminator = field.extras.get("discriminator") 1aefgbcd
35 if not discriminator or not isinstance(discriminator, dict): 1aefgbcd
36 continue 1bcd
37 property_name = discriminator.get("propertyName") 1abcd
38 if not property_name: # pragma: no cover 1abcd
39 continue
40 mapping = discriminator.get("mapping", {}) 1abcd
41 for data_type in field.data_type.data_types: 1abcd
42 if not data_type.reference: # pragma: no cover 1abcd
43 continue
44 discriminator_model = data_type.reference.source 1abcd
46 if not isinstance( # pragma: no cover 1abcd
47 discriminator_model,
48 (pydantic_model.BaseModel, pydantic_model_v2.BaseModel),
49 ):
50 continue # pragma: no cover
52 type_names: list[str] = [] 1abcd
54 def check_paths( 1abcd
55 model: Union[
56 pydantic_model.BaseModel,
57 pydantic_model_v2.BaseModel,
58 Reference,
59 ],
60 mapping: dict[str, str],
61 type_names: list[str] = type_names,
62 ) -> None:
63 """Helper function to validate paths for a given model."""
64 for name, path in mapping.items(): 1abcd
65 if ( 1a
66 model.path.split("#/")[-1] != path.split("#/")[-1]
67 ) and (
68 path.startswith("#/")
69 or model.path[:-1] != path.split("/")[-1]
70 ):
71 t_path = path[str(path).find("/") + 1 :] 1abcd
72 t_disc = model.path[: str(model.path).find("#")].lstrip( # noqa: B005 1abcd
73 "../"
74 )
75 t_disc_2 = "/".join(t_disc.split("/")[1:]) 1abcd
76 if t_path != t_disc and t_path != t_disc_2: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was always true1abcd
77 continue 1abcd
78 type_names.append(name) 1abcd
80 # Check the main discriminator model path
81 if mapping: 81 ↛ 89line 81 didn't jump to line 89 because the condition on line 81 was always true1abcd
82 check_paths(discriminator_model, mapping) 1abcd
84 # Check the base_classes if they exist
85 for base_class in discriminator_model.base_classes: 1abcd
86 if base_class.reference and base_class.reference.path: 1abcd
87 check_paths(base_class.reference, mapping) 1abcd
88 else:
89 type_names = [discriminator_model.path.split("/")[-1]]
90 if not type_names: # pragma: no cover 1abcd
91 raise RuntimeError(
92 f"Discriminator type is not found. {data_type.reference.path}"
93 )
94 has_one_literal = False 1abcd
95 for discriminator_field in discriminator_model.fields: 1abcd
96 if ( 1abcd
97 discriminator_field.original_name
98 or discriminator_field.name
99 ) != property_name:
100 continue 1abcd
101 literals = discriminator_field.data_type.literals 1abcd
102 if ( 102 ↛ 107line 102 didn't jump to line 107 because the condition on line 102 was never true1abc
103 len(literals) == 1 and literals[0] == type_names[0]
104 if type_names
105 else None
106 ):
107 has_one_literal = True
108 continue
109 for ( 1abc
110 field_data_type
111 ) in discriminator_field.data_type.all_data_types:
112 if field_data_type.reference: # pragma: no cover 1abcd
113 field_data_type.remove_reference()
114 discriminator_field.data_type = self.data_type( 1abcd
115 literals=type_names
116 )
117 discriminator_field.data_type.parent = discriminator_field 1abcd
118 discriminator_field.required = True 1abcd
119 imports.append(discriminator_field.imports) 1abcd
120 has_one_literal = True 1abcd
121 if not has_one_literal: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true1abcd
122 discriminator_model.fields.append(
123 self.data_model_field_type(
124 name=property_name,
125 data_type=self.data_type(literals=type_names),
126 required=True,
127 )
128 )
129 literal = ( 1abcd
130 IMPORT_LITERAL
131 if self.target_python_version.has_literal_type
132 else IMPORT_LITERAL_BACKPORT
133 )
134 has_imported_literal = any( 1abcd
135 literal == import_ # type: ignore [comparison-overlap]
136 for import_ in imports
137 )
138 if has_imported_literal: # pragma: no cover 1abcd
139 imports.append(literal)
141 # Patch the method using the exact mangled name
142 OpenAPIParser._Parser__apply_discriminator_type = __apply_discriminator_type_patched 1ahefgibcd
144 logger.info("Patched Parser.__apply_discriminator_type,") 1ahefgibcd