Coverage for faststream / _internal / fastapi / get_dependant.py: 92%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import inspect 

2from collections.abc import Callable, Iterable 

3from typing import TYPE_CHECKING, Annotated, Any, cast, get_args, get_origin 

4 

5from fast_depends.library.serializer import OptionItem 

6from fast_depends.utils import get_typed_annotation 

7from fastapi import params 

8from fastapi.dependencies.utils import ( 

9 get_dependant, 

10 get_parameterless_sub_dependant, 

11 get_typed_signature, 

12) 

13 

14from faststream._internal._compat import PYDANTIC_V2 

15 

16if TYPE_CHECKING: 

17 from fastapi.dependencies import models 

18 

19 

20def get_fastapi_dependant( 

21 orig_call: Callable[..., Any], 

22 dependencies: Iterable["params.Depends"], 

23) -> Any: 

24 """Generate FastStream-Compatible FastAPI Dependant object.""" 

25 dependent = get_fastapi_native_dependant( 

26 orig_call=orig_call, 

27 dependencies=dependencies, 

28 ) 

29 

30 return _patch_fastapi_dependent(dependent) 

31 

32 

33def get_fastapi_native_dependant( 

34 orig_call: Callable[..., Any], 

35 dependencies: Iterable["params.Depends"], 

36) -> Any: 

37 """Generate native FastAPI Dependant.""" 

38 dependent = get_dependant( 

39 path="", 

40 call=orig_call, 

41 ) 

42 

43 for depends in list(dependencies)[::-1]: 

44 dependent.dependencies.insert( 

45 0, 

46 get_parameterless_sub_dependant(depends=depends, path=""), 

47 ) 

48 

49 return dependent 

50 

51 

52def _patch_fastapi_dependent(dependant: "models.Dependant") -> "models.Dependant": 

53 """Patch FastAPI by adding fields for AsyncAPI schema generation.""" 

54 from pydantic import Field, create_model # FastAPI always has pydantic 

55 

56 from faststream._internal._compat import PydanticUndefined 

57 

58 params = dependant.query_params + dependant.body_params 

59 

60 for d in dependant.dependencies: 

61 params.extend(d.query_params + d.body_params) 

62 

63 params_unique = {} 

64 

65 call = dependant.call 

66 if is_faststream_decorated(call): 

67 call = getattr(call, "__wrapped__", call) 

68 globalns = getattr(call, "__globals__", {}) 

69 

70 for p in params: 

71 if p.name not in params_unique: 

72 info: Any = p.field_info if PYDANTIC_V2 else p 

73 

74 field_data = { 

75 "default": ... if info.default is PydanticUndefined else info.default, 

76 "default_factory": info.default_factory, 

77 "alias": info.alias, 

78 } 

79 

80 if PYDANTIC_V2: 

81 from pydantic.fields import FieldInfo 

82 

83 info = cast("FieldInfo", info) 

84 

85 field_data.update( 

86 { 

87 "title": info.title, 

88 "alias_priority": info.alias_priority, 

89 "validation_alias": info.validation_alias, 

90 "serialization_alias": info.serialization_alias, 

91 "description": info.description, 

92 "discriminator": info.discriminator, 

93 "examples": info.examples, 

94 "exclude": info.exclude, 

95 "json_schema_extra": info.json_schema_extra, 

96 }, 

97 ) 

98 

99 f = next( 

100 filter( 

101 lambda x: isinstance(x, FieldInfo), 

102 p.field_info.metadata or (), 

103 ), 

104 Field(**field_data), # type: ignore[pydantic-field,unused-ignore] 

105 ) 

106 

107 else: 

108 from pydantic.fields import ModelField # type: ignore[attr-defined] 

109 

110 info = cast("ModelField", info) 

111 

112 field_data.update( 

113 { 

114 "title": info.field_info.title, 

115 "description": info.field_info.description, 

116 "discriminator": info.field_info.discriminator, 

117 "exclude": info.field_info.exclude, 

118 "gt": info.field_info.gt, 

119 "ge": info.field_info.ge, 

120 "lt": info.field_info.lt, 

121 "le": info.field_info.le, 

122 }, 

123 ) 

124 f = Field(**field_data) # type: ignore[pydantic-field,unused-ignore] 

125 

126 params_unique[p.name] = ( 

127 get_typed_annotation(info.annotation, globalns, {}), 

128 f, 

129 ) 

130 

131 dependant.model = create_model( # type: ignore[attr-defined] 

132 getattr(call, "__name__", type(call).__name__), 

133 ) 

134 

135 dependant.custom_fields = {} # type: ignore[attr-defined] 

136 dependant.flat_params = [ # type: ignore[attr-defined] 

137 OptionItem(field_name=name, field_type=type_, default_value=default) 

138 for name, (type_, default) in params_unique.items() 

139 ] 

140 

141 return dependant 

142 

143 

144def has_forbidden_types( 

145 orig_call: Callable[..., Any], 

146 forbidden_types: tuple[Any, ...], 

147) -> set[Any]: 

148 """Check if faststream.Depends is used in the handler.""" 

149 endpoint_signature = get_typed_signature(orig_call) 

150 signature_params = endpoint_signature.parameters 

151 

152 founded_types = set() 

153 

154 for param in signature_params.values(): 

155 ann = param.annotation 

156 

157 founded_buffer = set() 

158 has_fastapi_depends = False 

159 if ann is not inspect.Signature.empty and get_origin(ann) is Annotated: 

160 annotated_args = get_args(ann) 

161 

162 for arg in annotated_args[1:]: 

163 if isinstance(arg, params.Depends): 

164 has_fastapi_depends = True 

165 continue 

166 

167 for t in forbidden_types: 

168 if isinstance(arg, t): 

169 founded_buffer.add(t) 

170 

171 if isinstance(param.default, params.Depends): 

172 has_fastapi_depends = True 

173 continue 

174 

175 for t in forbidden_types: 

176 if isinstance(param.default, t): 

177 founded_buffer.add(t) 

178 

179 if not has_fastapi_depends: 

180 founded_types |= founded_buffer 

181 

182 return founded_types 

183 

184 

185FASTSTREAM_FASTAPI_PLUGIN_DECORATOR_MARKER = "__faststream_consumer__" 

186 

187 

188def is_faststream_decorated(func: object) -> bool: 

189 return getattr(func, FASTSTREAM_FASTAPI_PLUGIN_DECORATOR_MARKER, False) 

190 

191 

192def mark_faststream_decorated(func: object) -> None: 

193 setattr(func, FASTSTREAM_FASTAPI_PLUGIN_DECORATOR_MARKER, True)