Coverage for faststream / _internal / fastapi / get_dependant.py: 92%
29 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import inspect
2from collections.abc import Callable, Iterable
3from typing import TYPE_CHECKING, Annotated, Any, cast, get_args, get_origin
5from fast_depends.library.serializer import OptionItem
6from fast_depends.utils import get_typed_annotation
7from fastapi import params
8from fastapi.dependencies.utils import (
9 get_dependant,
10 get_parameterless_sub_dependant,
11 get_typed_signature,
12)
14from faststream._internal._compat import PYDANTIC_V2
16if TYPE_CHECKING:
17 from fastapi.dependencies import models
20def get_fastapi_dependant(
21 orig_call: Callable[..., Any],
22 dependencies: Iterable["params.Depends"],
23) -> Any:
24 """Generate FastStream-Compatible FastAPI Dependant object."""
25 dependent = get_fastapi_native_dependant(
26 orig_call=orig_call,
27 dependencies=dependencies,
28 )
30 return _patch_fastapi_dependent(dependent)
33def get_fastapi_native_dependant(
34 orig_call: Callable[..., Any],
35 dependencies: Iterable["params.Depends"],
36) -> Any:
37 """Generate native FastAPI Dependant."""
38 dependent = get_dependant(
39 path="",
40 call=orig_call,
41 )
43 for depends in list(dependencies)[::-1]:
44 dependent.dependencies.insert(
45 0,
46 get_parameterless_sub_dependant(depends=depends, path=""),
47 )
49 return dependent
52def _patch_fastapi_dependent(dependant: "models.Dependant") -> "models.Dependant":
53 """Patch FastAPI by adding fields for AsyncAPI schema generation."""
54 from pydantic import Field, create_model # FastAPI always has pydantic
56 from faststream._internal._compat import PydanticUndefined
58 params = dependant.query_params + dependant.body_params
60 for d in dependant.dependencies:
61 params.extend(d.query_params + d.body_params)
63 params_unique = {}
65 call = dependant.call
66 if is_faststream_decorated(call):
67 call = getattr(call, "__wrapped__", call)
68 globalns = getattr(call, "__globals__", {})
70 for p in params:
71 if p.name not in params_unique:
72 info: Any = p.field_info if PYDANTIC_V2 else p
74 field_data = {
75 "default": ... if info.default is PydanticUndefined else info.default,
76 "default_factory": info.default_factory,
77 "alias": info.alias,
78 }
80 if PYDANTIC_V2:
81 from pydantic.fields import FieldInfo
83 info = cast("FieldInfo", info)
85 field_data.update(
86 {
87 "title": info.title,
88 "alias_priority": info.alias_priority,
89 "validation_alias": info.validation_alias,
90 "serialization_alias": info.serialization_alias,
91 "description": info.description,
92 "discriminator": info.discriminator,
93 "examples": info.examples,
94 "exclude": info.exclude,
95 "json_schema_extra": info.json_schema_extra,
96 },
97 )
99 f = next(
100 filter(
101 lambda x: isinstance(x, FieldInfo),
102 p.field_info.metadata or (),
103 ),
104 Field(**field_data), # type: ignore[pydantic-field,unused-ignore]
105 )
107 else:
108 from pydantic.fields import ModelField # type: ignore[attr-defined]
110 info = cast("ModelField", info)
112 field_data.update(
113 {
114 "title": info.field_info.title,
115 "description": info.field_info.description,
116 "discriminator": info.field_info.discriminator,
117 "exclude": info.field_info.exclude,
118 "gt": info.field_info.gt,
119 "ge": info.field_info.ge,
120 "lt": info.field_info.lt,
121 "le": info.field_info.le,
122 },
123 )
124 f = Field(**field_data) # type: ignore[pydantic-field,unused-ignore]
126 params_unique[p.name] = (
127 get_typed_annotation(info.annotation, globalns, {}),
128 f,
129 )
131 dependant.model = create_model( # type: ignore[attr-defined]
132 getattr(call, "__name__", type(call).__name__),
133 )
135 dependant.custom_fields = {} # type: ignore[attr-defined]
136 dependant.flat_params = [ # type: ignore[attr-defined]
137 OptionItem(field_name=name, field_type=type_, default_value=default)
138 for name, (type_, default) in params_unique.items()
139 ]
141 return dependant
144def has_forbidden_types(
145 orig_call: Callable[..., Any],
146 forbidden_types: tuple[Any, ...],
147) -> set[Any]:
148 """Check if faststream.Depends is used in the handler."""
149 endpoint_signature = get_typed_signature(orig_call)
150 signature_params = endpoint_signature.parameters
152 founded_types = set()
154 for param in signature_params.values():
155 ann = param.annotation
157 founded_buffer = set()
158 has_fastapi_depends = False
159 if ann is not inspect.Signature.empty and get_origin(ann) is Annotated:
160 annotated_args = get_args(ann)
162 for arg in annotated_args[1:]:
163 if isinstance(arg, params.Depends):
164 has_fastapi_depends = True
165 continue
167 for t in forbidden_types:
168 if isinstance(arg, t):
169 founded_buffer.add(t)
171 if isinstance(param.default, params.Depends):
172 has_fastapi_depends = True
173 continue
175 for t in forbidden_types:
176 if isinstance(param.default, t):
177 founded_buffer.add(t)
179 if not has_fastapi_depends:
180 founded_types |= founded_buffer
182 return founded_types
185FASTSTREAM_FASTAPI_PLUGIN_DECORATOR_MARKER = "__faststream_consumer__"
188def is_faststream_decorated(func: object) -> bool:
189 return getattr(func, FASTSTREAM_FASTAPI_PLUGIN_DECORATOR_MARKER, False)
192def mark_faststream_decorated(func: object) -> None:
193 setattr(func, FASTSTREAM_FASTAPI_PLUGIN_DECORATOR_MARKER, True)