Coverage for fastapi/dependencies/utils.py: 100%

408 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-08 03:53 +0000

1import inspect 1abcde

2from contextlib import AsyncExitStack, contextmanager 1abcde

3from copy import copy, deepcopy 1abcde

4from typing import ( 1abcde

5 Any, 

6 Callable, 

7 Coroutine, 

8 Dict, 

9 ForwardRef, 

10 List, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Type, 

16 Union, 

17 cast, 

18) 

19 

20import anyio 1abcde

21from fastapi import params 1abcde

22from fastapi._compat import ( 1abcde

23 PYDANTIC_V2, 

24 ErrorWrapper, 

25 ModelField, 

26 Required, 

27 Undefined, 

28 _regenerate_error_with_loc, 

29 copy_field_info, 

30 create_body_model, 

31 evaluate_forwardref, 

32 field_annotation_is_scalar, 

33 get_annotation_from_field_info, 

34 get_missing_field_error, 

35 is_bytes_field, 

36 is_bytes_sequence_field, 

37 is_scalar_field, 

38 is_scalar_sequence_field, 

39 is_sequence_field, 

40 is_uploadfile_or_nonable_uploadfile_annotation, 

41 is_uploadfile_sequence_annotation, 

42 lenient_issubclass, 

43 sequence_types, 

44 serialize_sequence_value, 

45 value_is_sequence, 

46) 

47from fastapi.background import BackgroundTasks 1abcde

48from fastapi.concurrency import ( 1abcde

49 asynccontextmanager, 

50 contextmanager_in_threadpool, 

51) 

52from fastapi.dependencies.models import Dependant, SecurityRequirement 1abcde

53from fastapi.logger import logger 1abcde

54from fastapi.security.base import SecurityBase 1abcde

55from fastapi.security.oauth2 import OAuth2, SecurityScopes 1abcde

56from fastapi.security.open_id_connect_url import OpenIdConnect 1abcde

57from fastapi.utils import create_response_field, get_path_param_names 1abcde

58from pydantic.fields import FieldInfo 1abcde

59from starlette.background import BackgroundTasks as StarletteBackgroundTasks 1abcde

60from starlette.concurrency import run_in_threadpool 1abcde

61from starlette.datastructures import FormData, Headers, QueryParams, UploadFile 1abcde

62from starlette.requests import HTTPConnection, Request 1abcde

63from starlette.responses import Response 1abcde

64from starlette.websockets import WebSocket 1abcde

65from typing_extensions import Annotated, get_args, get_origin 1abcde

66 

67multipart_not_installed_error = ( 1abcde

68 'Form data requires "python-multipart" to be installed. \n' 

69 'You can install "python-multipart" with: \n\n' 

70 "pip install python-multipart\n" 

71) 

72multipart_incorrect_install_error = ( 1abcde

73 'Form data requires "python-multipart" to be installed. ' 

74 'It seems you installed "multipart" instead. \n' 

75 'You can remove "multipart" with: \n\n' 

76 "pip uninstall multipart\n\n" 

77 'And then install "python-multipart" with: \n\n' 

78 "pip install python-multipart\n" 

79) 

80 

81 

82def check_file_field(field: ModelField) -> None: 1abcde

83 field_info = field.field_info 1abcde

84 if isinstance(field_info, params.Form): 1abcde

85 try: 1abcde

86 # __version__ is available in both multiparts, and can be mocked 

87 from multipart import __version__ # type: ignore 1abcde

88 

89 assert __version__ 1abcde

90 try: 1abcde

91 # parse_options_header is only available in the right multipart 

92 from multipart.multipart import parse_options_header # type: ignore 1abcde

93 

94 assert parse_options_header 1abcde

95 except ImportError: 1abcde

96 logger.error(multipart_incorrect_install_error) 1abcde

97 raise RuntimeError(multipart_incorrect_install_error) from None 1abcde

98 except ImportError: 1abcde

99 logger.error(multipart_not_installed_error) 1abcde

100 raise RuntimeError(multipart_not_installed_error) from None 1abcde

101 

102 

103def get_param_sub_dependant( 1abcde

104 *, 

105 param_name: str, 

106 depends: params.Depends, 

107 path: str, 

108 security_scopes: Optional[List[str]] = None, 

109) -> Dependant: 

110 assert depends.dependency 1abcde

111 return get_sub_dependant( 1abcde

112 depends=depends, 

113 dependency=depends.dependency, 

114 path=path, 

115 name=param_name, 

116 security_scopes=security_scopes, 

117 ) 

118 

119 

120def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: 1abcde

121 assert callable( 1abcde

122 depends.dependency 

123 ), "A parameter-less dependency must have a callable dependency" 

124 return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) 1abcde

125 

126 

127def get_sub_dependant( 1abcde

128 *, 

129 depends: params.Depends, 

130 dependency: Callable[..., Any], 

131 path: str, 

132 name: Optional[str] = None, 

133 security_scopes: Optional[List[str]] = None, 

134) -> Dependant: 

135 security_requirement = None 1abcde

136 security_scopes = security_scopes or [] 1abcde

137 if isinstance(depends, params.Security): 1abcde

138 dependency_scopes = depends.scopes 1abcde

139 security_scopes.extend(dependency_scopes) 1abcde

140 if isinstance(dependency, SecurityBase): 1abcde

141 use_scopes: List[str] = [] 1abcde

142 if isinstance(dependency, (OAuth2, OpenIdConnect)): 1abcde

143 use_scopes = security_scopes 1abcde

144 security_requirement = SecurityRequirement( 1abcde

145 security_scheme=dependency, scopes=use_scopes 

146 ) 

147 sub_dependant = get_dependant( 1abcde

148 path=path, 

149 call=dependency, 

150 name=name, 

151 security_scopes=security_scopes, 

152 use_cache=depends.use_cache, 

153 ) 

154 if security_requirement: 1abcde

155 sub_dependant.security_requirements.append(security_requirement) 1abcde

156 return sub_dependant 1abcde

157 

158 

159CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] 1abcde

160 

161 

162def get_flat_dependant( 1abcde

163 dependant: Dependant, 

164 *, 

165 skip_repeats: bool = False, 

166 visited: Optional[List[CacheKey]] = None, 

167) -> Dependant: 

168 if visited is None: 1abcde

169 visited = [] 1abcde

170 visited.append(dependant.cache_key) 1abcde

171 

172 flat_dependant = Dependant( 1abcde

173 path_params=dependant.path_params.copy(), 

174 query_params=dependant.query_params.copy(), 

175 header_params=dependant.header_params.copy(), 

176 cookie_params=dependant.cookie_params.copy(), 

177 body_params=dependant.body_params.copy(), 

178 security_schemes=dependant.security_requirements.copy(), 

179 use_cache=dependant.use_cache, 

180 path=dependant.path, 

181 ) 

182 for sub_dependant in dependant.dependencies: 1abcde

183 if skip_repeats and sub_dependant.cache_key in visited: 1abcde

184 continue 1abcde

185 flat_sub = get_flat_dependant( 1abcde

186 sub_dependant, skip_repeats=skip_repeats, visited=visited 

187 ) 

188 flat_dependant.path_params.extend(flat_sub.path_params) 1abcde

189 flat_dependant.query_params.extend(flat_sub.query_params) 1abcde

190 flat_dependant.header_params.extend(flat_sub.header_params) 1abcde

191 flat_dependant.cookie_params.extend(flat_sub.cookie_params) 1abcde

192 flat_dependant.body_params.extend(flat_sub.body_params) 1abcde

193 flat_dependant.security_requirements.extend(flat_sub.security_requirements) 1abcde

194 return flat_dependant 1abcde

195 

196 

197def get_flat_params(dependant: Dependant) -> List[ModelField]: 1abcde

198 flat_dependant = get_flat_dependant(dependant, skip_repeats=True) 1abcde

199 return ( 1abcde

200 flat_dependant.path_params 

201 + flat_dependant.query_params 

202 + flat_dependant.header_params 

203 + flat_dependant.cookie_params 

204 ) 

205 

206 

207def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: 1abcde

208 signature = inspect.signature(call) 1abcde

209 globalns = getattr(call, "__globals__", {}) 1abcde

210 typed_params = [ 1abcde

211 inspect.Parameter( 

212 name=param.name, 

213 kind=param.kind, 

214 default=param.default, 

215 annotation=get_typed_annotation(param.annotation, globalns), 

216 ) 

217 for param in signature.parameters.values() 

218 ] 

219 typed_signature = inspect.Signature(typed_params) 1abcde

220 return typed_signature 1abcde

221 

222 

223def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: 1abcde

224 if isinstance(annotation, str): 1abcde

225 annotation = ForwardRef(annotation) 1abcde

226 annotation = evaluate_forwardref(annotation, globalns, globalns) 1abcde

227 return annotation 1abcde

228 

229 

230def get_typed_return_annotation(call: Callable[..., Any]) -> Any: 1abcde

231 signature = inspect.signature(call) 1abcde

232 annotation = signature.return_annotation 1abcde

233 

234 if annotation is inspect.Signature.empty: 1abcde

235 return None 1abcde

236 

237 globalns = getattr(call, "__globals__", {}) 1abcde

238 return get_typed_annotation(annotation, globalns) 1abcde

239 

240 

241def get_dependant( 1abcde

242 *, 

243 path: str, 

244 call: Callable[..., Any], 

245 name: Optional[str] = None, 

246 security_scopes: Optional[List[str]] = None, 

247 use_cache: bool = True, 

248) -> Dependant: 

249 path_param_names = get_path_param_names(path) 1abcde

250 endpoint_signature = get_typed_signature(call) 1abcde

251 signature_params = endpoint_signature.parameters 1abcde

252 dependant = Dependant( 1abcde

253 call=call, 

254 name=name, 

255 path=path, 

256 security_scopes=security_scopes, 

257 use_cache=use_cache, 

258 ) 

259 for param_name, param in signature_params.items(): 1abcde

260 is_path_param = param_name in path_param_names 1abcde

261 type_annotation, depends, param_field = analyze_param( 1abcde

262 param_name=param_name, 

263 annotation=param.annotation, 

264 value=param.default, 

265 is_path_param=is_path_param, 

266 ) 

267 if depends is not None: 1abcde

268 sub_dependant = get_param_sub_dependant( 1abcde

269 param_name=param_name, 

270 depends=depends, 

271 path=path, 

272 security_scopes=security_scopes, 

273 ) 

274 dependant.dependencies.append(sub_dependant) 1abcde

275 continue 1abcde

276 if add_non_field_param_to_dependency( 1abcde

277 param_name=param_name, 

278 type_annotation=type_annotation, 

279 dependant=dependant, 

280 ): 

281 assert ( 1ab

282 param_field is None 

283 ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" 

284 continue 1cde

285 assert param_field is not None 1abcde

286 if is_body_param(param_field=param_field, is_path_param=is_path_param): 1abcde

287 dependant.body_params.append(param_field) 1abcde

288 else: 

289 add_param_to_fields(field=param_field, dependant=dependant) 1abcde

290 return dependant 1abcde

291 

292 

293def add_non_field_param_to_dependency( 1abcde

294 *, param_name: str, type_annotation: Any, dependant: Dependant 

295) -> Optional[bool]: 

296 if lenient_issubclass(type_annotation, Request): 1abcde

297 dependant.request_param_name = param_name 1abcde

298 return True 1abcde

299 elif lenient_issubclass(type_annotation, WebSocket): 1abcde

300 dependant.websocket_param_name = param_name 1abcde

301 return True 1abcde

302 elif lenient_issubclass(type_annotation, HTTPConnection): 1abcde

303 dependant.http_connection_param_name = param_name 1abcde

304 return True 1abcde

305 elif lenient_issubclass(type_annotation, Response): 1abcde

306 dependant.response_param_name = param_name 1abcde

307 return True 1abcde

308 elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): 1abcde

309 dependant.background_tasks_param_name = param_name 1abcde

310 return True 1abcde

311 elif lenient_issubclass(type_annotation, SecurityScopes): 1abcde

312 dependant.security_scopes_param_name = param_name 1abcde

313 return True 1abcde

314 return None 1abcde

315 

316 

317def analyze_param( 1abcde

318 *, 

319 param_name: str, 

320 annotation: Any, 

321 value: Any, 

322 is_path_param: bool, 

323) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]: 

324 field_info = None 1abcde

325 depends = None 1abcde

326 type_annotation: Any = Any 1abcde

327 use_annotation: Any = Any 1abcde

328 if annotation is not inspect.Signature.empty: 1abcde

329 use_annotation = annotation 1abcde

330 type_annotation = annotation 1abcde

331 if get_origin(use_annotation) is Annotated: 1abcde

332 annotated_args = get_args(annotation) 1abcde

333 type_annotation = annotated_args[0] 1abcde

334 fastapi_annotations = [ 1abcde

335 arg 

336 for arg in annotated_args[1:] 

337 if isinstance(arg, (FieldInfo, params.Depends)) 

338 ] 

339 fastapi_specific_annotations = [ 1abcde

340 arg 

341 for arg in fastapi_annotations 

342 if isinstance(arg, (params.Param, params.Body, params.Depends)) 

343 ] 

344 if fastapi_specific_annotations: 1abcde

345 fastapi_annotation: Union[ 1abcde

346 FieldInfo, params.Depends, None 

347 ] = fastapi_specific_annotations[-1] 

348 else: 

349 fastapi_annotation = None 1abcde

350 if isinstance(fastapi_annotation, FieldInfo): 1abcde

351 # Copy `field_info` because we mutate `field_info.default` below. 

352 field_info = copy_field_info( 1abcde

353 field_info=fastapi_annotation, annotation=use_annotation 

354 ) 

355 assert field_info.default is Undefined or field_info.default is Required, ( 1abcde

356 f"`{field_info.__class__.__name__}` default value cannot be set in" 

357 f" `Annotated` for {param_name!r}. Set the default value with `=` instead." 

358 ) 

359 if value is not inspect.Signature.empty: 1abcde

360 assert not is_path_param, "Path parameters cannot have default values" 1abcde

361 field_info.default = value 1abcde

362 else: 

363 field_info.default = Required 1abcde

364 elif isinstance(fastapi_annotation, params.Depends): 1abcde

365 depends = fastapi_annotation 1abcde

366 

367 if isinstance(value, params.Depends): 1abcde

368 assert depends is None, ( 1abcde

369 "Cannot specify `Depends` in `Annotated` and default value" 

370 f" together for {param_name!r}" 

371 ) 

372 assert field_info is None, ( 1abcde

373 "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" 

374 f" default value together for {param_name!r}" 

375 ) 

376 depends = value 1abcde

377 elif isinstance(value, FieldInfo): 1abcde

378 assert field_info is None, ( 1abcde

379 "Cannot specify FastAPI annotations in `Annotated` and default value" 

380 f" together for {param_name!r}" 

381 ) 

382 field_info = value 1abcde

383 if PYDANTIC_V2: 1abcde

384 field_info.annotation = type_annotation 1abcde

385 

386 if depends is not None and depends.dependency is None: 1abcde

387 # Copy `depends` before mutating it 

388 depends = copy(depends) 1abcde

389 depends.dependency = type_annotation 1abcde

390 

391 if lenient_issubclass( 1abcde

392 type_annotation, 

393 ( 

394 Request, 

395 WebSocket, 

396 HTTPConnection, 

397 Response, 

398 StarletteBackgroundTasks, 

399 SecurityScopes, 

400 ), 

401 ): 

402 assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" 1abcde

403 assert ( 1ab

404 field_info is None 

405 ), f"Cannot specify FastAPI annotation for type {type_annotation!r}" 

406 elif field_info is None and depends is None: 1abcde

407 default_value = value if value is not inspect.Signature.empty else Required 1abcde

408 if is_path_param: 1abcde

409 # We might check here that `default_value is Required`, but the fact is that the same 

410 # parameter might sometimes be a path parameter and sometimes not. See 

411 # `tests/test_infer_param_optionality.py` for an example. 

412 field_info = params.Path(annotation=use_annotation) 1abcde

413 elif is_uploadfile_or_nonable_uploadfile_annotation( 1abcde

414 type_annotation 

415 ) or is_uploadfile_sequence_annotation(type_annotation): 

416 field_info = params.File(annotation=use_annotation, default=default_value) 1abcde

417 elif not field_annotation_is_scalar(annotation=type_annotation): 1abcde

418 field_info = params.Body(annotation=use_annotation, default=default_value) 1abcde

419 else: 

420 field_info = params.Query(annotation=use_annotation, default=default_value) 1abcde

421 

422 field = None 1abcde

423 if field_info is not None: 1abcde

424 if is_path_param: 1abcde

425 assert isinstance(field_info, params.Path), ( 1abcde

426 f"Cannot use `{field_info.__class__.__name__}` for path param" 

427 f" {param_name!r}" 

428 ) 

429 elif ( 1abcd

430 isinstance(field_info, params.Param) 

431 and getattr(field_info, "in_", None) is None 

432 ): 

433 field_info.in_ = params.ParamTypes.query 1abcde

434 use_annotation_from_field_info = get_annotation_from_field_info( 1abcde

435 use_annotation, 

436 field_info, 

437 param_name, 

438 ) 

439 if not field_info.alias and getattr(field_info, "convert_underscores", None): 1abcde

440 alias = param_name.replace("_", "-") 1abcde

441 else: 

442 alias = field_info.alias or param_name 1abcde

443 field_info.alias = alias 1abcde

444 field = create_response_field( 1abcde

445 name=param_name, 

446 type_=use_annotation_from_field_info, 

447 default=field_info.default, 

448 alias=alias, 

449 required=field_info.default in (Required, Undefined), 

450 field_info=field_info, 

451 ) 

452 

453 return type_annotation, depends, field 1abcde

454 

455 

456def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: 1abcde

457 if is_path_param: 1abcde

458 assert is_scalar_field( 1abcde

459 field=param_field 

460 ), "Path params must be of one of the supported types" 

461 return False 1abcde

462 elif is_scalar_field(field=param_field): 1abcde

463 return False 1abcde

464 elif isinstance( 1abcde

465 param_field.field_info, (params.Query, params.Header) 

466 ) and is_scalar_sequence_field(param_field): 

467 return False 1abcde

468 else: 

469 assert isinstance( 1abcde

470 param_field.field_info, params.Body 

471 ), f"Param: {param_field.name} can only be a request body, using Body()" 

472 return True 1abcde

473 

474 

475def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: 1abcde

476 field_info = field.field_info 1abcde

477 field_info_in = getattr(field_info, "in_", None) 1abcde

478 if field_info_in == params.ParamTypes.path: 1abcde

479 dependant.path_params.append(field) 1abcde

480 elif field_info_in == params.ParamTypes.query: 1abcde

481 dependant.query_params.append(field) 1abcde

482 elif field_info_in == params.ParamTypes.header: 1abcde

483 dependant.header_params.append(field) 1abcde

484 else: 

485 assert ( 1ab

486 field_info_in == params.ParamTypes.cookie 

487 ), f"non-body parameters must be in path, query, header or cookie: {field.name}" 

488 dependant.cookie_params.append(field) 1abcde

489 

490 

491def is_coroutine_callable(call: Callable[..., Any]) -> bool: 1abcde

492 if inspect.isroutine(call): 1abcde

493 return inspect.iscoroutinefunction(call) 1abcde

494 if inspect.isclass(call): 1abcde

495 return False 1abcde

496 dunder_call = getattr(call, "__call__", None) # noqa: B004 1abcde

497 return inspect.iscoroutinefunction(dunder_call) 1abcde

498 

499 

500def is_async_gen_callable(call: Callable[..., Any]) -> bool: 1abcde

501 if inspect.isasyncgenfunction(call): 1abcde

502 return True 1abcde

503 dunder_call = getattr(call, "__call__", None) # noqa: B004 1abcde

504 return inspect.isasyncgenfunction(dunder_call) 1abcde

505 

506 

507def is_gen_callable(call: Callable[..., Any]) -> bool: 1abcde

508 if inspect.isgeneratorfunction(call): 1abcde

509 return True 1abcde

510 dunder_call = getattr(call, "__call__", None) # noqa: B004 1abcde

511 return inspect.isgeneratorfunction(dunder_call) 1abcde

512 

513 

514async def solve_generator( 1abcde

515 *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] 

516) -> Any: 

517 if is_gen_callable(call): 1abcde

518 cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) 1abcde

519 elif is_async_gen_callable(call): 1abcde

520 cm = asynccontextmanager(call)(**sub_values) 1abcde

521 return await stack.enter_async_context(cm) 1abcde

522 

523 

524async def solve_dependencies( 1abcde

525 *, 

526 request: Union[Request, WebSocket], 

527 dependant: Dependant, 

528 body: Optional[Union[Dict[str, Any], FormData]] = None, 

529 background_tasks: Optional[StarletteBackgroundTasks] = None, 

530 response: Optional[Response] = None, 

531 dependency_overrides_provider: Optional[Any] = None, 

532 dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, 

533 async_exit_stack: AsyncExitStack, 

534) -> Tuple[ 

535 Dict[str, Any], 

536 List[Any], 

537 Optional[StarletteBackgroundTasks], 

538 Response, 

539 Dict[Tuple[Callable[..., Any], Tuple[str]], Any], 

540]: 

541 values: Dict[str, Any] = {} 1abcde

542 errors: List[Any] = [] 1abcde

543 if response is None: 1abcde

544 response = Response() 1abcde

545 del response.headers["content-length"] 1abcde

546 response.status_code = None # type: ignore 1abcde

547 dependency_cache = dependency_cache or {} 1abcde

548 sub_dependant: Dependant 

549 for sub_dependant in dependant.dependencies: 1abcde

550 sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) 1abcde

551 sub_dependant.cache_key = cast( 1abcde

552 Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key 

553 ) 

554 call = sub_dependant.call 1abcde

555 use_sub_dependant = sub_dependant 1abcde

556 if ( 1abcd

557 dependency_overrides_provider 

558 and dependency_overrides_provider.dependency_overrides 

559 ): 

560 original_call = sub_dependant.call 1abcde

561 call = getattr( 1abcde

562 dependency_overrides_provider, "dependency_overrides", {} 

563 ).get(original_call, original_call) 

564 use_path: str = sub_dependant.path # type: ignore 1abcde

565 use_sub_dependant = get_dependant( 1abcde

566 path=use_path, 

567 call=call, 

568 name=sub_dependant.name, 

569 security_scopes=sub_dependant.security_scopes, 

570 ) 

571 

572 solved_result = await solve_dependencies( 1abcde

573 request=request, 

574 dependant=use_sub_dependant, 

575 body=body, 

576 background_tasks=background_tasks, 

577 response=response, 

578 dependency_overrides_provider=dependency_overrides_provider, 

579 dependency_cache=dependency_cache, 

580 async_exit_stack=async_exit_stack, 

581 ) 

582 ( 1abcde

583 sub_values, 

584 sub_errors, 

585 background_tasks, 

586 _, # the subdependency returns the same response we have 

587 sub_dependency_cache, 

588 ) = solved_result 

589 dependency_cache.update(sub_dependency_cache) 1abcde

590 if sub_errors: 1abcde

591 errors.extend(sub_errors) 1abcde

592 continue 1abcde

593 if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: 1abcde

594 solved = dependency_cache[sub_dependant.cache_key] 1abcde

595 elif is_gen_callable(call) or is_async_gen_callable(call): 1abcde

596 solved = await solve_generator( 1abcde

597 call=call, stack=async_exit_stack, sub_values=sub_values 

598 ) 

599 elif is_coroutine_callable(call): 1abcde

600 solved = await call(**sub_values) 1abcde

601 else: 

602 solved = await run_in_threadpool(call, **sub_values) 1abcde

603 if sub_dependant.name is not None: 1abcde

604 values[sub_dependant.name] = solved 1abcde

605 if sub_dependant.cache_key not in dependency_cache: 1abcde

606 dependency_cache[sub_dependant.cache_key] = solved 1abcde

607 path_values, path_errors = request_params_to_args( 1abcde

608 dependant.path_params, request.path_params 

609 ) 

610 query_values, query_errors = request_params_to_args( 1abcde

611 dependant.query_params, request.query_params 

612 ) 

613 header_values, header_errors = request_params_to_args( 1abcde

614 dependant.header_params, request.headers 

615 ) 

616 cookie_values, cookie_errors = request_params_to_args( 1abcde

617 dependant.cookie_params, request.cookies 

618 ) 

619 values.update(path_values) 1abcde

620 values.update(query_values) 1abcde

621 values.update(header_values) 1abcde

622 values.update(cookie_values) 1abcde

623 errors += path_errors + query_errors + header_errors + cookie_errors 1abcde

624 if dependant.body_params: 1abcde

625 ( 1abcde

626 body_values, 

627 body_errors, 

628 ) = await request_body_to_args( # body_params checked above 

629 required_params=dependant.body_params, received_body=body 

630 ) 

631 values.update(body_values) 1abcde

632 errors.extend(body_errors) 1abcde

633 if dependant.http_connection_param_name: 1abcde

634 values[dependant.http_connection_param_name] = request 1abcde

635 if dependant.request_param_name and isinstance(request, Request): 1abcde

636 values[dependant.request_param_name] = request 1abcde

637 elif dependant.websocket_param_name and isinstance(request, WebSocket): 1abcde

638 values[dependant.websocket_param_name] = request 1abcde

639 if dependant.background_tasks_param_name: 1abcde

640 if background_tasks is None: 1abcde

641 background_tasks = BackgroundTasks() 1abcde

642 values[dependant.background_tasks_param_name] = background_tasks 1abcde

643 if dependant.response_param_name: 1abcde

644 values[dependant.response_param_name] = response 1abcde

645 if dependant.security_scopes_param_name: 1abcde

646 values[dependant.security_scopes_param_name] = SecurityScopes( 1abcde

647 scopes=dependant.security_scopes 

648 ) 

649 return values, errors, background_tasks, response, dependency_cache 1abcde

650 

651 

652def request_params_to_args( 1abcde

653 required_params: Sequence[ModelField], 

654 received_params: Union[Mapping[str, Any], QueryParams, Headers], 

655) -> Tuple[Dict[str, Any], List[Any]]: 

656 values = {} 1abcde

657 errors = [] 1abcde

658 for field in required_params: 1abcde

659 if is_scalar_sequence_field(field) and isinstance( 1abcde

660 received_params, (QueryParams, Headers) 

661 ): 

662 value = received_params.getlist(field.alias) or field.default 1abcde

663 else: 

664 value = received_params.get(field.alias) 1abcde

665 field_info = field.field_info 1abcde

666 assert isinstance( 1abcde

667 field_info, params.Param 

668 ), "Params must be subclasses of Param" 

669 loc = (field_info.in_.value, field.alias) 1abcde

670 if value is None: 1abcde

671 if field.required: 1abcde

672 errors.append(get_missing_field_error(loc=loc)) 1abcde

673 else: 

674 values[field.name] = deepcopy(field.default) 1abcde

675 continue 1abcde

676 v_, errors_ = field.validate(value, values, loc=loc) 1abcde

677 if isinstance(errors_, ErrorWrapper): 1abcde

678 errors.append(errors_) 1abcde

679 elif isinstance(errors_, list): 1abcde

680 new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) 1abcde

681 errors.extend(new_errors) 1abcde

682 else: 

683 values[field.name] = v_ 1abcde

684 return values, errors 1abcde

685 

686 

687async def request_body_to_args( 1abcde

688 required_params: List[ModelField], 

689 received_body: Optional[Union[Dict[str, Any], FormData]], 

690) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: 

691 values = {} 1abcde

692 errors: List[Dict[str, Any]] = [] 1abcde

693 if required_params: 1abcde

694 field = required_params[0] 1abcde

695 field_info = field.field_info 1abcde

696 embed = getattr(field_info, "embed", None) 1abcde

697 field_alias_omitted = len(required_params) == 1 and not embed 1abcde

698 if field_alias_omitted: 1abcde

699 received_body = {field.alias: received_body} 1abcde

700 

701 for field in required_params: 1abcde

702 loc: Tuple[str, ...] 

703 if field_alias_omitted: 1abcde

704 loc = ("body",) 1abcde

705 else: 

706 loc = ("body", field.alias) 1abcde

707 

708 value: Optional[Any] = None 1abcde

709 if received_body is not None: 1abcde

710 if (is_sequence_field(field)) and isinstance(received_body, FormData): 1abcde

711 value = received_body.getlist(field.alias) 1abcde

712 else: 

713 try: 1abcde

714 value = received_body.get(field.alias) 1abcde

715 except AttributeError: 1abcde

716 errors.append(get_missing_field_error(loc)) 1abcde

717 continue 1abcde

718 if ( 1ab

719 value is None 

720 or (isinstance(field_info, params.Form) and value == "") 

721 or ( 

722 isinstance(field_info, params.Form) 

723 and is_sequence_field(field) 

724 and len(value) == 0 

725 ) 

726 ): 

727 if field.required: 1abcde

728 errors.append(get_missing_field_error(loc)) 1abcde

729 else: 

730 values[field.name] = deepcopy(field.default) 1abcde

731 continue 1abcde

732 if ( 1abcd

733 isinstance(field_info, params.File) 

734 and is_bytes_field(field) 

735 and isinstance(value, UploadFile) 

736 ): 

737 value = await value.read() 1abcde

738 elif ( 1abcd

739 is_bytes_sequence_field(field) 

740 and isinstance(field_info, params.File) 

741 and value_is_sequence(value) 

742 ): 

743 # For types 

744 assert isinstance(value, sequence_types) # type: ignore[arg-type] 1abcde

745 results: List[Union[bytes, str]] = [] 1abcde

746 

747 async def process_fn( 1abcde

748 fn: Callable[[], Coroutine[Any, Any, Any]], 

749 ) -> None: 

750 result = await fn() 1abcde

751 results.append(result) # noqa: B023 1abcde

752 

753 async with anyio.create_task_group() as tg: 1abcde

754 for sub_value in value: 1abcde

755 tg.start_soon(process_fn, sub_value.read) 1abcde

756 value = serialize_sequence_value(field=field, value=results) 1abcde

757 

758 v_, errors_ = field.validate(value, values, loc=loc) 1abcde

759 

760 if isinstance(errors_, list): 1abcde

761 errors.extend(errors_) 1abcde

762 elif errors_: 1abcde

763 errors.append(errors_) 1abcde

764 else: 

765 values[field.name] = v_ 1abcde

766 return values, errors 1abcde

767 

768 

769def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: 1abcde

770 flat_dependant = get_flat_dependant(dependant) 1abcde

771 if not flat_dependant.body_params: 1abcde

772 return None 1abcde

773 first_param = flat_dependant.body_params[0] 1abcde

774 field_info = first_param.field_info 1abcde

775 embed = getattr(field_info, "embed", None) 1abcde

776 body_param_names_set = {param.name for param in flat_dependant.body_params} 1abcde

777 if len(body_param_names_set) == 1 and not embed: 1abcde

778 check_file_field(first_param) 1abcde

779 return first_param 1abcde

780 # If one field requires to embed, all have to be embedded 

781 # in case a sub-dependency is evaluated with a single unique body field 

782 # That is combined (embedded) with other body fields 

783 for param in flat_dependant.body_params: 1abcde

784 setattr(param.field_info, "embed", True) # noqa: B010 1abcde

785 model_name = "Body_" + name 1abcde

786 BodyModel = create_body_model( 1abcde

787 fields=flat_dependant.body_params, model_name=model_name 

788 ) 

789 required = any(True for f in flat_dependant.body_params if f.required) 1abcde

790 BodyFieldInfo_kwargs: Dict[str, Any] = { 1abcde

791 "annotation": BodyModel, 

792 "alias": "body", 

793 } 

794 if not required: 1abcde

795 BodyFieldInfo_kwargs["default"] = None 1abcde

796 if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): 1abcde

797 BodyFieldInfo: Type[params.Body] = params.File 1abcde

798 elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): 1abcde

799 BodyFieldInfo = params.Form 1abcde

800 else: 

801 BodyFieldInfo = params.Body 1abcde

802 

803 body_param_media_types = [ 1abcde

804 f.field_info.media_type 

805 for f in flat_dependant.body_params 

806 if isinstance(f.field_info, params.Body) 

807 ] 

808 if len(set(body_param_media_types)) == 1: 1abcde

809 BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] 1abcde

810 final_field = create_response_field( 1abcde

811 name="body", 

812 type_=BodyModel, 

813 required=required, 

814 alias="body", 

815 field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), 

816 ) 

817 check_file_field(final_field) 1abcde

818 return final_field 1abcde