Coverage for fastapi/dependencies/utils.py: 100%
408 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-08 03:53 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-08 03:53 +0000
1import inspect 1abcde
2from contextlib import AsyncExitStack, contextmanager 1abcde
3from copy import copy, deepcopy 1abcde
4from typing import ( 1abcde
5 Any,
6 Callable,
7 Coroutine,
8 Dict,
9 ForwardRef,
10 List,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Type,
16 Union,
17 cast,
18)
20import anyio 1abcde
21from fastapi import params 1abcde
22from fastapi._compat import ( 1abcde
23 PYDANTIC_V2,
24 ErrorWrapper,
25 ModelField,
26 Required,
27 Undefined,
28 _regenerate_error_with_loc,
29 copy_field_info,
30 create_body_model,
31 evaluate_forwardref,
32 field_annotation_is_scalar,
33 get_annotation_from_field_info,
34 get_missing_field_error,
35 is_bytes_field,
36 is_bytes_sequence_field,
37 is_scalar_field,
38 is_scalar_sequence_field,
39 is_sequence_field,
40 is_uploadfile_or_nonable_uploadfile_annotation,
41 is_uploadfile_sequence_annotation,
42 lenient_issubclass,
43 sequence_types,
44 serialize_sequence_value,
45 value_is_sequence,
46)
47from fastapi.background import BackgroundTasks 1abcde
48from fastapi.concurrency import ( 1abcde
49 asynccontextmanager,
50 contextmanager_in_threadpool,
51)
52from fastapi.dependencies.models import Dependant, SecurityRequirement 1abcde
53from fastapi.logger import logger 1abcde
54from fastapi.security.base import SecurityBase 1abcde
55from fastapi.security.oauth2 import OAuth2, SecurityScopes 1abcde
56from fastapi.security.open_id_connect_url import OpenIdConnect 1abcde
57from fastapi.utils import create_response_field, get_path_param_names 1abcde
58from pydantic.fields import FieldInfo 1abcde
59from starlette.background import BackgroundTasks as StarletteBackgroundTasks 1abcde
60from starlette.concurrency import run_in_threadpool 1abcde
61from starlette.datastructures import FormData, Headers, QueryParams, UploadFile 1abcde
62from starlette.requests import HTTPConnection, Request 1abcde
63from starlette.responses import Response 1abcde
64from starlette.websockets import WebSocket 1abcde
65from typing_extensions import Annotated, get_args, get_origin 1abcde
67multipart_not_installed_error = ( 1abcde
68 'Form data requires "python-multipart" to be installed. \n'
69 'You can install "python-multipart" with: \n\n'
70 "pip install python-multipart\n"
71)
72multipart_incorrect_install_error = ( 1abcde
73 'Form data requires "python-multipart" to be installed. '
74 'It seems you installed "multipart" instead. \n'
75 'You can remove "multipart" with: \n\n'
76 "pip uninstall multipart\n\n"
77 'And then install "python-multipart" with: \n\n'
78 "pip install python-multipart\n"
79)
82def check_file_field(field: ModelField) -> None: 1abcde
83 field_info = field.field_info 1abcde
84 if isinstance(field_info, params.Form): 1abcde
85 try: 1abcde
86 # __version__ is available in both multiparts, and can be mocked
87 from multipart import __version__ # type: ignore 1abcde
89 assert __version__ 1abcde
90 try: 1abcde
91 # parse_options_header is only available in the right multipart
92 from multipart.multipart import parse_options_header # type: ignore 1abcde
94 assert parse_options_header 1abcde
95 except ImportError: 1abcde
96 logger.error(multipart_incorrect_install_error) 1abcde
97 raise RuntimeError(multipart_incorrect_install_error) from None 1abcde
98 except ImportError: 1abcde
99 logger.error(multipart_not_installed_error) 1abcde
100 raise RuntimeError(multipart_not_installed_error) from None 1abcde
103def get_param_sub_dependant( 1abcde
104 *,
105 param_name: str,
106 depends: params.Depends,
107 path: str,
108 security_scopes: Optional[List[str]] = None,
109) -> Dependant:
110 assert depends.dependency 1abcde
111 return get_sub_dependant( 1abcde
112 depends=depends,
113 dependency=depends.dependency,
114 path=path,
115 name=param_name,
116 security_scopes=security_scopes,
117 )
120def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: 1abcde
121 assert callable( 1abcde
122 depends.dependency
123 ), "A parameter-less dependency must have a callable dependency"
124 return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) 1abcde
127def get_sub_dependant( 1abcde
128 *,
129 depends: params.Depends,
130 dependency: Callable[..., Any],
131 path: str,
132 name: Optional[str] = None,
133 security_scopes: Optional[List[str]] = None,
134) -> Dependant:
135 security_requirement = None 1abcde
136 security_scopes = security_scopes or [] 1abcde
137 if isinstance(depends, params.Security): 1abcde
138 dependency_scopes = depends.scopes 1abcde
139 security_scopes.extend(dependency_scopes) 1abcde
140 if isinstance(dependency, SecurityBase): 1abcde
141 use_scopes: List[str] = [] 1abcde
142 if isinstance(dependency, (OAuth2, OpenIdConnect)): 1abcde
143 use_scopes = security_scopes 1abcde
144 security_requirement = SecurityRequirement( 1abcde
145 security_scheme=dependency, scopes=use_scopes
146 )
147 sub_dependant = get_dependant( 1abcde
148 path=path,
149 call=dependency,
150 name=name,
151 security_scopes=security_scopes,
152 use_cache=depends.use_cache,
153 )
154 if security_requirement: 1abcde
155 sub_dependant.security_requirements.append(security_requirement) 1abcde
156 return sub_dependant 1abcde
159CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] 1abcde
162def get_flat_dependant( 1abcde
163 dependant: Dependant,
164 *,
165 skip_repeats: bool = False,
166 visited: Optional[List[CacheKey]] = None,
167) -> Dependant:
168 if visited is None: 1abcde
169 visited = [] 1abcde
170 visited.append(dependant.cache_key) 1abcde
172 flat_dependant = Dependant( 1abcde
173 path_params=dependant.path_params.copy(),
174 query_params=dependant.query_params.copy(),
175 header_params=dependant.header_params.copy(),
176 cookie_params=dependant.cookie_params.copy(),
177 body_params=dependant.body_params.copy(),
178 security_schemes=dependant.security_requirements.copy(),
179 use_cache=dependant.use_cache,
180 path=dependant.path,
181 )
182 for sub_dependant in dependant.dependencies: 1abcde
183 if skip_repeats and sub_dependant.cache_key in visited: 1abcde
184 continue 1abcde
185 flat_sub = get_flat_dependant( 1abcde
186 sub_dependant, skip_repeats=skip_repeats, visited=visited
187 )
188 flat_dependant.path_params.extend(flat_sub.path_params) 1abcde
189 flat_dependant.query_params.extend(flat_sub.query_params) 1abcde
190 flat_dependant.header_params.extend(flat_sub.header_params) 1abcde
191 flat_dependant.cookie_params.extend(flat_sub.cookie_params) 1abcde
192 flat_dependant.body_params.extend(flat_sub.body_params) 1abcde
193 flat_dependant.security_requirements.extend(flat_sub.security_requirements) 1abcde
194 return flat_dependant 1abcde
197def get_flat_params(dependant: Dependant) -> List[ModelField]: 1abcde
198 flat_dependant = get_flat_dependant(dependant, skip_repeats=True) 1abcde
199 return ( 1abcde
200 flat_dependant.path_params
201 + flat_dependant.query_params
202 + flat_dependant.header_params
203 + flat_dependant.cookie_params
204 )
207def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: 1abcde
208 signature = inspect.signature(call) 1abcde
209 globalns = getattr(call, "__globals__", {}) 1abcde
210 typed_params = [ 1abcde
211 inspect.Parameter(
212 name=param.name,
213 kind=param.kind,
214 default=param.default,
215 annotation=get_typed_annotation(param.annotation, globalns),
216 )
217 for param in signature.parameters.values()
218 ]
219 typed_signature = inspect.Signature(typed_params) 1abcde
220 return typed_signature 1abcde
223def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: 1abcde
224 if isinstance(annotation, str): 1abcde
225 annotation = ForwardRef(annotation) 1abcde
226 annotation = evaluate_forwardref(annotation, globalns, globalns) 1abcde
227 return annotation 1abcde
230def get_typed_return_annotation(call: Callable[..., Any]) -> Any: 1abcde
231 signature = inspect.signature(call) 1abcde
232 annotation = signature.return_annotation 1abcde
234 if annotation is inspect.Signature.empty: 1abcde
235 return None 1abcde
237 globalns = getattr(call, "__globals__", {}) 1abcde
238 return get_typed_annotation(annotation, globalns) 1abcde
241def get_dependant( 1abcde
242 *,
243 path: str,
244 call: Callable[..., Any],
245 name: Optional[str] = None,
246 security_scopes: Optional[List[str]] = None,
247 use_cache: bool = True,
248) -> Dependant:
249 path_param_names = get_path_param_names(path) 1abcde
250 endpoint_signature = get_typed_signature(call) 1abcde
251 signature_params = endpoint_signature.parameters 1abcde
252 dependant = Dependant( 1abcde
253 call=call,
254 name=name,
255 path=path,
256 security_scopes=security_scopes,
257 use_cache=use_cache,
258 )
259 for param_name, param in signature_params.items(): 1abcde
260 is_path_param = param_name in path_param_names 1abcde
261 type_annotation, depends, param_field = analyze_param( 1abcde
262 param_name=param_name,
263 annotation=param.annotation,
264 value=param.default,
265 is_path_param=is_path_param,
266 )
267 if depends is not None: 1abcde
268 sub_dependant = get_param_sub_dependant( 1abcde
269 param_name=param_name,
270 depends=depends,
271 path=path,
272 security_scopes=security_scopes,
273 )
274 dependant.dependencies.append(sub_dependant) 1abcde
275 continue 1abcde
276 if add_non_field_param_to_dependency( 1abcde
277 param_name=param_name,
278 type_annotation=type_annotation,
279 dependant=dependant,
280 ):
281 assert ( 1ab
282 param_field is None
283 ), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
284 continue 1cde
285 assert param_field is not None 1abcde
286 if is_body_param(param_field=param_field, is_path_param=is_path_param): 1abcde
287 dependant.body_params.append(param_field) 1abcde
288 else:
289 add_param_to_fields(field=param_field, dependant=dependant) 1abcde
290 return dependant 1abcde
293def add_non_field_param_to_dependency( 1abcde
294 *, param_name: str, type_annotation: Any, dependant: Dependant
295) -> Optional[bool]:
296 if lenient_issubclass(type_annotation, Request): 1abcde
297 dependant.request_param_name = param_name 1abcde
298 return True 1abcde
299 elif lenient_issubclass(type_annotation, WebSocket): 1abcde
300 dependant.websocket_param_name = param_name 1abcde
301 return True 1abcde
302 elif lenient_issubclass(type_annotation, HTTPConnection): 1abcde
303 dependant.http_connection_param_name = param_name 1abcde
304 return True 1abcde
305 elif lenient_issubclass(type_annotation, Response): 1abcde
306 dependant.response_param_name = param_name 1abcde
307 return True 1abcde
308 elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): 1abcde
309 dependant.background_tasks_param_name = param_name 1abcde
310 return True 1abcde
311 elif lenient_issubclass(type_annotation, SecurityScopes): 1abcde
312 dependant.security_scopes_param_name = param_name 1abcde
313 return True 1abcde
314 return None 1abcde
317def analyze_param( 1abcde
318 *,
319 param_name: str,
320 annotation: Any,
321 value: Any,
322 is_path_param: bool,
323) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]:
324 field_info = None 1abcde
325 depends = None 1abcde
326 type_annotation: Any = Any 1abcde
327 use_annotation: Any = Any 1abcde
328 if annotation is not inspect.Signature.empty: 1abcde
329 use_annotation = annotation 1abcde
330 type_annotation = annotation 1abcde
331 if get_origin(use_annotation) is Annotated: 1abcde
332 annotated_args = get_args(annotation) 1abcde
333 type_annotation = annotated_args[0] 1abcde
334 fastapi_annotations = [ 1abcde
335 arg
336 for arg in annotated_args[1:]
337 if isinstance(arg, (FieldInfo, params.Depends))
338 ]
339 fastapi_specific_annotations = [ 1abcde
340 arg
341 for arg in fastapi_annotations
342 if isinstance(arg, (params.Param, params.Body, params.Depends))
343 ]
344 if fastapi_specific_annotations: 1abcde
345 fastapi_annotation: Union[ 1abcde
346 FieldInfo, params.Depends, None
347 ] = fastapi_specific_annotations[-1]
348 else:
349 fastapi_annotation = None 1abcde
350 if isinstance(fastapi_annotation, FieldInfo): 1abcde
351 # Copy `field_info` because we mutate `field_info.default` below.
352 field_info = copy_field_info( 1abcde
353 field_info=fastapi_annotation, annotation=use_annotation
354 )
355 assert field_info.default is Undefined or field_info.default is Required, ( 1abcde
356 f"`{field_info.__class__.__name__}` default value cannot be set in"
357 f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
358 )
359 if value is not inspect.Signature.empty: 1abcde
360 assert not is_path_param, "Path parameters cannot have default values" 1abcde
361 field_info.default = value 1abcde
362 else:
363 field_info.default = Required 1abcde
364 elif isinstance(fastapi_annotation, params.Depends): 1abcde
365 depends = fastapi_annotation 1abcde
367 if isinstance(value, params.Depends): 1abcde
368 assert depends is None, ( 1abcde
369 "Cannot specify `Depends` in `Annotated` and default value"
370 f" together for {param_name!r}"
371 )
372 assert field_info is None, ( 1abcde
373 "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
374 f" default value together for {param_name!r}"
375 )
376 depends = value 1abcde
377 elif isinstance(value, FieldInfo): 1abcde
378 assert field_info is None, ( 1abcde
379 "Cannot specify FastAPI annotations in `Annotated` and default value"
380 f" together for {param_name!r}"
381 )
382 field_info = value 1abcde
383 if PYDANTIC_V2: 1abcde
384 field_info.annotation = type_annotation 1abcde
386 if depends is not None and depends.dependency is None: 1abcde
387 # Copy `depends` before mutating it
388 depends = copy(depends) 1abcde
389 depends.dependency = type_annotation 1abcde
391 if lenient_issubclass( 1abcde
392 type_annotation,
393 (
394 Request,
395 WebSocket,
396 HTTPConnection,
397 Response,
398 StarletteBackgroundTasks,
399 SecurityScopes,
400 ),
401 ):
402 assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" 1abcde
403 assert ( 1ab
404 field_info is None
405 ), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
406 elif field_info is None and depends is None: 1abcde
407 default_value = value if value is not inspect.Signature.empty else Required 1abcde
408 if is_path_param: 1abcde
409 # We might check here that `default_value is Required`, but the fact is that the same
410 # parameter might sometimes be a path parameter and sometimes not. See
411 # `tests/test_infer_param_optionality.py` for an example.
412 field_info = params.Path(annotation=use_annotation) 1abcde
413 elif is_uploadfile_or_nonable_uploadfile_annotation( 1abcde
414 type_annotation
415 ) or is_uploadfile_sequence_annotation(type_annotation):
416 field_info = params.File(annotation=use_annotation, default=default_value) 1abcde
417 elif not field_annotation_is_scalar(annotation=type_annotation): 1abcde
418 field_info = params.Body(annotation=use_annotation, default=default_value) 1abcde
419 else:
420 field_info = params.Query(annotation=use_annotation, default=default_value) 1abcde
422 field = None 1abcde
423 if field_info is not None: 1abcde
424 if is_path_param: 1abcde
425 assert isinstance(field_info, params.Path), ( 1abcde
426 f"Cannot use `{field_info.__class__.__name__}` for path param"
427 f" {param_name!r}"
428 )
429 elif ( 1abcd
430 isinstance(field_info, params.Param)
431 and getattr(field_info, "in_", None) is None
432 ):
433 field_info.in_ = params.ParamTypes.query 1abcde
434 use_annotation_from_field_info = get_annotation_from_field_info( 1abcde
435 use_annotation,
436 field_info,
437 param_name,
438 )
439 if not field_info.alias and getattr(field_info, "convert_underscores", None): 1abcde
440 alias = param_name.replace("_", "-") 1abcde
441 else:
442 alias = field_info.alias or param_name 1abcde
443 field_info.alias = alias 1abcde
444 field = create_response_field( 1abcde
445 name=param_name,
446 type_=use_annotation_from_field_info,
447 default=field_info.default,
448 alias=alias,
449 required=field_info.default in (Required, Undefined),
450 field_info=field_info,
451 )
453 return type_annotation, depends, field 1abcde
456def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: 1abcde
457 if is_path_param: 1abcde
458 assert is_scalar_field( 1abcde
459 field=param_field
460 ), "Path params must be of one of the supported types"
461 return False 1abcde
462 elif is_scalar_field(field=param_field): 1abcde
463 return False 1abcde
464 elif isinstance( 1abcde
465 param_field.field_info, (params.Query, params.Header)
466 ) and is_scalar_sequence_field(param_field):
467 return False 1abcde
468 else:
469 assert isinstance( 1abcde
470 param_field.field_info, params.Body
471 ), f"Param: {param_field.name} can only be a request body, using Body()"
472 return True 1abcde
475def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: 1abcde
476 field_info = field.field_info 1abcde
477 field_info_in = getattr(field_info, "in_", None) 1abcde
478 if field_info_in == params.ParamTypes.path: 1abcde
479 dependant.path_params.append(field) 1abcde
480 elif field_info_in == params.ParamTypes.query: 1abcde
481 dependant.query_params.append(field) 1abcde
482 elif field_info_in == params.ParamTypes.header: 1abcde
483 dependant.header_params.append(field) 1abcde
484 else:
485 assert ( 1ab
486 field_info_in == params.ParamTypes.cookie
487 ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
488 dependant.cookie_params.append(field) 1abcde
491def is_coroutine_callable(call: Callable[..., Any]) -> bool: 1abcde
492 if inspect.isroutine(call): 1abcde
493 return inspect.iscoroutinefunction(call) 1abcde
494 if inspect.isclass(call): 1abcde
495 return False 1abcde
496 dunder_call = getattr(call, "__call__", None) # noqa: B004 1abcde
497 return inspect.iscoroutinefunction(dunder_call) 1abcde
500def is_async_gen_callable(call: Callable[..., Any]) -> bool: 1abcde
501 if inspect.isasyncgenfunction(call): 1abcde
502 return True 1abcde
503 dunder_call = getattr(call, "__call__", None) # noqa: B004 1abcde
504 return inspect.isasyncgenfunction(dunder_call) 1abcde
507def is_gen_callable(call: Callable[..., Any]) -> bool: 1abcde
508 if inspect.isgeneratorfunction(call): 1abcde
509 return True 1abcde
510 dunder_call = getattr(call, "__call__", None) # noqa: B004 1abcde
511 return inspect.isgeneratorfunction(dunder_call) 1abcde
514async def solve_generator( 1abcde
515 *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
516) -> Any:
517 if is_gen_callable(call): 1abcde
518 cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) 1abcde
519 elif is_async_gen_callable(call): 1abcde
520 cm = asynccontextmanager(call)(**sub_values) 1abcde
521 return await stack.enter_async_context(cm) 1abcde
524async def solve_dependencies( 1abcde
525 *,
526 request: Union[Request, WebSocket],
527 dependant: Dependant,
528 body: Optional[Union[Dict[str, Any], FormData]] = None,
529 background_tasks: Optional[StarletteBackgroundTasks] = None,
530 response: Optional[Response] = None,
531 dependency_overrides_provider: Optional[Any] = None,
532 dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
533 async_exit_stack: AsyncExitStack,
534) -> Tuple[
535 Dict[str, Any],
536 List[Any],
537 Optional[StarletteBackgroundTasks],
538 Response,
539 Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
540]:
541 values: Dict[str, Any] = {} 1abcde
542 errors: List[Any] = [] 1abcde
543 if response is None: 1abcde
544 response = Response() 1abcde
545 del response.headers["content-length"] 1abcde
546 response.status_code = None # type: ignore 1abcde
547 dependency_cache = dependency_cache or {} 1abcde
548 sub_dependant: Dependant
549 for sub_dependant in dependant.dependencies: 1abcde
550 sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) 1abcde
551 sub_dependant.cache_key = cast( 1abcde
552 Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
553 )
554 call = sub_dependant.call 1abcde
555 use_sub_dependant = sub_dependant 1abcde
556 if ( 1abcd
557 dependency_overrides_provider
558 and dependency_overrides_provider.dependency_overrides
559 ):
560 original_call = sub_dependant.call 1abcde
561 call = getattr( 1abcde
562 dependency_overrides_provider, "dependency_overrides", {}
563 ).get(original_call, original_call)
564 use_path: str = sub_dependant.path # type: ignore 1abcde
565 use_sub_dependant = get_dependant( 1abcde
566 path=use_path,
567 call=call,
568 name=sub_dependant.name,
569 security_scopes=sub_dependant.security_scopes,
570 )
572 solved_result = await solve_dependencies( 1abcde
573 request=request,
574 dependant=use_sub_dependant,
575 body=body,
576 background_tasks=background_tasks,
577 response=response,
578 dependency_overrides_provider=dependency_overrides_provider,
579 dependency_cache=dependency_cache,
580 async_exit_stack=async_exit_stack,
581 )
582 ( 1abcde
583 sub_values,
584 sub_errors,
585 background_tasks,
586 _, # the subdependency returns the same response we have
587 sub_dependency_cache,
588 ) = solved_result
589 dependency_cache.update(sub_dependency_cache) 1abcde
590 if sub_errors: 1abcde
591 errors.extend(sub_errors) 1abcde
592 continue 1abcde
593 if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: 1abcde
594 solved = dependency_cache[sub_dependant.cache_key] 1abcde
595 elif is_gen_callable(call) or is_async_gen_callable(call): 1abcde
596 solved = await solve_generator( 1abcde
597 call=call, stack=async_exit_stack, sub_values=sub_values
598 )
599 elif is_coroutine_callable(call): 1abcde
600 solved = await call(**sub_values) 1abcde
601 else:
602 solved = await run_in_threadpool(call, **sub_values) 1abcde
603 if sub_dependant.name is not None: 1abcde
604 values[sub_dependant.name] = solved 1abcde
605 if sub_dependant.cache_key not in dependency_cache: 1abcde
606 dependency_cache[sub_dependant.cache_key] = solved 1abcde
607 path_values, path_errors = request_params_to_args( 1abcde
608 dependant.path_params, request.path_params
609 )
610 query_values, query_errors = request_params_to_args( 1abcde
611 dependant.query_params, request.query_params
612 )
613 header_values, header_errors = request_params_to_args( 1abcde
614 dependant.header_params, request.headers
615 )
616 cookie_values, cookie_errors = request_params_to_args( 1abcde
617 dependant.cookie_params, request.cookies
618 )
619 values.update(path_values) 1abcde
620 values.update(query_values) 1abcde
621 values.update(header_values) 1abcde
622 values.update(cookie_values) 1abcde
623 errors += path_errors + query_errors + header_errors + cookie_errors 1abcde
624 if dependant.body_params: 1abcde
625 ( 1abcde
626 body_values,
627 body_errors,
628 ) = await request_body_to_args( # body_params checked above
629 required_params=dependant.body_params, received_body=body
630 )
631 values.update(body_values) 1abcde
632 errors.extend(body_errors) 1abcde
633 if dependant.http_connection_param_name: 1abcde
634 values[dependant.http_connection_param_name] = request 1abcde
635 if dependant.request_param_name and isinstance(request, Request): 1abcde
636 values[dependant.request_param_name] = request 1abcde
637 elif dependant.websocket_param_name and isinstance(request, WebSocket): 1abcde
638 values[dependant.websocket_param_name] = request 1abcde
639 if dependant.background_tasks_param_name: 1abcde
640 if background_tasks is None: 1abcde
641 background_tasks = BackgroundTasks() 1abcde
642 values[dependant.background_tasks_param_name] = background_tasks 1abcde
643 if dependant.response_param_name: 1abcde
644 values[dependant.response_param_name] = response 1abcde
645 if dependant.security_scopes_param_name: 1abcde
646 values[dependant.security_scopes_param_name] = SecurityScopes( 1abcde
647 scopes=dependant.security_scopes
648 )
649 return values, errors, background_tasks, response, dependency_cache 1abcde
652def request_params_to_args( 1abcde
653 required_params: Sequence[ModelField],
654 received_params: Union[Mapping[str, Any], QueryParams, Headers],
655) -> Tuple[Dict[str, Any], List[Any]]:
656 values = {} 1abcde
657 errors = [] 1abcde
658 for field in required_params: 1abcde
659 if is_scalar_sequence_field(field) and isinstance( 1abcde
660 received_params, (QueryParams, Headers)
661 ):
662 value = received_params.getlist(field.alias) or field.default 1abcde
663 else:
664 value = received_params.get(field.alias) 1abcde
665 field_info = field.field_info 1abcde
666 assert isinstance( 1abcde
667 field_info, params.Param
668 ), "Params must be subclasses of Param"
669 loc = (field_info.in_.value, field.alias) 1abcde
670 if value is None: 1abcde
671 if field.required: 1abcde
672 errors.append(get_missing_field_error(loc=loc)) 1abcde
673 else:
674 values[field.name] = deepcopy(field.default) 1abcde
675 continue 1abcde
676 v_, errors_ = field.validate(value, values, loc=loc) 1abcde
677 if isinstance(errors_, ErrorWrapper): 1abcde
678 errors.append(errors_) 1abcde
679 elif isinstance(errors_, list): 1abcde
680 new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) 1abcde
681 errors.extend(new_errors) 1abcde
682 else:
683 values[field.name] = v_ 1abcde
684 return values, errors 1abcde
687async def request_body_to_args( 1abcde
688 required_params: List[ModelField],
689 received_body: Optional[Union[Dict[str, Any], FormData]],
690) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
691 values = {} 1abcde
692 errors: List[Dict[str, Any]] = [] 1abcde
693 if required_params: 1abcde
694 field = required_params[0] 1abcde
695 field_info = field.field_info 1abcde
696 embed = getattr(field_info, "embed", None) 1abcde
697 field_alias_omitted = len(required_params) == 1 and not embed 1abcde
698 if field_alias_omitted: 1abcde
699 received_body = {field.alias: received_body} 1abcde
701 for field in required_params: 1abcde
702 loc: Tuple[str, ...]
703 if field_alias_omitted: 1abcde
704 loc = ("body",) 1abcde
705 else:
706 loc = ("body", field.alias) 1abcde
708 value: Optional[Any] = None 1abcde
709 if received_body is not None: 1abcde
710 if (is_sequence_field(field)) and isinstance(received_body, FormData): 1abcde
711 value = received_body.getlist(field.alias) 1abcde
712 else:
713 try: 1abcde
714 value = received_body.get(field.alias) 1abcde
715 except AttributeError: 1abcde
716 errors.append(get_missing_field_error(loc)) 1abcde
717 continue 1abcde
718 if ( 1ab
719 value is None
720 or (isinstance(field_info, params.Form) and value == "")
721 or (
722 isinstance(field_info, params.Form)
723 and is_sequence_field(field)
724 and len(value) == 0
725 )
726 ):
727 if field.required: 1abcde
728 errors.append(get_missing_field_error(loc)) 1abcde
729 else:
730 values[field.name] = deepcopy(field.default) 1abcde
731 continue 1abcde
732 if ( 1abcd
733 isinstance(field_info, params.File)
734 and is_bytes_field(field)
735 and isinstance(value, UploadFile)
736 ):
737 value = await value.read() 1abcde
738 elif ( 1abcd
739 is_bytes_sequence_field(field)
740 and isinstance(field_info, params.File)
741 and value_is_sequence(value)
742 ):
743 # For types
744 assert isinstance(value, sequence_types) # type: ignore[arg-type] 1abcde
745 results: List[Union[bytes, str]] = [] 1abcde
747 async def process_fn( 1abcde
748 fn: Callable[[], Coroutine[Any, Any, Any]],
749 ) -> None:
750 result = await fn() 1abcde
751 results.append(result) # noqa: B023 1abcde
753 async with anyio.create_task_group() as tg: 1abcde
754 for sub_value in value: 1abcde
755 tg.start_soon(process_fn, sub_value.read) 1abcde
756 value = serialize_sequence_value(field=field, value=results) 1abcde
758 v_, errors_ = field.validate(value, values, loc=loc) 1abcde
760 if isinstance(errors_, list): 1abcde
761 errors.extend(errors_) 1abcde
762 elif errors_: 1abcde
763 errors.append(errors_) 1abcde
764 else:
765 values[field.name] = v_ 1abcde
766 return values, errors 1abcde
769def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: 1abcde
770 flat_dependant = get_flat_dependant(dependant) 1abcde
771 if not flat_dependant.body_params: 1abcde
772 return None 1abcde
773 first_param = flat_dependant.body_params[0] 1abcde
774 field_info = first_param.field_info 1abcde
775 embed = getattr(field_info, "embed", None) 1abcde
776 body_param_names_set = {param.name for param in flat_dependant.body_params} 1abcde
777 if len(body_param_names_set) == 1 and not embed: 1abcde
778 check_file_field(first_param) 1abcde
779 return first_param 1abcde
780 # If one field requires to embed, all have to be embedded
781 # in case a sub-dependency is evaluated with a single unique body field
782 # That is combined (embedded) with other body fields
783 for param in flat_dependant.body_params: 1abcde
784 setattr(param.field_info, "embed", True) # noqa: B010 1abcde
785 model_name = "Body_" + name 1abcde
786 BodyModel = create_body_model( 1abcde
787 fields=flat_dependant.body_params, model_name=model_name
788 )
789 required = any(True for f in flat_dependant.body_params if f.required) 1abcde
790 BodyFieldInfo_kwargs: Dict[str, Any] = { 1abcde
791 "annotation": BodyModel,
792 "alias": "body",
793 }
794 if not required: 1abcde
795 BodyFieldInfo_kwargs["default"] = None 1abcde
796 if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): 1abcde
797 BodyFieldInfo: Type[params.Body] = params.File 1abcde
798 elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): 1abcde
799 BodyFieldInfo = params.Form 1abcde
800 else:
801 BodyFieldInfo = params.Body 1abcde
803 body_param_media_types = [ 1abcde
804 f.field_info.media_type
805 for f in flat_dependant.body_params
806 if isinstance(f.field_info, params.Body)
807 ]
808 if len(set(body_param_media_types)) == 1: 1abcde
809 BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] 1abcde
810 final_field = create_response_field( 1abcde
811 name="body",
812 type_=BodyModel,
813 required=required,
814 alias="body",
815 field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
816 )
817 check_file_field(final_field) 1abcde
818 return final_field 1abcde