Coverage for fastagency/api/openapi/patch_fastapi_code_generator.py: 98%

34 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-19 12:16 +0000

1import json 1ahfegibcd

2import re 1ahfegibcd

3from functools import cached_property, wraps 1ahfegibcd

4from typing import Any 1ahfegibcd

5 

6import stringcase 1ahfegibcd

7from fastapi_code_generator import __main__ as fastapi_code_generator_main 1ahfegibcd

8from fastapi_code_generator.parser import OpenAPIParser, Operation 1ahfegibcd

9 

10from ...logging import get_logger 1ahfegibcd

11 

12logger = get_logger(__name__) 1ahfegibcd

13 

14 

15def patch_parse_schema() -> None: 1ahfegibcd

16 org_parse_schema = OpenAPIParser.parse_schema 1ahfegibcd

17 

18 @wraps(org_parse_schema) 1ahfegibcd

19 def my_parse_schema(*args: Any, **kwargs: Any) -> Any: 1ahfegibcd

20 data_type = org_parse_schema(*args, **kwargs) 1afegbcd

21 if data_type.reference and data_type.reference.duplicate_name: 1afegbcd

22 data_type.reference.name = data_type.reference.duplicate_name 1aebcd

23 return data_type 1afegbcd

24 

25 OpenAPIParser.parse_schema = my_parse_schema 1ahfegibcd

26 logger.info("Patched OpenAPIParser.parse_schema") 1ahfegibcd

27 

28 

29def patch_function_name_parsing() -> None: 1ahfegibcd

30 def function_name(self: Operation) -> str: 1ahfegibcd

31 if self.operationId: 1afegbcd

32 name: str = self.operationId.replace("/", "_") 1afegbcd

33 else: 

34 path = re.sub(r"/{|/", "_", self.snake_case_path).replace("}", "") 1abcd

35 name = f"{self.type}{path}" 1abcd

36 return stringcase.snakecase(name) # type: ignore[no-any-return] 1afegbcd

37 

38 Operation.function_name = cached_property(function_name) 1ahfegibcd

39 Operation.function_name.__set_name__(Operation, "function_name") 1ahfegibcd

40 

41 logger.info("Patched Operation.function_name") 1ahfegibcd

42 

43 

44def patch_generate_code() -> None: 1ahfegibcd

45 # Save reference to the original generate_code function 

46 org_generate_code = fastapi_code_generator_main.generate_code 1ahfegibcd

47 

48 @wraps(org_generate_code) 1ahfegibcd

49 def patched_generate_code(*args: Any, **kwargs: Any) -> Any: 1ahfegibcd

50 try: 1afegbcd

51 input_text: str = kwargs["input_text"] 1afegbcd

52 

53 json_spec = json.loads(input_text) 1afegbcd

54 

55 schemas_with_dots = sorted( 1afegbcd

56 [ 

57 name 

58 for name in json_spec.get("components", {}).get("schemas", {}) 

59 if "." in name 

60 ], 

61 key=len, 

62 reverse=True, # Sort by length in descending order 

63 ) 

64 

65 for schema_name in schemas_with_dots: 1afegbcd

66 new_schema_name = schema_name.replace(".", "_") 1aebcd

67 input_text = input_text.replace(schema_name, new_schema_name) 1aebcd

68 

69 kwargs["input_text"] = input_text 1afegbcd

70 

71 except Exception as e: 

72 logger.info( 

73 f"Patched fastapi_code_generator.__main__.generate_code raised: {e}, passing untouched arguments to original generate_code" 

74 ) 

75 

76 return org_generate_code(*args, **kwargs) 1afegbcd

77 

78 fastapi_code_generator_main.generate_code = patched_generate_code 1ahfegibcd

79 

80 logger.info("Patched fastapi_code_generator.__main__.generate_code") 1ahfegibcd