You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							166 lines
						
					
					
						
							5.8 KiB
						
					
					
				
			
		
		
	
	
							166 lines
						
					
					
						
							5.8 KiB
						
					
					
				import binascii
 | 
						|
from base64 import b64decode
 | 
						|
from typing import Optional
 | 
						|
 | 
						|
from fastapi.exceptions import HTTPException
 | 
						|
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
 | 
						|
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
 | 
						|
from fastapi.security.base import SecurityBase
 | 
						|
from fastapi.security.utils import get_authorization_scheme_param
 | 
						|
from pydantic import BaseModel
 | 
						|
from starlette.requests import Request
 | 
						|
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
 | 
						|
 | 
						|
 | 
						|
class HTTPBasicCredentials(BaseModel):
 | 
						|
    username: str
 | 
						|
    password: str
 | 
						|
 | 
						|
 | 
						|
class HTTPAuthorizationCredentials(BaseModel):
 | 
						|
    scheme: str
 | 
						|
    credentials: str
 | 
						|
 | 
						|
 | 
						|
class HTTPBase(SecurityBase):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        scheme: str,
 | 
						|
        scheme_name: Optional[str] = None,
 | 
						|
        description: Optional[str] = None,
 | 
						|
        auto_error: bool = True,
 | 
						|
    ):
 | 
						|
        self.model = HTTPBaseModel(scheme=scheme, description=description)
 | 
						|
        self.scheme_name = scheme_name or self.__class__.__name__
 | 
						|
        self.auto_error = auto_error
 | 
						|
 | 
						|
    async def __call__(
 | 
						|
        self, request: Request
 | 
						|
    ) -> Optional[HTTPAuthorizationCredentials]:
 | 
						|
        authorization: str = request.headers.get("Authorization")
 | 
						|
        scheme, credentials = get_authorization_scheme_param(authorization)
 | 
						|
        if not (authorization and scheme and credentials):
 | 
						|
            if self.auto_error:
 | 
						|
                raise HTTPException(
 | 
						|
                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
 | 
						|
 | 
						|
 | 
						|
class HTTPBasic(HTTPBase):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        scheme_name: Optional[str] = None,
 | 
						|
        realm: Optional[str] = None,
 | 
						|
        description: Optional[str] = None,
 | 
						|
        auto_error: bool = True,
 | 
						|
    ):
 | 
						|
        self.model = HTTPBaseModel(scheme="basic", description=description)
 | 
						|
        self.scheme_name = scheme_name or self.__class__.__name__
 | 
						|
        self.realm = realm
 | 
						|
        self.auto_error = auto_error
 | 
						|
 | 
						|
    async def __call__(  # type: ignore
 | 
						|
        self, request: Request
 | 
						|
    ) -> Optional[HTTPBasicCredentials]:
 | 
						|
        authorization: str = request.headers.get("Authorization")
 | 
						|
        scheme, param = get_authorization_scheme_param(authorization)
 | 
						|
        if self.realm:
 | 
						|
            unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
 | 
						|
        else:
 | 
						|
            unauthorized_headers = {"WWW-Authenticate": "Basic"}
 | 
						|
        invalid_user_credentials_exc = HTTPException(
 | 
						|
            status_code=HTTP_401_UNAUTHORIZED,
 | 
						|
            detail="Invalid authentication credentials",
 | 
						|
            headers=unauthorized_headers,
 | 
						|
        )
 | 
						|
        if not authorization or scheme.lower() != "basic":
 | 
						|
            if self.auto_error:
 | 
						|
                raise HTTPException(
 | 
						|
                    status_code=HTTP_401_UNAUTHORIZED,
 | 
						|
                    detail="Not authenticated",
 | 
						|
                    headers=unauthorized_headers,
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        try:
 | 
						|
            data = b64decode(param).decode("ascii")
 | 
						|
        except (ValueError, UnicodeDecodeError, binascii.Error):
 | 
						|
            raise invalid_user_credentials_exc
 | 
						|
        username, separator, password = data.partition(":")
 | 
						|
        if not separator:
 | 
						|
            raise invalid_user_credentials_exc
 | 
						|
        return HTTPBasicCredentials(username=username, password=password)
 | 
						|
 | 
						|
 | 
						|
class HTTPBearer(HTTPBase):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        bearerFormat: Optional[str] = None,
 | 
						|
        scheme_name: Optional[str] = None,
 | 
						|
        description: Optional[str] = None,
 | 
						|
        auto_error: bool = True,
 | 
						|
    ):
 | 
						|
        self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description)
 | 
						|
        self.scheme_name = scheme_name or self.__class__.__name__
 | 
						|
        self.auto_error = auto_error
 | 
						|
 | 
						|
    async def __call__(
 | 
						|
        self, request: Request
 | 
						|
    ) -> Optional[HTTPAuthorizationCredentials]:
 | 
						|
        authorization: str = request.headers.get("Authorization")
 | 
						|
        scheme, credentials = get_authorization_scheme_param(authorization)
 | 
						|
        if not (authorization and scheme and credentials):
 | 
						|
            if self.auto_error:
 | 
						|
                raise HTTPException(
 | 
						|
                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        if scheme.lower() != "bearer":
 | 
						|
            if self.auto_error:
 | 
						|
                raise HTTPException(
 | 
						|
                    status_code=HTTP_403_FORBIDDEN,
 | 
						|
                    detail="Invalid authentication credentials",
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
 | 
						|
 | 
						|
 | 
						|
class HTTPDigest(HTTPBase):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        scheme_name: Optional[str] = None,
 | 
						|
        description: Optional[str] = None,
 | 
						|
        auto_error: bool = True,
 | 
						|
    ):
 | 
						|
        self.model = HTTPBaseModel(scheme="digest", description=description)
 | 
						|
        self.scheme_name = scheme_name or self.__class__.__name__
 | 
						|
        self.auto_error = auto_error
 | 
						|
 | 
						|
    async def __call__(
 | 
						|
        self, request: Request
 | 
						|
    ) -> Optional[HTTPAuthorizationCredentials]:
 | 
						|
        authorization: str = request.headers.get("Authorization")
 | 
						|
        scheme, credentials = get_authorization_scheme_param(authorization)
 | 
						|
        if not (authorization and scheme and credentials):
 | 
						|
            if self.auto_error:
 | 
						|
                raise HTTPException(
 | 
						|
                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        if scheme.lower() != "digest":
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=HTTP_403_FORBIDDEN,
 | 
						|
                detail="Invalid authentication credentials",
 | 
						|
            )
 | 
						|
        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
 |