You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					752 lines
				
				24 KiB
			
		
		
			
		
	
	
					752 lines
				
				24 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import warnings
							 | 
						||
| 
								 | 
							
								import weakref
							 | 
						||
| 
								 | 
							
								from collections import OrderedDict, defaultdict, deque
							 | 
						||
| 
								 | 
							
								from copy import deepcopy
							 | 
						||
| 
								 | 
							
								from itertools import islice, zip_longest
							 | 
						||
| 
								 | 
							
								from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
							 | 
						||
| 
								 | 
							
								from typing import (
							 | 
						||
| 
								 | 
							
								    TYPE_CHECKING,
							 | 
						||
| 
								 | 
							
								    AbstractSet,
							 | 
						||
| 
								 | 
							
								    Any,
							 | 
						||
| 
								 | 
							
								    Callable,
							 | 
						||
| 
								 | 
							
								    Collection,
							 | 
						||
| 
								 | 
							
								    Dict,
							 | 
						||
| 
								 | 
							
								    Generator,
							 | 
						||
| 
								 | 
							
								    Iterable,
							 | 
						||
| 
								 | 
							
								    Iterator,
							 | 
						||
| 
								 | 
							
								    List,
							 | 
						||
| 
								 | 
							
								    Mapping,
							 | 
						||
| 
								 | 
							
								    Optional,
							 | 
						||
| 
								 | 
							
								    Set,
							 | 
						||
| 
								 | 
							
								    Tuple,
							 | 
						||
| 
								 | 
							
								    Type,
							 | 
						||
| 
								 | 
							
								    TypeVar,
							 | 
						||
| 
								 | 
							
								    Union,
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from typing_extensions import Annotated
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .errors import ConfigError
							 | 
						||
| 
								 | 
							
								from .typing import (
							 | 
						||
| 
								 | 
							
								    NoneType,
							 | 
						||
| 
								 | 
							
								    WithArgsTypes,
							 | 
						||
| 
								 | 
							
								    all_literal_values,
							 | 
						||
| 
								 | 
							
								    display_as_type,
							 | 
						||
| 
								 | 
							
								    get_args,
							 | 
						||
| 
								 | 
							
								    get_origin,
							 | 
						||
| 
								 | 
							
								    is_literal_type,
							 | 
						||
| 
								 | 
							
								    is_union,
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								from .version import version_info
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if TYPE_CHECKING:
							 | 
						||
| 
								 | 
							
								    from inspect import Signature
							 | 
						||
| 
								 | 
							
								    from pathlib import Path
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    from .config import BaseConfig
							 | 
						||
| 
								 | 
							
								    from .dataclasses import Dataclass
							 | 
						||
| 
								 | 
							
								    from .fields import ModelField
							 | 
						||
| 
								 | 
							
								    from .main import BaseModel
							 | 
						||
| 
								 | 
							
								    from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								__all__ = (
							 | 
						||
| 
								 | 
							
								    'import_string',
							 | 
						||
| 
								 | 
							
								    'sequence_like',
							 | 
						||
| 
								 | 
							
								    'validate_field_name',
							 | 
						||
| 
								 | 
							
								    'lenient_isinstance',
							 | 
						||
| 
								 | 
							
								    'lenient_issubclass',
							 | 
						||
| 
								 | 
							
								    'in_ipython',
							 | 
						||
| 
								 | 
							
								    'deep_update',
							 | 
						||
| 
								 | 
							
								    'update_not_none',
							 | 
						||
| 
								 | 
							
								    'almost_equal_floats',
							 | 
						||
| 
								 | 
							
								    'get_model',
							 | 
						||
| 
								 | 
							
								    'to_camel',
							 | 
						||
| 
								 | 
							
								    'is_valid_field',
							 | 
						||
| 
								 | 
							
								    'smart_deepcopy',
							 | 
						||
| 
								 | 
							
								    'PyObjectStr',
							 | 
						||
| 
								 | 
							
								    'Representation',
							 | 
						||
| 
								 | 
							
								    'GetterDict',
							 | 
						||
| 
								 | 
							
								    'ValueItems',
							 | 
						||
| 
								 | 
							
								    'version_info',  # required here to match behaviour in v1.3
							 | 
						||
| 
								 | 
							
								    'ClassAttribute',
							 | 
						||
| 
								 | 
							
								    'path_type',
							 | 
						||
| 
								 | 
							
								    'ROOT_KEY',
							 | 
						||
| 
								 | 
							
								    'get_unique_discriminator_alias',
							 | 
						||
| 
								 | 
							
								    'get_discriminator_alias_and_values',
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								ROOT_KEY = '__root__'
							 | 
						||
| 
								 | 
							
								# these are types that are returned unchanged by deepcopy
							 | 
						||
| 
								 | 
							
								IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = {
							 | 
						||
| 
								 | 
							
								    int,
							 | 
						||
| 
								 | 
							
								    float,
							 | 
						||
| 
								 | 
							
								    complex,
							 | 
						||
| 
								 | 
							
								    str,
							 | 
						||
| 
								 | 
							
								    bool,
							 | 
						||
| 
								 | 
							
								    bytes,
							 | 
						||
| 
								 | 
							
								    type,
							 | 
						||
| 
								 | 
							
								    NoneType,
							 | 
						||
| 
								 | 
							
								    FunctionType,
							 | 
						||
| 
								 | 
							
								    BuiltinFunctionType,
							 | 
						||
| 
								 | 
							
								    LambdaType,
							 | 
						||
| 
								 | 
							
								    weakref.ref,
							 | 
						||
| 
								 | 
							
								    CodeType,
							 | 
						||
| 
								 | 
							
								    # note: including ModuleType will differ from behaviour of deepcopy by not producing error.
							 | 
						||
| 
								 | 
							
								    # It might be not a good idea in general, but considering that this function used only internally
							 | 
						||
| 
								 | 
							
								    # against default values of fields, this will allow to actually have a field with module as default value
							 | 
						||
| 
								 | 
							
								    ModuleType,
							 | 
						||
| 
								 | 
							
								    NotImplemented.__class__,
							 | 
						||
| 
								 | 
							
								    Ellipsis.__class__,
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								# these are types that if empty, might be copied with simple copy() instead of deepcopy()
							 | 
						||
| 
								 | 
							
								BUILTIN_COLLECTIONS: Set[Type[Any]] = {
							 | 
						||
| 
								 | 
							
								    list,
							 | 
						||
| 
								 | 
							
								    set,
							 | 
						||
| 
								 | 
							
								    tuple,
							 | 
						||
| 
								 | 
							
								    frozenset,
							 | 
						||
| 
								 | 
							
								    dict,
							 | 
						||
| 
								 | 
							
								    OrderedDict,
							 | 
						||
| 
								 | 
							
								    defaultdict,
							 | 
						||
| 
								 | 
							
								    deque,
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def import_string(dotted_path: str) -> Any:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
							 | 
						||
| 
								 | 
							
								    last name in the path. Raise ImportError if the import fails.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    from importlib import import_module
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
							 | 
						||
| 
								 | 
							
								    except ValueError as e:
							 | 
						||
| 
								 | 
							
								        raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    module = import_module(module_path)
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        return getattr(module, class_name)
							 | 
						||
| 
								 | 
							
								    except AttributeError as e:
							 | 
						||
| 
								 | 
							
								        raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def truncate(v: Union[str], *, max_len: int = 80) -> str:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning)
							 | 
						||
| 
								 | 
							
								    if isinstance(v, str) and len(v) > (max_len - 2):
							 | 
						||
| 
								 | 
							
								        # -3 so quote + string + … + quote has correct length
							 | 
						||
| 
								 | 
							
								        return (v[: (max_len - 3)] + '…').__repr__()
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        v = v.__repr__()
							 | 
						||
| 
								 | 
							
								    except TypeError:
							 | 
						||
| 
								 | 
							
								        v = v.__class__.__repr__(v)  # in case v is a type
							 | 
						||
| 
								 | 
							
								    if len(v) > max_len:
							 | 
						||
| 
								 | 
							
								        v = v[: max_len - 1] + '…'
							 | 
						||
| 
								 | 
							
								    return v
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def sequence_like(v: Any) -> bool:
							 | 
						||
| 
								 | 
							
								    return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Ensure that the field's name does not shadow an existing attribute of the model.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    for base in bases:
							 | 
						||
| 
								 | 
							
								        if getattr(base, field_name, None):
							 | 
						||
| 
								 | 
							
								            raise NameError(
							 | 
						||
| 
								 | 
							
								                f'Field name "{field_name}" shadows a BaseModel attribute; '
							 | 
						||
| 
								 | 
							
								                f'use a different field name with "alias=\'{field_name}\'".'
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        return isinstance(o, class_or_tuple)  # type: ignore[arg-type]
							 | 
						||
| 
								 | 
							
								    except TypeError:
							 | 
						||
| 
								 | 
							
								        return False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        return isinstance(cls, type) and issubclass(cls, class_or_tuple)  # type: ignore[arg-type]
							 | 
						||
| 
								 | 
							
								    except TypeError:
							 | 
						||
| 
								 | 
							
								        if isinstance(cls, WithArgsTypes):
							 | 
						||
| 
								 | 
							
								            return False
							 | 
						||
| 
								 | 
							
								        raise  # pragma: no cover
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def in_ipython() -> bool:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Check whether we're in an ipython environment, including jupyter notebooks.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        eval('__IPYTHON__')
							 | 
						||
| 
								 | 
							
								    except NameError:
							 | 
						||
| 
								 | 
							
								        return False
							 | 
						||
| 
								 | 
							
								    else:  # pragma: no cover
							 | 
						||
| 
								 | 
							
								        return True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								KeyType = TypeVar('KeyType')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]:
							 | 
						||
| 
								 | 
							
								    updated_mapping = mapping.copy()
							 | 
						||
| 
								 | 
							
								    for updating_mapping in updating_mappings:
							 | 
						||
| 
								 | 
							
								        for k, v in updating_mapping.items():
							 | 
						||
| 
								 | 
							
								            if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
							 | 
						||
| 
								 | 
							
								                updated_mapping[k] = deep_update(updated_mapping[k], v)
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                updated_mapping[k] = v
							 | 
						||
| 
								 | 
							
								    return updated_mapping
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None:
							 | 
						||
| 
								 | 
							
								    mapping.update({k: v for k, v in update.items() if v is not None})
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Return True if two floats are almost equal
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    return abs(value_1 - value_2) <= delta
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def generate_model_signature(
							 | 
						||
| 
								 | 
							
								    init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig']
							 | 
						||
| 
								 | 
							
								) -> 'Signature':
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Generate signature for model based on its fields
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    from inspect import Parameter, Signature, signature
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    from .config import Extra
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    present_params = signature(init).parameters.values()
							 | 
						||
| 
								 | 
							
								    merged_params: Dict[str, Parameter] = {}
							 | 
						||
| 
								 | 
							
								    var_kw = None
							 | 
						||
| 
								 | 
							
								    use_var_kw = False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for param in islice(present_params, 1, None):  # skip self arg
							 | 
						||
| 
								 | 
							
								        if param.kind is param.VAR_KEYWORD:
							 | 
						||
| 
								 | 
							
								            var_kw = param
							 | 
						||
| 
								 | 
							
								            continue
							 | 
						||
| 
								 | 
							
								        merged_params[param.name] = param
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if var_kw:  # if custom init has no var_kw, fields which are not declared in it cannot be passed through
							 | 
						||
| 
								 | 
							
								        allow_names = config.allow_population_by_field_name
							 | 
						||
| 
								 | 
							
								        for field_name, field in fields.items():
							 | 
						||
| 
								 | 
							
								            param_name = field.alias
							 | 
						||
| 
								 | 
							
								            if field_name in merged_params or param_name in merged_params:
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								            elif not param_name.isidentifier():
							 | 
						||
| 
								 | 
							
								                if allow_names and field_name.isidentifier():
							 | 
						||
| 
								 | 
							
								                    param_name = field_name
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    use_var_kw = True
							 | 
						||
| 
								 | 
							
								                    continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # TODO: replace annotation with actual expected types once #1055 solved
							 | 
						||
| 
								 | 
							
								            kwargs = {'default': field.default} if not field.required else {}
							 | 
						||
| 
								 | 
							
								            merged_params[param_name] = Parameter(
							 | 
						||
| 
								 | 
							
								                param_name, Parameter.KEYWORD_ONLY, annotation=field.outer_type_, **kwargs
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if config.extra is Extra.allow:
							 | 
						||
| 
								 | 
							
								        use_var_kw = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if var_kw and use_var_kw:
							 | 
						||
| 
								 | 
							
								        # Make sure the parameter for extra kwargs
							 | 
						||
| 
								 | 
							
								        # does not have the same name as a field
							 | 
						||
| 
								 | 
							
								        default_model_signature = [
							 | 
						||
| 
								 | 
							
								            ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
							 | 
						||
| 
								 | 
							
								            ('data', Parameter.VAR_KEYWORD),
							 | 
						||
| 
								 | 
							
								        ]
							 | 
						||
| 
								 | 
							
								        if [(p.name, p.kind) for p in present_params] == default_model_signature:
							 | 
						||
| 
								 | 
							
								            # if this is the standard model signature, use extra_data as the extra args name
							 | 
						||
| 
								 | 
							
								            var_kw_name = 'extra_data'
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            # else start from var_kw
							 | 
						||
| 
								 | 
							
								            var_kw_name = var_kw.name
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # generate a name that's definitely unique
							 | 
						||
| 
								 | 
							
								        while var_kw_name in fields:
							 | 
						||
| 
								 | 
							
								            var_kw_name += '_'
							 | 
						||
| 
								 | 
							
								        merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return Signature(parameters=list(merged_params.values()), return_annotation=None)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
							 | 
						||
| 
								 | 
							
								    from .main import BaseModel
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        model_cls = obj.__pydantic_model__  # type: ignore
							 | 
						||
| 
								 | 
							
								    except AttributeError:
							 | 
						||
| 
								 | 
							
								        model_cls = obj
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if not issubclass(model_cls, BaseModel):
							 | 
						||
| 
								 | 
							
								        raise TypeError('Unsupported type, must be either BaseModel or dataclass')
							 | 
						||
| 
								 | 
							
								    return model_cls
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def to_camel(string: str) -> str:
							 | 
						||
| 
								 | 
							
								    return ''.join(word.capitalize() for word in string.split('_'))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								T = TypeVar('T')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def unique_list(
							 | 
						||
| 
								 | 
							
								    input_list: Union[List[T], Tuple[T, ...]],
							 | 
						||
| 
								 | 
							
								    *,
							 | 
						||
| 
								 | 
							
								    name_factory: Callable[[T], str] = str,
							 | 
						||
| 
								 | 
							
								) -> List[T]:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Make a list unique while maintaining order.
							 | 
						||
| 
								 | 
							
								    We update the list if another one with the same name is set
							 | 
						||
| 
								 | 
							
								    (e.g. root validator overridden in subclass)
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    result: List[T] = []
							 | 
						||
| 
								 | 
							
								    result_names: List[str] = []
							 | 
						||
| 
								 | 
							
								    for v in input_list:
							 | 
						||
| 
								 | 
							
								        v_name = name_factory(v)
							 | 
						||
| 
								 | 
							
								        if v_name not in result_names:
							 | 
						||
| 
								 | 
							
								            result_names.append(v_name)
							 | 
						||
| 
								 | 
							
								            result.append(v)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            result[result_names.index(v_name)] = v
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return result
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class PyObjectStr(str):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    String class where repr doesn't include quotes. Useful with Representation when you want to return a string
							 | 
						||
| 
								 | 
							
								    representation of something that valid (or pseudo-valid) python.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr__(self) -> str:
							 | 
						||
| 
								 | 
							
								        return str(self)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Representation:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    __pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
							 | 
						||
| 
								 | 
							
								    of objects.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    __slots__: Tuple[str, ...] = tuple()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr_args__(self) -> 'ReprArgs':
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Can either return:
							 | 
						||
| 
								 | 
							
								        * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
							 | 
						||
| 
								 | 
							
								        * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        attrs = ((s, getattr(self, s)) for s in self.__slots__)
							 | 
						||
| 
								 | 
							
								        return [(a, v) for a, v in attrs if v is not None]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr_name__(self) -> str:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Name of the instance's class, used in __repr__.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return self.__class__.__name__
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr_str__(self, join_str: str) -> str:
							 | 
						||
| 
								 | 
							
								        return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        yield self.__repr_name__() + '('
							 | 
						||
| 
								 | 
							
								        yield 1
							 | 
						||
| 
								 | 
							
								        for name, value in self.__repr_args__():
							 | 
						||
| 
								 | 
							
								            if name is not None:
							 | 
						||
| 
								 | 
							
								                yield name + '='
							 | 
						||
| 
								 | 
							
								            yield fmt(value)
							 | 
						||
| 
								 | 
							
								            yield ','
							 | 
						||
| 
								 | 
							
								            yield 0
							 | 
						||
| 
								 | 
							
								        yield -1
							 | 
						||
| 
								 | 
							
								        yield ')'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __str__(self) -> str:
							 | 
						||
| 
								 | 
							
								        return self.__repr_str__(' ')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr__(self) -> str:
							 | 
						||
| 
								 | 
							
								        return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class GetterDict(Representation):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Hack to make object's smell just enough like dicts for validate_model.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    __slots__ = ('_obj',)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, obj: Any):
							 | 
						||
| 
								 | 
							
								        self._obj = obj
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __getitem__(self, key: str) -> Any:
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            return getattr(self._obj, key)
							 | 
						||
| 
								 | 
							
								        except AttributeError as e:
							 | 
						||
| 
								 | 
							
								            raise KeyError(key) from e
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get(self, key: Any, default: Any = None) -> Any:
							 | 
						||
| 
								 | 
							
								        return getattr(self._obj, key, default)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def extra_keys(self) -> Set[Any]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        We don't want to get any other attributes of obj if the model didn't explicitly ask for them
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return set()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def keys(self) -> List[Any]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python
							 | 
						||
| 
								 | 
							
								        dictionaries.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return list(self)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def values(self) -> List[Any]:
							 | 
						||
| 
								 | 
							
								        return [self[k] for k in self]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def items(self) -> Iterator[Tuple[str, Any]]:
							 | 
						||
| 
								 | 
							
								        for k in self:
							 | 
						||
| 
								 | 
							
								            yield k, self.get(k)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __iter__(self) -> Iterator[str]:
							 | 
						||
| 
								 | 
							
								        for name in dir(self._obj):
							 | 
						||
| 
								 | 
							
								            if not name.startswith('_'):
							 | 
						||
| 
								 | 
							
								                yield name
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __len__(self) -> int:
							 | 
						||
| 
								 | 
							
								        return sum(1 for _ in self)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __contains__(self, item: Any) -> bool:
							 | 
						||
| 
								 | 
							
								        return item in self.keys()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __eq__(self, other: Any) -> bool:
							 | 
						||
| 
								 | 
							
								        return dict(self) == dict(other.items())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr_args__(self) -> 'ReprArgs':
							 | 
						||
| 
								 | 
							
								        return [(None, dict(self))]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr_name__(self) -> str:
							 | 
						||
| 
								 | 
							
								        return f'GetterDict[{display_as_type(self._obj)}]'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class ValueItems(Representation):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Class for more convenient calculation of excluded or included fields on values.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    __slots__ = ('_items', '_type')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
							 | 
						||
| 
								 | 
							
								        items = self._coerce_items(items)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if isinstance(value, (list, tuple)):
							 | 
						||
| 
								 | 
							
								            items = self._normalize_indexes(items, len(value))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self._items: 'MappingIntStrAny' = items
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def is_excluded(self, item: Any) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Check if item is fully excluded.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param item: key or index of a value
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return self.is_true(self._items.get(item))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def is_included(self, item: Any) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Check if value is contained in self._items
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param item: key or index of value
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return item in self._items
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param e: key or index of element on value
							 | 
						||
| 
								 | 
							
								        :return: raw values for element if self._items is dict and contain needed element
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        item = self._items.get(e)
							 | 
						||
| 
								 | 
							
								        return item if not self.is_true(item) else None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param items: dict or set of indexes which will be normalized
							 | 
						||
| 
								 | 
							
								        :param v_length: length of sequence indexes of which will be
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
							 | 
						||
| 
								 | 
							
								        {0: True, 2: True, 3: True}
							 | 
						||
| 
								 | 
							
								        >>> self._normalize_indexes({'__all__': True}, 4)
							 | 
						||
| 
								 | 
							
								        {0: True, 1: True, 2: True, 3: True}
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        normalized_items: 'DictIntStrAny' = {}
							 | 
						||
| 
								 | 
							
								        all_items = None
							 | 
						||
| 
								 | 
							
								        for i, v in items.items():
							 | 
						||
| 
								 | 
							
								            if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
							 | 
						||
| 
								 | 
							
								                raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
							 | 
						||
| 
								 | 
							
								            if i == '__all__':
							 | 
						||
| 
								 | 
							
								                all_items = self._coerce_value(v)
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								            if not isinstance(i, int):
							 | 
						||
| 
								 | 
							
								                raise TypeError(
							 | 
						||
| 
								 | 
							
								                    'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
							 | 
						||
| 
								 | 
							
								                    'expected integer keys or keyword "__all__"'
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            normalized_i = v_length + i if i < 0 else i
							 | 
						||
| 
								 | 
							
								            normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not all_items:
							 | 
						||
| 
								 | 
							
								            return normalized_items
							 | 
						||
| 
								 | 
							
								        if self.is_true(all_items):
							 | 
						||
| 
								 | 
							
								            for i in range(v_length):
							 | 
						||
| 
								 | 
							
								                normalized_items.setdefault(i, ...)
							 | 
						||
| 
								 | 
							
								            return normalized_items
							 | 
						||
| 
								 | 
							
								        for i in range(v_length):
							 | 
						||
| 
								 | 
							
								            normalized_item = normalized_items.setdefault(i, {})
							 | 
						||
| 
								 | 
							
								            if not self.is_true(normalized_item):
							 | 
						||
| 
								 | 
							
								                normalized_items[i] = self.merge(all_items, normalized_item)
							 | 
						||
| 
								 | 
							
								        return normalized_items
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @classmethod
							 | 
						||
| 
								 | 
							
								    def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Merge a ``base`` item with an ``override`` item.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Both ``base`` and ``override`` are converted to dictionaries if possible.
							 | 
						||
| 
								 | 
							
								        Sets are converted to dictionaries with the sets entries as keys and
							 | 
						||
| 
								 | 
							
								        Ellipsis as values.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Each key-value pair existing in ``base`` is merged with ``override``,
							 | 
						||
| 
								 | 
							
								        while the rest of the key-value pairs are updated recursively with this function.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Merging takes place based on the "union" of keys if ``intersect`` is
							 | 
						||
| 
								 | 
							
								        set to ``False`` (default) and on the intersection of keys if
							 | 
						||
| 
								 | 
							
								        ``intersect`` is set to ``True``.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        override = cls._coerce_value(override)
							 | 
						||
| 
								 | 
							
								        base = cls._coerce_value(base)
							 | 
						||
| 
								 | 
							
								        if override is None:
							 | 
						||
| 
								 | 
							
								            return base
							 | 
						||
| 
								 | 
							
								        if cls.is_true(base) or base is None:
							 | 
						||
| 
								 | 
							
								            return override
							 | 
						||
| 
								 | 
							
								        if cls.is_true(override):
							 | 
						||
| 
								 | 
							
								            return base if intersect else override
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # intersection or union of keys while preserving ordering:
							 | 
						||
| 
								 | 
							
								        if intersect:
							 | 
						||
| 
								 | 
							
								            merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            merge_keys = list(base) + [k for k in override if k not in base]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        merged: 'DictIntStrAny' = {}
							 | 
						||
| 
								 | 
							
								        for k in merge_keys:
							 | 
						||
| 
								 | 
							
								            merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
							 | 
						||
| 
								 | 
							
								            if merged_item is not None:
							 | 
						||
| 
								 | 
							
								                merged[k] = merged_item
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return merged
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @staticmethod
							 | 
						||
| 
								 | 
							
								    def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
							 | 
						||
| 
								 | 
							
								        if isinstance(items, Mapping):
							 | 
						||
| 
								 | 
							
								            pass
							 | 
						||
| 
								 | 
							
								        elif isinstance(items, AbstractSet):
							 | 
						||
| 
								 | 
							
								            items = dict.fromkeys(items, ...)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            class_name = getattr(items, '__class__', '???')
							 | 
						||
| 
								 | 
							
								            raise TypeError(f'Unexpected type of exclude value {class_name}')
							 | 
						||
| 
								 | 
							
								        return items
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @classmethod
							 | 
						||
| 
								 | 
							
								    def _coerce_value(cls, value: Any) -> Any:
							 | 
						||
| 
								 | 
							
								        if value is None or cls.is_true(value):
							 | 
						||
| 
								 | 
							
								            return value
							 | 
						||
| 
								 | 
							
								        return cls._coerce_items(value)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @staticmethod
							 | 
						||
| 
								 | 
							
								    def is_true(v: Any) -> bool:
							 | 
						||
| 
								 | 
							
								        return v is True or v is ...
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr_args__(self) -> 'ReprArgs':
							 | 
						||
| 
								 | 
							
								        return [(None, self._items)]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class ClassAttribute:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Hide class attribute from its instances
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    __slots__ = (
							 | 
						||
| 
								 | 
							
								        'name',
							 | 
						||
| 
								 | 
							
								        'value',
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, name: str, value: Any) -> None:
							 | 
						||
| 
								 | 
							
								        self.name = name
							 | 
						||
| 
								 | 
							
								        self.value = value
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __get__(self, instance: Any, owner: Type[Any]) -> None:
							 | 
						||
| 
								 | 
							
								        if instance is None:
							 | 
						||
| 
								 | 
							
								            return self.value
							 | 
						||
| 
								 | 
							
								        raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								path_types = {
							 | 
						||
| 
								 | 
							
								    'is_dir': 'directory',
							 | 
						||
| 
								 | 
							
								    'is_file': 'file',
							 | 
						||
| 
								 | 
							
								    'is_mount': 'mount point',
							 | 
						||
| 
								 | 
							
								    'is_symlink': 'symlink',
							 | 
						||
| 
								 | 
							
								    'is_block_device': 'block device',
							 | 
						||
| 
								 | 
							
								    'is_char_device': 'char device',
							 | 
						||
| 
								 | 
							
								    'is_fifo': 'FIFO',
							 | 
						||
| 
								 | 
							
								    'is_socket': 'socket',
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def path_type(p: 'Path') -> str:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Find out what sort of thing a path is.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    assert p.exists(), 'path does not exist'
							 | 
						||
| 
								 | 
							
								    for method, name in path_types.items():
							 | 
						||
| 
								 | 
							
								        if getattr(p, method)():
							 | 
						||
| 
								 | 
							
								            return name
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return 'unknown'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								Obj = TypeVar('Obj')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def smart_deepcopy(obj: Obj) -> Obj:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Return type as is for immutable built-in types
							 | 
						||
| 
								 | 
							
								    Use obj.copy() for built-in empty collections
							 | 
						||
| 
								 | 
							
								    Use copy.deepcopy() for non-empty collections and unknown objects
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    obj_type = obj.__class__
							 | 
						||
| 
								 | 
							
								    if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
							 | 
						||
| 
								 | 
							
								        return obj  # fastest case: obj is immutable and not collection therefore will not be copied anyway
							 | 
						||
| 
								 | 
							
								    elif not obj and obj_type in BUILTIN_COLLECTIONS:
							 | 
						||
| 
								 | 
							
								        # faster way for empty collections, no need to copy its members
							 | 
						||
| 
								 | 
							
								        return obj if obj_type is tuple else obj.copy()  # type: ignore  # tuple doesn't have copy method
							 | 
						||
| 
								 | 
							
								    return deepcopy(obj)  # slowest way when we actually might need a deepcopy
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def is_valid_field(name: str) -> bool:
							 | 
						||
| 
								 | 
							
								    if not name.startswith('_'):
							 | 
						||
| 
								 | 
							
								        return True
							 | 
						||
| 
								 | 
							
								    return ROOT_KEY == name
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def is_valid_private_name(name: str) -> bool:
							 | 
						||
| 
								 | 
							
								    return not is_valid_field(name) and name not in {
							 | 
						||
| 
								 | 
							
								        '__annotations__',
							 | 
						||
| 
								 | 
							
								        '__classcell__',
							 | 
						||
| 
								 | 
							
								        '__doc__',
							 | 
						||
| 
								 | 
							
								        '__module__',
							 | 
						||
| 
								 | 
							
								        '__orig_bases__',
							 | 
						||
| 
								 | 
							
								        '__qualname__',
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								_EMPTY = object()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Check that the items of `left` are the same objects as those in `right`.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    >>> a, b = object(), object()
							 | 
						||
| 
								 | 
							
								    >>> all_identical([a, b, a], [a, b, a])
							 | 
						||
| 
								 | 
							
								    True
							 | 
						||
| 
								 | 
							
								    >>> all_identical([a, b, [a]], [a, b, [a]])  # new list object, while "equal" is not "identical"
							 | 
						||
| 
								 | 
							
								    False
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
							 | 
						||
| 
								 | 
							
								        if left_item is not right_item:
							 | 
						||
| 
								 | 
							
								            return False
							 | 
						||
| 
								 | 
							
								    return True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str:
							 | 
						||
| 
								 | 
							
								    """Validate that all aliases are the same and if that's the case return the alias"""
							 | 
						||
| 
								 | 
							
								    unique_aliases = set(all_aliases)
							 | 
						||
| 
								 | 
							
								    if len(unique_aliases) > 1:
							 | 
						||
| 
								 | 
							
								        raise ConfigError(
							 | 
						||
| 
								 | 
							
								            f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})'
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								    return unique_aliases.pop()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Get alias and all valid values in the `Literal` type of the discriminator field
							 | 
						||
| 
								 | 
							
								    `tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    is_root_model = getattr(tp, '__custom_root_type__', False)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if get_origin(tp) is Annotated:
							 | 
						||
| 
								 | 
							
								        tp = get_args(tp)[0]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if hasattr(tp, '__pydantic_model__'):
							 | 
						||
| 
								 | 
							
								        tp = tp.__pydantic_model__
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if is_union(get_origin(tp)):
							 | 
						||
| 
								 | 
							
								        alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key)
							 | 
						||
| 
								 | 
							
								        return alias, tuple(v for values in all_values for v in values)
							 | 
						||
| 
								 | 
							
								    elif is_root_model:
							 | 
						||
| 
								 | 
							
								        union_type = tp.__fields__[ROOT_KEY].type_
							 | 
						||
| 
								 | 
							
								        alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if len(set(all_values)) > 1:
							 | 
						||
| 
								 | 
							
								            raise ConfigError(
							 | 
						||
| 
								 | 
							
								                f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}'
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return alias, all_values[0]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            t_discriminator_type = tp.__fields__[discriminator_key].type_
							 | 
						||
| 
								 | 
							
								        except AttributeError as e:
							 | 
						||
| 
								 | 
							
								            raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e
							 | 
						||
| 
								 | 
							
								        except KeyError as e:
							 | 
						||
| 
								 | 
							
								            raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not is_literal_type(t_discriminator_type):
							 | 
						||
| 
								 | 
							
								            raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _get_union_alias_and_all_values(
							 | 
						||
| 
								 | 
							
								    union_type: Type[Any], discriminator_key: str
							 | 
						||
| 
								 | 
							
								) -> Tuple[str, Tuple[Tuple[str, ...], ...]]:
							 | 
						||
| 
								 | 
							
								    zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)]
							 | 
						||
| 
								 | 
							
								    # unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
							 | 
						||
| 
								 | 
							
								    all_aliases, all_values = zip(*zipped_aliases_values)
							 | 
						||
| 
								 | 
							
								    return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values
							 |