Coverage for fastagency/api/openapi/patch_datamodel_code_generator.py: 89%

48 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-19 12:16 +0000

1from typing import Union 1ahefgibcd

2 

3from datamodel_code_generator.imports import ( 1ahefgibcd

4 IMPORT_LITERAL, 

5 IMPORT_LITERAL_BACKPORT, 

6 Imports, 

7) 

8from datamodel_code_generator.model import pydantic as pydantic_model 1ahefgibcd

9from datamodel_code_generator.model import pydantic_v2 as pydantic_model_v2 1ahefgibcd

10from datamodel_code_generator.model.base import ( 1ahefgibcd

11 DataModel, 

12) 

13from datamodel_code_generator.reference import Reference 1ahefgibcd

14 

15# from datamodel_code_generator.parser import base 

16from fastapi_code_generator.parser import OpenAPIParser 1ahefgibcd

17 

18from ...logging import get_logger 1ahefgibcd

19 

20logger = get_logger(__name__) 1ahefgibcd

21 

22# Save the original method before patching 

23original_apply_discriminator_type = OpenAPIParser._Parser__apply_discriminator_type 1ahefgibcd

24 

25 

26def patch_apply_discriminator_type() -> None: # noqa: C901 1ahefgibcd

27 def __apply_discriminator_type_patched( # noqa: C901 1ahefgibcd

28 self: OpenAPIParser, 

29 models: list[DataModel], 

30 imports: Imports, 

31 ) -> None: 

32 for model in models: 1aefgbcd

33 for field in model.fields: 1aefgbcd

34 discriminator = field.extras.get("discriminator") 1aefgbcd

35 if not discriminator or not isinstance(discriminator, dict): 1aefgbcd

36 continue 1bcd

37 property_name = discriminator.get("propertyName") 1abcd

38 if not property_name: # pragma: no cover 1abcd

39 continue 

40 mapping = discriminator.get("mapping", {}) 1abcd

41 for data_type in field.data_type.data_types: 1abcd

42 if not data_type.reference: # pragma: no cover 1abcd

43 continue 

44 discriminator_model = data_type.reference.source 1abcd

45 

46 if not isinstance( # pragma: no cover 1abcd

47 discriminator_model, 

48 (pydantic_model.BaseModel, pydantic_model_v2.BaseModel), 

49 ): 

50 continue # pragma: no cover 

51 

52 type_names: list[str] = [] 1abcd

53 

54 def check_paths( 1abcd

55 model: Union[ 

56 pydantic_model.BaseModel, 

57 pydantic_model_v2.BaseModel, 

58 Reference, 

59 ], 

60 mapping: dict[str, str], 

61 type_names: list[str] = type_names, 

62 ) -> None: 

63 """Helper function to validate paths for a given model.""" 

64 for name, path in mapping.items(): 1abcd

65 if ( 1a

66 model.path.split("#/")[-1] != path.split("#/")[-1] 

67 ) and ( 

68 path.startswith("#/") 

69 or model.path[:-1] != path.split("/")[-1] 

70 ): 

71 t_path = path[str(path).find("/") + 1 :] 1abcd

72 t_disc = model.path[: str(model.path).find("#")].lstrip( # noqa: B005 1abcd

73 "../" 

74 ) 

75 t_disc_2 = "/".join(t_disc.split("/")[1:]) 1abcd

76 if t_path != t_disc and t_path != t_disc_2: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was always true1abcd

77 continue 1abcd

78 type_names.append(name) 1abcd

79 

80 # Check the main discriminator model path 

81 if mapping: 81 ↛ 89line 81 didn't jump to line 89 because the condition on line 81 was always true1abcd

82 check_paths(discriminator_model, mapping) 1abcd

83 

84 # Check the base_classes if they exist 

85 for base_class in discriminator_model.base_classes: 1abcd

86 if base_class.reference and base_class.reference.path: 1abcd

87 check_paths(base_class.reference, mapping) 1abcd

88 else: 

89 type_names = [discriminator_model.path.split("/")[-1]] 

90 if not type_names: # pragma: no cover 1abcd

91 raise RuntimeError( 

92 f"Discriminator type is not found. {data_type.reference.path}" 

93 ) 

94 has_one_literal = False 1abcd

95 for discriminator_field in discriminator_model.fields: 1abcd

96 if ( 1abcd

97 discriminator_field.original_name 

98 or discriminator_field.name 

99 ) != property_name: 

100 continue 1abcd

101 literals = discriminator_field.data_type.literals 1abcd

102 if ( 102 ↛ 107line 102 didn't jump to line 107 because the condition on line 102 was never true1abc

103 len(literals) == 1 and literals[0] == type_names[0] 

104 if type_names 

105 else None 

106 ): 

107 has_one_literal = True 

108 continue 

109 for ( 1abc

110 field_data_type 

111 ) in discriminator_field.data_type.all_data_types: 

112 if field_data_type.reference: # pragma: no cover 1abcd

113 field_data_type.remove_reference() 

114 discriminator_field.data_type = self.data_type( 1abcd

115 literals=type_names 

116 ) 

117 discriminator_field.data_type.parent = discriminator_field 1abcd

118 discriminator_field.required = True 1abcd

119 imports.append(discriminator_field.imports) 1abcd

120 has_one_literal = True 1abcd

121 if not has_one_literal: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true1abcd

122 discriminator_model.fields.append( 

123 self.data_model_field_type( 

124 name=property_name, 

125 data_type=self.data_type(literals=type_names), 

126 required=True, 

127 ) 

128 ) 

129 literal = ( 1abcd

130 IMPORT_LITERAL 

131 if self.target_python_version.has_literal_type 

132 else IMPORT_LITERAL_BACKPORT 

133 ) 

134 has_imported_literal = any( 1abcd

135 literal == import_ # type: ignore [comparison-overlap] 

136 for import_ in imports 

137 ) 

138 if has_imported_literal: # pragma: no cover 1abcd

139 imports.append(literal) 

140 

141 # Patch the method using the exact mangled name 

142 OpenAPIParser._Parser__apply_discriminator_type = __apply_discriminator_type_patched 1ahefgibcd

143 

144 logger.info("Patched Parser.__apply_discriminator_type,") 1ahefgibcd