Coverage for fastagency/api/openapi/patch_fastapi_code_generator.py: 98%
34 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
1import json 1ahfegibcd
2import re 1ahfegibcd
3from functools import cached_property, wraps 1ahfegibcd
4from typing import Any 1ahfegibcd
6import stringcase 1ahfegibcd
7from fastapi_code_generator import __main__ as fastapi_code_generator_main 1ahfegibcd
8from fastapi_code_generator.parser import OpenAPIParser, Operation 1ahfegibcd
10from ...logging import get_logger 1ahfegibcd
12logger = get_logger(__name__) 1ahfegibcd
15def patch_parse_schema() -> None: 1ahfegibcd
16 org_parse_schema = OpenAPIParser.parse_schema 1ahfegibcd
18 @wraps(org_parse_schema) 1ahfegibcd
19 def my_parse_schema(*args: Any, **kwargs: Any) -> Any: 1ahfegibcd
20 data_type = org_parse_schema(*args, **kwargs) 1afegbcd
21 if data_type.reference and data_type.reference.duplicate_name: 1afegbcd
22 data_type.reference.name = data_type.reference.duplicate_name 1aebcd
23 return data_type 1afegbcd
25 OpenAPIParser.parse_schema = my_parse_schema 1ahfegibcd
26 logger.info("Patched OpenAPIParser.parse_schema") 1ahfegibcd
29def patch_function_name_parsing() -> None: 1ahfegibcd
30 def function_name(self: Operation) -> str: 1ahfegibcd
31 if self.operationId: 1afegbcd
32 name: str = self.operationId.replace("/", "_") 1afegbcd
33 else:
34 path = re.sub(r"/{|/", "_", self.snake_case_path).replace("}", "") 1abcd
35 name = f"{self.type}{path}" 1abcd
36 return stringcase.snakecase(name) # type: ignore[no-any-return] 1afegbcd
38 Operation.function_name = cached_property(function_name) 1ahfegibcd
39 Operation.function_name.__set_name__(Operation, "function_name") 1ahfegibcd
41 logger.info("Patched Operation.function_name") 1ahfegibcd
44def patch_generate_code() -> None: 1ahfegibcd
45 # Save reference to the original generate_code function
46 org_generate_code = fastapi_code_generator_main.generate_code 1ahfegibcd
48 @wraps(org_generate_code) 1ahfegibcd
49 def patched_generate_code(*args: Any, **kwargs: Any) -> Any: 1ahfegibcd
50 try: 1afegbcd
51 input_text: str = kwargs["input_text"] 1afegbcd
53 json_spec = json.loads(input_text) 1afegbcd
55 schemas_with_dots = sorted( 1afegbcd
56 [
57 name
58 for name in json_spec.get("components", {}).get("schemas", {})
59 if "." in name
60 ],
61 key=len,
62 reverse=True, # Sort by length in descending order
63 )
65 for schema_name in schemas_with_dots: 1afegbcd
66 new_schema_name = schema_name.replace(".", "_") 1aebcd
67 input_text = input_text.replace(schema_name, new_schema_name) 1aebcd
69 kwargs["input_text"] = input_text 1afegbcd
71 except Exception as e:
72 logger.info(
73 f"Patched fastapi_code_generator.__main__.generate_code raised: {e}, passing untouched arguments to original generate_code"
74 )
76 return org_generate_code(*args, **kwargs) 1afegbcd
78 fastapi_code_generator_main.generate_code = patched_generate_code 1ahfegibcd
80 logger.info("Patched fastapi_code_generator.__main__.generate_code") 1ahfegibcd