You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					120 lines
				
				3.3 KiB
			
		
		
			
		
	
	
					120 lines
				
				3.3 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import datetime
							 | 
						||
| 
								 | 
							
								import re
							 | 
						||
| 
								 | 
							
								import sys
							 | 
						||
| 
								 | 
							
								from collections import deque
							 | 
						||
| 
								 | 
							
								from decimal import Decimal
							 | 
						||
| 
								 | 
							
								from enum import Enum
							 | 
						||
| 
								 | 
							
								from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
							 | 
						||
| 
								 | 
							
								from pathlib import Path
							 | 
						||
| 
								 | 
							
								from types import GeneratorType
							 | 
						||
| 
								 | 
							
								from typing import Any, Callable, Dict, Type, Union
							 | 
						||
| 
								 | 
							
								from uuid import UUID
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if sys.version_info >= (3, 7):
							 | 
						||
| 
								 | 
							
								    Pattern = re.Pattern
							 | 
						||
| 
								 | 
							
								else:
							 | 
						||
| 
								 | 
							
								    # python 3.6
							 | 
						||
| 
								 | 
							
								    Pattern = re.compile('a').__class__
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .color import Color
							 | 
						||
| 
								 | 
							
								from .networks import NameEmail
							 | 
						||
| 
								 | 
							
								from .types import SecretBytes, SecretStr
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def isoformat(o: Union[datetime.date, datetime.time]) -> str:
							 | 
						||
| 
								 | 
							
								    return o.isoformat()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Encodes a Decimal as int of there's no exponent, otherwise float
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
							 | 
						||
| 
								 | 
							
								    where a integer (but not int typed) is used. Encoding this as a float
							 | 
						||
| 
								 | 
							
								    results in failed round-tripping between encode and prase.
							 | 
						||
| 
								 | 
							
								    Our Id type is a prime example of this.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    >>> decimal_encoder(Decimal("1.0"))
							 | 
						||
| 
								 | 
							
								    1.0
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    >>> decimal_encoder(Decimal("1"))
							 | 
						||
| 
								 | 
							
								    1
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    if dec_value.as_tuple().exponent >= 0:
							 | 
						||
| 
								 | 
							
								        return int(dec_value)
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        return float(dec_value)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
							 | 
						||
| 
								 | 
							
								    bytes: lambda o: o.decode(),
							 | 
						||
| 
								 | 
							
								    Color: str,
							 | 
						||
| 
								 | 
							
								    datetime.date: isoformat,
							 | 
						||
| 
								 | 
							
								    datetime.datetime: isoformat,
							 | 
						||
| 
								 | 
							
								    datetime.time: isoformat,
							 | 
						||
| 
								 | 
							
								    datetime.timedelta: lambda td: td.total_seconds(),
							 | 
						||
| 
								 | 
							
								    Decimal: decimal_encoder,
							 | 
						||
| 
								 | 
							
								    Enum: lambda o: o.value,
							 | 
						||
| 
								 | 
							
								    frozenset: list,
							 | 
						||
| 
								 | 
							
								    deque: list,
							 | 
						||
| 
								 | 
							
								    GeneratorType: list,
							 | 
						||
| 
								 | 
							
								    IPv4Address: str,
							 | 
						||
| 
								 | 
							
								    IPv4Interface: str,
							 | 
						||
| 
								 | 
							
								    IPv4Network: str,
							 | 
						||
| 
								 | 
							
								    IPv6Address: str,
							 | 
						||
| 
								 | 
							
								    IPv6Interface: str,
							 | 
						||
| 
								 | 
							
								    IPv6Network: str,
							 | 
						||
| 
								 | 
							
								    NameEmail: str,
							 | 
						||
| 
								 | 
							
								    Path: str,
							 | 
						||
| 
								 | 
							
								    Pattern: lambda o: o.pattern,
							 | 
						||
| 
								 | 
							
								    SecretBytes: str,
							 | 
						||
| 
								 | 
							
								    SecretStr: str,
							 | 
						||
| 
								 | 
							
								    set: list,
							 | 
						||
| 
								 | 
							
								    UUID: str,
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def pydantic_encoder(obj: Any) -> Any:
							 | 
						||
| 
								 | 
							
								    from dataclasses import asdict, is_dataclass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    from .main import BaseModel
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if isinstance(obj, BaseModel):
							 | 
						||
| 
								 | 
							
								        return obj.dict()
							 | 
						||
| 
								 | 
							
								    elif is_dataclass(obj):
							 | 
						||
| 
								 | 
							
								        return asdict(obj)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Check the class type and its superclasses for a matching encoder
							 | 
						||
| 
								 | 
							
								    for base in obj.__class__.__mro__[:-1]:
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            encoder = ENCODERS_BY_TYPE[base]
							 | 
						||
| 
								 | 
							
								        except KeyError:
							 | 
						||
| 
								 | 
							
								            continue
							 | 
						||
| 
								 | 
							
								        return encoder(obj)
							 | 
						||
| 
								 | 
							
								    else:  # We have exited the for loop without finding a suitable encoder
							 | 
						||
| 
								 | 
							
								        raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
							 | 
						||
| 
								 | 
							
								    # Check the class type and its superclasses for a matching encoder
							 | 
						||
| 
								 | 
							
								    for base in obj.__class__.__mro__[:-1]:
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            encoder = type_encoders[base]
							 | 
						||
| 
								 | 
							
								        except KeyError:
							 | 
						||
| 
								 | 
							
								            continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return encoder(obj)
							 | 
						||
| 
								 | 
							
								    else:  # We have exited the for loop without finding a suitable encoder
							 | 
						||
| 
								 | 
							
								        return pydantic_encoder(obj)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def timedelta_isoformat(td: datetime.timedelta) -> str:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    ISO 8601 encoding for timedeltas.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    minutes, seconds = divmod(td.seconds, 60)
							 | 
						||
| 
								 | 
							
								    hours, minutes = divmod(minutes, 60)
							 | 
						||
| 
								 | 
							
								    return f'P{td.days}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'
							 |