You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							752 lines
						
					
					
						
							27 KiB
						
					
					
				
			
		
		
	
	
							752 lines
						
					
					
						
							27 KiB
						
					
					
				import dataclasses
 | 
						|
import inspect
 | 
						|
from contextlib import contextmanager
 | 
						|
from copy import deepcopy
 | 
						|
from typing import (
 | 
						|
    Any,
 | 
						|
    Callable,
 | 
						|
    Coroutine,
 | 
						|
    Dict,
 | 
						|
    List,
 | 
						|
    Mapping,
 | 
						|
    Optional,
 | 
						|
    Sequence,
 | 
						|
    Tuple,
 | 
						|
    Type,
 | 
						|
    Union,
 | 
						|
    cast,
 | 
						|
)
 | 
						|
 | 
						|
import anyio
 | 
						|
from fastapi import params
 | 
						|
from fastapi.concurrency import (
 | 
						|
    AsyncExitStack,
 | 
						|
    asynccontextmanager,
 | 
						|
    contextmanager_in_threadpool,
 | 
						|
)
 | 
						|
from fastapi.dependencies.models import Dependant, SecurityRequirement
 | 
						|
from fastapi.logger import logger
 | 
						|
from fastapi.security.base import SecurityBase
 | 
						|
from fastapi.security.oauth2 import OAuth2, SecurityScopes
 | 
						|
from fastapi.security.open_id_connect_url import OpenIdConnect
 | 
						|
from fastapi.utils import create_response_field, get_path_param_names
 | 
						|
from pydantic import BaseModel, create_model
 | 
						|
from pydantic.error_wrappers import ErrorWrapper
 | 
						|
from pydantic.errors import MissingError
 | 
						|
from pydantic.fields import (
 | 
						|
    SHAPE_LIST,
 | 
						|
    SHAPE_SEQUENCE,
 | 
						|
    SHAPE_SET,
 | 
						|
    SHAPE_SINGLETON,
 | 
						|
    SHAPE_TUPLE,
 | 
						|
    SHAPE_TUPLE_ELLIPSIS,
 | 
						|
    FieldInfo,
 | 
						|
    ModelField,
 | 
						|
    Required,
 | 
						|
)
 | 
						|
from pydantic.schema import get_annotation_from_field_info
 | 
						|
from pydantic.typing import ForwardRef, evaluate_forwardref
 | 
						|
from pydantic.utils import lenient_issubclass
 | 
						|
from starlette.background import BackgroundTasks
 | 
						|
from starlette.concurrency import run_in_threadpool
 | 
						|
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
 | 
						|
from starlette.requests import HTTPConnection, Request
 | 
						|
from starlette.responses import Response
 | 
						|
from starlette.websockets import WebSocket
 | 
						|
 | 
						|
sequence_shapes = {
 | 
						|
    SHAPE_LIST,
 | 
						|
    SHAPE_SET,
 | 
						|
    SHAPE_TUPLE,
 | 
						|
    SHAPE_SEQUENCE,
 | 
						|
    SHAPE_TUPLE_ELLIPSIS,
 | 
						|
}
 | 
						|
sequence_types = (list, set, tuple)
 | 
						|
sequence_shape_to_type = {
 | 
						|
    SHAPE_LIST: list,
 | 
						|
    SHAPE_SET: set,
 | 
						|
    SHAPE_TUPLE: tuple,
 | 
						|
    SHAPE_SEQUENCE: list,
 | 
						|
    SHAPE_TUPLE_ELLIPSIS: list,
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
multipart_not_installed_error = (
 | 
						|
    'Form data requires "python-multipart" to be installed. \n'
 | 
						|
    'You can install "python-multipart" with: \n\n'
 | 
						|
    "pip install python-multipart\n"
 | 
						|
)
 | 
						|
multipart_incorrect_install_error = (
 | 
						|
    'Form data requires "python-multipart" to be installed. '
 | 
						|
    'It seems you installed "multipart" instead. \n'
 | 
						|
    'You can remove "multipart" with: \n\n'
 | 
						|
    "pip uninstall multipart\n\n"
 | 
						|
    'And then install "python-multipart" with: \n\n'
 | 
						|
    "pip install python-multipart\n"
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def check_file_field(field: ModelField) -> None:
 | 
						|
    field_info = field.field_info
 | 
						|
    if isinstance(field_info, params.Form):
 | 
						|
        try:
 | 
						|
            # __version__ is available in both multiparts, and can be mocked
 | 
						|
            from multipart import __version__  # type: ignore
 | 
						|
 | 
						|
            assert __version__
 | 
						|
            try:
 | 
						|
                # parse_options_header is only available in the right multipart
 | 
						|
                from multipart.multipart import parse_options_header  # type: ignore
 | 
						|
 | 
						|
                assert parse_options_header
 | 
						|
            except ImportError:
 | 
						|
                logger.error(multipart_incorrect_install_error)
 | 
						|
                raise RuntimeError(multipart_incorrect_install_error)
 | 
						|
        except ImportError:
 | 
						|
            logger.error(multipart_not_installed_error)
 | 
						|
            raise RuntimeError(multipart_not_installed_error)
 | 
						|
 | 
						|
 | 
						|
def get_param_sub_dependant(
 | 
						|
    *, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None
 | 
						|
) -> Dependant:
 | 
						|
    depends: params.Depends = param.default
 | 
						|
    if depends.dependency:
 | 
						|
        dependency = depends.dependency
 | 
						|
    else:
 | 
						|
        dependency = param.annotation
 | 
						|
    return get_sub_dependant(
 | 
						|
        depends=depends,
 | 
						|
        dependency=dependency,
 | 
						|
        path=path,
 | 
						|
        name=param.name,
 | 
						|
        security_scopes=security_scopes,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
 | 
						|
    assert callable(
 | 
						|
        depends.dependency
 | 
						|
    ), "A parameter-less dependency must have a callable dependency"
 | 
						|
    return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
 | 
						|
 | 
						|
 | 
						|
def get_sub_dependant(
 | 
						|
    *,
 | 
						|
    depends: params.Depends,
 | 
						|
    dependency: Callable[..., Any],
 | 
						|
    path: str,
 | 
						|
    name: Optional[str] = None,
 | 
						|
    security_scopes: Optional[List[str]] = None,
 | 
						|
) -> Dependant:
 | 
						|
    security_requirement = None
 | 
						|
    security_scopes = security_scopes or []
 | 
						|
    if isinstance(depends, params.Security):
 | 
						|
        dependency_scopes = depends.scopes
 | 
						|
        security_scopes.extend(dependency_scopes)
 | 
						|
    if isinstance(dependency, SecurityBase):
 | 
						|
        use_scopes: List[str] = []
 | 
						|
        if isinstance(dependency, (OAuth2, OpenIdConnect)):
 | 
						|
            use_scopes = security_scopes
 | 
						|
        security_requirement = SecurityRequirement(
 | 
						|
            security_scheme=dependency, scopes=use_scopes
 | 
						|
        )
 | 
						|
    sub_dependant = get_dependant(
 | 
						|
        path=path,
 | 
						|
        call=dependency,
 | 
						|
        name=name,
 | 
						|
        security_scopes=security_scopes,
 | 
						|
        use_cache=depends.use_cache,
 | 
						|
    )
 | 
						|
    if security_requirement:
 | 
						|
        sub_dependant.security_requirements.append(security_requirement)
 | 
						|
    sub_dependant.security_scopes = security_scopes
 | 
						|
    return sub_dependant
 | 
						|
 | 
						|
 | 
						|
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
 | 
						|
 | 
						|
 | 
						|
def get_flat_dependant(
 | 
						|
    dependant: Dependant,
 | 
						|
    *,
 | 
						|
    skip_repeats: bool = False,
 | 
						|
    visited: Optional[List[CacheKey]] = None,
 | 
						|
) -> Dependant:
 | 
						|
    if visited is None:
 | 
						|
        visited = []
 | 
						|
    visited.append(dependant.cache_key)
 | 
						|
 | 
						|
    flat_dependant = Dependant(
 | 
						|
        path_params=dependant.path_params.copy(),
 | 
						|
        query_params=dependant.query_params.copy(),
 | 
						|
        header_params=dependant.header_params.copy(),
 | 
						|
        cookie_params=dependant.cookie_params.copy(),
 | 
						|
        body_params=dependant.body_params.copy(),
 | 
						|
        security_schemes=dependant.security_requirements.copy(),
 | 
						|
        use_cache=dependant.use_cache,
 | 
						|
        path=dependant.path,
 | 
						|
    )
 | 
						|
    for sub_dependant in dependant.dependencies:
 | 
						|
        if skip_repeats and sub_dependant.cache_key in visited:
 | 
						|
            continue
 | 
						|
        flat_sub = get_flat_dependant(
 | 
						|
            sub_dependant, skip_repeats=skip_repeats, visited=visited
 | 
						|
        )
 | 
						|
        flat_dependant.path_params.extend(flat_sub.path_params)
 | 
						|
        flat_dependant.query_params.extend(flat_sub.query_params)
 | 
						|
        flat_dependant.header_params.extend(flat_sub.header_params)
 | 
						|
        flat_dependant.cookie_params.extend(flat_sub.cookie_params)
 | 
						|
        flat_dependant.body_params.extend(flat_sub.body_params)
 | 
						|
        flat_dependant.security_requirements.extend(flat_sub.security_requirements)
 | 
						|
    return flat_dependant
 | 
						|
 | 
						|
 | 
						|
def get_flat_params(dependant: Dependant) -> List[ModelField]:
 | 
						|
    flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
 | 
						|
    return (
 | 
						|
        flat_dependant.path_params
 | 
						|
        + flat_dependant.query_params
 | 
						|
        + flat_dependant.header_params
 | 
						|
        + flat_dependant.cookie_params
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def is_scalar_field(field: ModelField) -> bool:
 | 
						|
    field_info = field.field_info
 | 
						|
    if not (
 | 
						|
        field.shape == SHAPE_SINGLETON
 | 
						|
        and not lenient_issubclass(field.type_, BaseModel)
 | 
						|
        and not lenient_issubclass(field.type_, sequence_types + (dict,))
 | 
						|
        and not dataclasses.is_dataclass(field.type_)
 | 
						|
        and not isinstance(field_info, params.Body)
 | 
						|
    ):
 | 
						|
        return False
 | 
						|
    if field.sub_fields:
 | 
						|
        if not all(is_scalar_field(f) for f in field.sub_fields):
 | 
						|
            return False
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
def is_scalar_sequence_field(field: ModelField) -> bool:
 | 
						|
    if (field.shape in sequence_shapes) and not lenient_issubclass(
 | 
						|
        field.type_, BaseModel
 | 
						|
    ):
 | 
						|
        if field.sub_fields is not None:
 | 
						|
            for sub_field in field.sub_fields:
 | 
						|
                if not is_scalar_field(sub_field):
 | 
						|
                    return False
 | 
						|
        return True
 | 
						|
    if lenient_issubclass(field.type_, sequence_types):
 | 
						|
        return True
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
 | 
						|
    signature = inspect.signature(call)
 | 
						|
    globalns = getattr(call, "__globals__", {})
 | 
						|
    typed_params = [
 | 
						|
        inspect.Parameter(
 | 
						|
            name=param.name,
 | 
						|
            kind=param.kind,
 | 
						|
            default=param.default,
 | 
						|
            annotation=get_typed_annotation(param, globalns),
 | 
						|
        )
 | 
						|
        for param in signature.parameters.values()
 | 
						|
    ]
 | 
						|
    typed_signature = inspect.Signature(typed_params)
 | 
						|
    return typed_signature
 | 
						|
 | 
						|
 | 
						|
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
 | 
						|
    annotation = param.annotation
 | 
						|
    if isinstance(annotation, str):
 | 
						|
        annotation = ForwardRef(annotation)
 | 
						|
        annotation = evaluate_forwardref(annotation, globalns, globalns)
 | 
						|
    return annotation
 | 
						|
 | 
						|
 | 
						|
def get_dependant(
 | 
						|
    *,
 | 
						|
    path: str,
 | 
						|
    call: Callable[..., Any],
 | 
						|
    name: Optional[str] = None,
 | 
						|
    security_scopes: Optional[List[str]] = None,
 | 
						|
    use_cache: bool = True,
 | 
						|
) -> Dependant:
 | 
						|
    path_param_names = get_path_param_names(path)
 | 
						|
    endpoint_signature = get_typed_signature(call)
 | 
						|
    signature_params = endpoint_signature.parameters
 | 
						|
    dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
 | 
						|
    for param_name, param in signature_params.items():
 | 
						|
        if isinstance(param.default, params.Depends):
 | 
						|
            sub_dependant = get_param_sub_dependant(
 | 
						|
                param=param, path=path, security_scopes=security_scopes
 | 
						|
            )
 | 
						|
            dependant.dependencies.append(sub_dependant)
 | 
						|
            continue
 | 
						|
        if add_non_field_param_to_dependency(param=param, dependant=dependant):
 | 
						|
            continue
 | 
						|
        param_field = get_param_field(
 | 
						|
            param=param, default_field_info=params.Query, param_name=param_name
 | 
						|
        )
 | 
						|
        if param_name in path_param_names:
 | 
						|
            assert is_scalar_field(
 | 
						|
                field=param_field
 | 
						|
            ), "Path params must be of one of the supported types"
 | 
						|
            if isinstance(param.default, params.Path):
 | 
						|
                ignore_default = False
 | 
						|
            else:
 | 
						|
                ignore_default = True
 | 
						|
            param_field = get_param_field(
 | 
						|
                param=param,
 | 
						|
                param_name=param_name,
 | 
						|
                default_field_info=params.Path,
 | 
						|
                force_type=params.ParamTypes.path,
 | 
						|
                ignore_default=ignore_default,
 | 
						|
            )
 | 
						|
            add_param_to_fields(field=param_field, dependant=dependant)
 | 
						|
        elif is_scalar_field(field=param_field):
 | 
						|
            add_param_to_fields(field=param_field, dependant=dependant)
 | 
						|
        elif isinstance(
 | 
						|
            param.default, (params.Query, params.Header)
 | 
						|
        ) and is_scalar_sequence_field(param_field):
 | 
						|
            add_param_to_fields(field=param_field, dependant=dependant)
 | 
						|
        else:
 | 
						|
            field_info = param_field.field_info
 | 
						|
            assert isinstance(
 | 
						|
                field_info, params.Body
 | 
						|
            ), f"Param: {param_field.name} can only be a request body, using Body(...)"
 | 
						|
            dependant.body_params.append(param_field)
 | 
						|
    return dependant
 | 
						|
 | 
						|
 | 
						|
def add_non_field_param_to_dependency(
 | 
						|
    *, param: inspect.Parameter, dependant: Dependant
 | 
						|
) -> Optional[bool]:
 | 
						|
    if lenient_issubclass(param.annotation, Request):
 | 
						|
        dependant.request_param_name = param.name
 | 
						|
        return True
 | 
						|
    elif lenient_issubclass(param.annotation, WebSocket):
 | 
						|
        dependant.websocket_param_name = param.name
 | 
						|
        return True
 | 
						|
    elif lenient_issubclass(param.annotation, HTTPConnection):
 | 
						|
        dependant.http_connection_param_name = param.name
 | 
						|
        return True
 | 
						|
    elif lenient_issubclass(param.annotation, Response):
 | 
						|
        dependant.response_param_name = param.name
 | 
						|
        return True
 | 
						|
    elif lenient_issubclass(param.annotation, BackgroundTasks):
 | 
						|
        dependant.background_tasks_param_name = param.name
 | 
						|
        return True
 | 
						|
    elif lenient_issubclass(param.annotation, SecurityScopes):
 | 
						|
        dependant.security_scopes_param_name = param.name
 | 
						|
        return True
 | 
						|
    return None
 | 
						|
 | 
						|
 | 
						|
def get_param_field(
 | 
						|
    *,
 | 
						|
    param: inspect.Parameter,
 | 
						|
    param_name: str,
 | 
						|
    default_field_info: Type[params.Param] = params.Param,
 | 
						|
    force_type: Optional[params.ParamTypes] = None,
 | 
						|
    ignore_default: bool = False,
 | 
						|
) -> ModelField:
 | 
						|
    default_value = Required
 | 
						|
    had_schema = False
 | 
						|
    if not param.default == param.empty and ignore_default is False:
 | 
						|
        default_value = param.default
 | 
						|
    if isinstance(default_value, FieldInfo):
 | 
						|
        had_schema = True
 | 
						|
        field_info = default_value
 | 
						|
        default_value = field_info.default
 | 
						|
        if (
 | 
						|
            isinstance(field_info, params.Param)
 | 
						|
            and getattr(field_info, "in_", None) is None
 | 
						|
        ):
 | 
						|
            field_info.in_ = default_field_info.in_
 | 
						|
        if force_type:
 | 
						|
            field_info.in_ = force_type  # type: ignore
 | 
						|
    else:
 | 
						|
        field_info = default_field_info(default_value)
 | 
						|
    required = default_value == Required
 | 
						|
    annotation: Any = Any
 | 
						|
    if not param.annotation == param.empty:
 | 
						|
        annotation = param.annotation
 | 
						|
    annotation = get_annotation_from_field_info(annotation, field_info, param_name)
 | 
						|
    if not field_info.alias and getattr(field_info, "convert_underscores", None):
 | 
						|
        alias = param.name.replace("_", "-")
 | 
						|
    else:
 | 
						|
        alias = field_info.alias or param.name
 | 
						|
    field = create_response_field(
 | 
						|
        name=param.name,
 | 
						|
        type_=annotation,
 | 
						|
        default=None if required else default_value,
 | 
						|
        alias=alias,
 | 
						|
        required=required,
 | 
						|
        field_info=field_info,
 | 
						|
    )
 | 
						|
    field.required = required
 | 
						|
    if not had_schema and not is_scalar_field(field=field):
 | 
						|
        field.field_info = params.Body(field_info.default)
 | 
						|
    if not had_schema and lenient_issubclass(field.type_, UploadFile):
 | 
						|
        field.field_info = params.File(field_info.default)
 | 
						|
 | 
						|
    return field
 | 
						|
 | 
						|
 | 
						|
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
 | 
						|
    field_info = cast(params.Param, field.field_info)
 | 
						|
    if field_info.in_ == params.ParamTypes.path:
 | 
						|
        dependant.path_params.append(field)
 | 
						|
    elif field_info.in_ == params.ParamTypes.query:
 | 
						|
        dependant.query_params.append(field)
 | 
						|
    elif field_info.in_ == params.ParamTypes.header:
 | 
						|
        dependant.header_params.append(field)
 | 
						|
    else:
 | 
						|
        assert (
 | 
						|
            field_info.in_ == params.ParamTypes.cookie
 | 
						|
        ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
 | 
						|
        dependant.cookie_params.append(field)
 | 
						|
 | 
						|
 | 
						|
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
 | 
						|
    if inspect.isroutine(call):
 | 
						|
        return inspect.iscoroutinefunction(call)
 | 
						|
    if inspect.isclass(call):
 | 
						|
        return False
 | 
						|
    call = getattr(call, "__call__", None)
 | 
						|
    return inspect.iscoroutinefunction(call)
 | 
						|
 | 
						|
 | 
						|
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
 | 
						|
    if inspect.isasyncgenfunction(call):
 | 
						|
        return True
 | 
						|
    call = getattr(call, "__call__", None)
 | 
						|
    return inspect.isasyncgenfunction(call)
 | 
						|
 | 
						|
 | 
						|
def is_gen_callable(call: Callable[..., Any]) -> bool:
 | 
						|
    if inspect.isgeneratorfunction(call):
 | 
						|
        return True
 | 
						|
    call = getattr(call, "__call__", None)
 | 
						|
    return inspect.isgeneratorfunction(call)
 | 
						|
 | 
						|
 | 
						|
async def solve_generator(
 | 
						|
    *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
 | 
						|
) -> Any:
 | 
						|
    if is_gen_callable(call):
 | 
						|
        cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
 | 
						|
    elif is_async_gen_callable(call):
 | 
						|
        cm = asynccontextmanager(call)(**sub_values)
 | 
						|
    return await stack.enter_async_context(cm)
 | 
						|
 | 
						|
 | 
						|
async def solve_dependencies(
 | 
						|
    *,
 | 
						|
    request: Union[Request, WebSocket],
 | 
						|
    dependant: Dependant,
 | 
						|
    body: Optional[Union[Dict[str, Any], FormData]] = None,
 | 
						|
    background_tasks: Optional[BackgroundTasks] = None,
 | 
						|
    response: Optional[Response] = None,
 | 
						|
    dependency_overrides_provider: Optional[Any] = None,
 | 
						|
    dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
 | 
						|
) -> Tuple[
 | 
						|
    Dict[str, Any],
 | 
						|
    List[ErrorWrapper],
 | 
						|
    Optional[BackgroundTasks],
 | 
						|
    Response,
 | 
						|
    Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
 | 
						|
]:
 | 
						|
    values: Dict[str, Any] = {}
 | 
						|
    errors: List[ErrorWrapper] = []
 | 
						|
    response = response or Response(
 | 
						|
        content=None,
 | 
						|
        status_code=None,  # type: ignore
 | 
						|
        headers=None,  # type: ignore # in Starlette
 | 
						|
        media_type=None,  # type: ignore # in Starlette
 | 
						|
        background=None,  # type: ignore # in Starlette
 | 
						|
    )
 | 
						|
    dependency_cache = dependency_cache or {}
 | 
						|
    sub_dependant: Dependant
 | 
						|
    for sub_dependant in dependant.dependencies:
 | 
						|
        sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
 | 
						|
        sub_dependant.cache_key = cast(
 | 
						|
            Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
 | 
						|
        )
 | 
						|
        call = sub_dependant.call
 | 
						|
        use_sub_dependant = sub_dependant
 | 
						|
        if (
 | 
						|
            dependency_overrides_provider
 | 
						|
            and dependency_overrides_provider.dependency_overrides
 | 
						|
        ):
 | 
						|
            original_call = sub_dependant.call
 | 
						|
            call = getattr(
 | 
						|
                dependency_overrides_provider, "dependency_overrides", {}
 | 
						|
            ).get(original_call, original_call)
 | 
						|
            use_path: str = sub_dependant.path  # type: ignore
 | 
						|
            use_sub_dependant = get_dependant(
 | 
						|
                path=use_path,
 | 
						|
                call=call,
 | 
						|
                name=sub_dependant.name,
 | 
						|
                security_scopes=sub_dependant.security_scopes,
 | 
						|
            )
 | 
						|
            use_sub_dependant.security_scopes = sub_dependant.security_scopes
 | 
						|
 | 
						|
        solved_result = await solve_dependencies(
 | 
						|
            request=request,
 | 
						|
            dependant=use_sub_dependant,
 | 
						|
            body=body,
 | 
						|
            background_tasks=background_tasks,
 | 
						|
            response=response,
 | 
						|
            dependency_overrides_provider=dependency_overrides_provider,
 | 
						|
            dependency_cache=dependency_cache,
 | 
						|
        )
 | 
						|
        (
 | 
						|
            sub_values,
 | 
						|
            sub_errors,
 | 
						|
            background_tasks,
 | 
						|
            _,  # the subdependency returns the same response we have
 | 
						|
            sub_dependency_cache,
 | 
						|
        ) = solved_result
 | 
						|
        dependency_cache.update(sub_dependency_cache)
 | 
						|
        if sub_errors:
 | 
						|
            errors.extend(sub_errors)
 | 
						|
            continue
 | 
						|
        if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
 | 
						|
            solved = dependency_cache[sub_dependant.cache_key]
 | 
						|
        elif is_gen_callable(call) or is_async_gen_callable(call):
 | 
						|
            stack = request.scope.get("fastapi_astack")
 | 
						|
            assert isinstance(stack, AsyncExitStack)
 | 
						|
            solved = await solve_generator(
 | 
						|
                call=call, stack=stack, sub_values=sub_values
 | 
						|
            )
 | 
						|
        elif is_coroutine_callable(call):
 | 
						|
            solved = await call(**sub_values)
 | 
						|
        else:
 | 
						|
            solved = await run_in_threadpool(call, **sub_values)
 | 
						|
        if sub_dependant.name is not None:
 | 
						|
            values[sub_dependant.name] = solved
 | 
						|
        if sub_dependant.cache_key not in dependency_cache:
 | 
						|
            dependency_cache[sub_dependant.cache_key] = solved
 | 
						|
    path_values, path_errors = request_params_to_args(
 | 
						|
        dependant.path_params, request.path_params
 | 
						|
    )
 | 
						|
    query_values, query_errors = request_params_to_args(
 | 
						|
        dependant.query_params, request.query_params
 | 
						|
    )
 | 
						|
    header_values, header_errors = request_params_to_args(
 | 
						|
        dependant.header_params, request.headers
 | 
						|
    )
 | 
						|
    cookie_values, cookie_errors = request_params_to_args(
 | 
						|
        dependant.cookie_params, request.cookies
 | 
						|
    )
 | 
						|
    values.update(path_values)
 | 
						|
    values.update(query_values)
 | 
						|
    values.update(header_values)
 | 
						|
    values.update(cookie_values)
 | 
						|
    errors += path_errors + query_errors + header_errors + cookie_errors
 | 
						|
    if dependant.body_params:
 | 
						|
        (
 | 
						|
            body_values,
 | 
						|
            body_errors,
 | 
						|
        ) = await request_body_to_args(  # body_params checked above
 | 
						|
            required_params=dependant.body_params, received_body=body
 | 
						|
        )
 | 
						|
        values.update(body_values)
 | 
						|
        errors.extend(body_errors)
 | 
						|
    if dependant.http_connection_param_name:
 | 
						|
        values[dependant.http_connection_param_name] = request
 | 
						|
    if dependant.request_param_name and isinstance(request, Request):
 | 
						|
        values[dependant.request_param_name] = request
 | 
						|
    elif dependant.websocket_param_name and isinstance(request, WebSocket):
 | 
						|
        values[dependant.websocket_param_name] = request
 | 
						|
    if dependant.background_tasks_param_name:
 | 
						|
        if background_tasks is None:
 | 
						|
            background_tasks = BackgroundTasks()
 | 
						|
        values[dependant.background_tasks_param_name] = background_tasks
 | 
						|
    if dependant.response_param_name:
 | 
						|
        values[dependant.response_param_name] = response
 | 
						|
    if dependant.security_scopes_param_name:
 | 
						|
        values[dependant.security_scopes_param_name] = SecurityScopes(
 | 
						|
            scopes=dependant.security_scopes
 | 
						|
        )
 | 
						|
    return values, errors, background_tasks, response, dependency_cache
 | 
						|
 | 
						|
 | 
						|
def request_params_to_args(
 | 
						|
    required_params: Sequence[ModelField],
 | 
						|
    received_params: Union[Mapping[str, Any], QueryParams, Headers],
 | 
						|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
 | 
						|
    values = {}
 | 
						|
    errors = []
 | 
						|
    for field in required_params:
 | 
						|
        if is_scalar_sequence_field(field) and isinstance(
 | 
						|
            received_params, (QueryParams, Headers)
 | 
						|
        ):
 | 
						|
            value = received_params.getlist(field.alias) or field.default
 | 
						|
        else:
 | 
						|
            value = received_params.get(field.alias)
 | 
						|
        field_info = field.field_info
 | 
						|
        assert isinstance(
 | 
						|
            field_info, params.Param
 | 
						|
        ), "Params must be subclasses of Param"
 | 
						|
        if value is None:
 | 
						|
            if field.required:
 | 
						|
                errors.append(
 | 
						|
                    ErrorWrapper(
 | 
						|
                        MissingError(), loc=(field_info.in_.value, field.alias)
 | 
						|
                    )
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                values[field.name] = deepcopy(field.default)
 | 
						|
            continue
 | 
						|
        v_, errors_ = field.validate(
 | 
						|
            value, values, loc=(field_info.in_.value, field.alias)
 | 
						|
        )
 | 
						|
        if isinstance(errors_, ErrorWrapper):
 | 
						|
            errors.append(errors_)
 | 
						|
        elif isinstance(errors_, list):
 | 
						|
            errors.extend(errors_)
 | 
						|
        else:
 | 
						|
            values[field.name] = v_
 | 
						|
    return values, errors
 | 
						|
 | 
						|
 | 
						|
async def request_body_to_args(
 | 
						|
    required_params: List[ModelField],
 | 
						|
    received_body: Optional[Union[Dict[str, Any], FormData]],
 | 
						|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
 | 
						|
    values = {}
 | 
						|
    errors = []
 | 
						|
    if required_params:
 | 
						|
        field = required_params[0]
 | 
						|
        field_info = field.field_info
 | 
						|
        embed = getattr(field_info, "embed", None)
 | 
						|
        field_alias_omitted = len(required_params) == 1 and not embed
 | 
						|
        if field_alias_omitted:
 | 
						|
            received_body = {field.alias: received_body}
 | 
						|
 | 
						|
        for field in required_params:
 | 
						|
            loc: Tuple[str, ...]
 | 
						|
            if field_alias_omitted:
 | 
						|
                loc = ("body",)
 | 
						|
            else:
 | 
						|
                loc = ("body", field.alias)
 | 
						|
 | 
						|
            value: Optional[Any] = None
 | 
						|
            if received_body is not None:
 | 
						|
                if (
 | 
						|
                    field.shape in sequence_shapes or field.type_ in sequence_types
 | 
						|
                ) and isinstance(received_body, FormData):
 | 
						|
                    value = received_body.getlist(field.alias)
 | 
						|
                else:
 | 
						|
                    try:
 | 
						|
                        value = received_body.get(field.alias)
 | 
						|
                    except AttributeError:
 | 
						|
                        errors.append(get_missing_field_error(loc))
 | 
						|
                        continue
 | 
						|
            if (
 | 
						|
                value is None
 | 
						|
                or (isinstance(field_info, params.Form) and value == "")
 | 
						|
                or (
 | 
						|
                    isinstance(field_info, params.Form)
 | 
						|
                    and field.shape in sequence_shapes
 | 
						|
                    and len(value) == 0
 | 
						|
                )
 | 
						|
            ):
 | 
						|
                if field.required:
 | 
						|
                    errors.append(get_missing_field_error(loc))
 | 
						|
                else:
 | 
						|
                    values[field.name] = deepcopy(field.default)
 | 
						|
                continue
 | 
						|
            if (
 | 
						|
                isinstance(field_info, params.File)
 | 
						|
                and lenient_issubclass(field.type_, bytes)
 | 
						|
                and isinstance(value, UploadFile)
 | 
						|
            ):
 | 
						|
                value = await value.read()
 | 
						|
            elif (
 | 
						|
                field.shape in sequence_shapes
 | 
						|
                and isinstance(field_info, params.File)
 | 
						|
                and lenient_issubclass(field.type_, bytes)
 | 
						|
                and isinstance(value, sequence_types)
 | 
						|
            ):
 | 
						|
                results: List[Union[bytes, str]] = []
 | 
						|
 | 
						|
                async def process_fn(
 | 
						|
                    fn: Callable[[], Coroutine[Any, Any, Any]]
 | 
						|
                ) -> None:
 | 
						|
                    result = await fn()
 | 
						|
                    results.append(result)
 | 
						|
 | 
						|
                async with anyio.create_task_group() as tg:
 | 
						|
                    for sub_value in value:
 | 
						|
                        tg.start_soon(process_fn, sub_value.read)
 | 
						|
                value = sequence_shape_to_type[field.shape](results)
 | 
						|
 | 
						|
            v_, errors_ = field.validate(value, values, loc=loc)
 | 
						|
 | 
						|
            if isinstance(errors_, ErrorWrapper):
 | 
						|
                errors.append(errors_)
 | 
						|
            elif isinstance(errors_, list):
 | 
						|
                errors.extend(errors_)
 | 
						|
            else:
 | 
						|
                values[field.name] = v_
 | 
						|
    return values, errors
 | 
						|
 | 
						|
 | 
						|
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper:
 | 
						|
    missing_field_error = ErrorWrapper(MissingError(), loc=loc)
 | 
						|
    return missing_field_error
 | 
						|
 | 
						|
 | 
						|
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
 | 
						|
    flat_dependant = get_flat_dependant(dependant)
 | 
						|
    if not flat_dependant.body_params:
 | 
						|
        return None
 | 
						|
    first_param = flat_dependant.body_params[0]
 | 
						|
    field_info = first_param.field_info
 | 
						|
    embed = getattr(field_info, "embed", None)
 | 
						|
    body_param_names_set = {param.name for param in flat_dependant.body_params}
 | 
						|
    if len(body_param_names_set) == 1 and not embed:
 | 
						|
        check_file_field(first_param)
 | 
						|
        return first_param
 | 
						|
    # If one field requires to embed, all have to be embedded
 | 
						|
    # in case a sub-dependency is evaluated with a single unique body field
 | 
						|
    # That is combined (embedded) with other body fields
 | 
						|
    for param in flat_dependant.body_params:
 | 
						|
        setattr(param.field_info, "embed", True)
 | 
						|
    model_name = "Body_" + name
 | 
						|
    BodyModel: Type[BaseModel] = create_model(model_name)
 | 
						|
    for f in flat_dependant.body_params:
 | 
						|
        BodyModel.__fields__[f.name] = f
 | 
						|
    required = any(True for f in flat_dependant.body_params if f.required)
 | 
						|
 | 
						|
    BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None)
 | 
						|
    if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
 | 
						|
        BodyFieldInfo: Type[params.Body] = params.File
 | 
						|
    elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
 | 
						|
        BodyFieldInfo = params.Form
 | 
						|
    else:
 | 
						|
        BodyFieldInfo = params.Body
 | 
						|
 | 
						|
        body_param_media_types = [
 | 
						|
            getattr(f.field_info, "media_type")
 | 
						|
            for f in flat_dependant.body_params
 | 
						|
            if isinstance(f.field_info, params.Body)
 | 
						|
        ]
 | 
						|
        if len(set(body_param_media_types)) == 1:
 | 
						|
            BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
 | 
						|
    final_field = create_response_field(
 | 
						|
        name="body",
 | 
						|
        type_=BodyModel,
 | 
						|
        required=required,
 | 
						|
        alias="body",
 | 
						|
        field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
 | 
						|
    )
 | 
						|
    check_file_field(final_field)
 | 
						|
    return final_field
 |