You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					166 lines
				
				5.8 KiB
			
		
		
			
		
	
	
					166 lines
				
				5.8 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import binascii
							 | 
						||
| 
								 | 
							
								from base64 import b64decode
							 | 
						||
| 
								 | 
							
								from typing import Optional
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from fastapi.exceptions import HTTPException
							 | 
						||
| 
								 | 
							
								from fastapi.openapi.models import HTTPBase as HTTPBaseModel
							 | 
						||
| 
								 | 
							
								from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
							 | 
						||
| 
								 | 
							
								from fastapi.security.base import SecurityBase
							 | 
						||
| 
								 | 
							
								from fastapi.security.utils import get_authorization_scheme_param
							 | 
						||
| 
								 | 
							
								from pydantic import BaseModel
							 | 
						||
| 
								 | 
							
								from starlette.requests import Request
							 | 
						||
| 
								 | 
							
								from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class HTTPBasicCredentials(BaseModel):
							 | 
						||
| 
								 | 
							
								    username: str
							 | 
						||
| 
								 | 
							
								    password: str
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class HTTPAuthorizationCredentials(BaseModel):
							 | 
						||
| 
								 | 
							
								    scheme: str
							 | 
						||
| 
								 | 
							
								    credentials: str
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class HTTPBase(SecurityBase):
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        *,
							 | 
						||
| 
								 | 
							
								        scheme: str,
							 | 
						||
| 
								 | 
							
								        scheme_name: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        description: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        auto_error: bool = True,
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        self.model = HTTPBaseModel(scheme=scheme, description=description)
							 | 
						||
| 
								 | 
							
								        self.scheme_name = scheme_name or self.__class__.__name__
							 | 
						||
| 
								 | 
							
								        self.auto_error = auto_error
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(
							 | 
						||
| 
								 | 
							
								        self, request: Request
							 | 
						||
| 
								 | 
							
								    ) -> Optional[HTTPAuthorizationCredentials]:
							 | 
						||
| 
								 | 
							
								        authorization: str = request.headers.get("Authorization")
							 | 
						||
| 
								 | 
							
								        scheme, credentials = get_authorization_scheme_param(authorization)
							 | 
						||
| 
								 | 
							
								        if not (authorization and scheme and credentials):
							 | 
						||
| 
								 | 
							
								            if self.auto_error:
							 | 
						||
| 
								 | 
							
								                raise HTTPException(
							 | 
						||
| 
								 | 
							
								                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return None
							 | 
						||
| 
								 | 
							
								        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class HTTPBasic(HTTPBase):
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        *,
							 | 
						||
| 
								 | 
							
								        scheme_name: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        realm: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        description: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        auto_error: bool = True,
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        self.model = HTTPBaseModel(scheme="basic", description=description)
							 | 
						||
| 
								 | 
							
								        self.scheme_name = scheme_name or self.__class__.__name__
							 | 
						||
| 
								 | 
							
								        self.realm = realm
							 | 
						||
| 
								 | 
							
								        self.auto_error = auto_error
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(  # type: ignore
							 | 
						||
| 
								 | 
							
								        self, request: Request
							 | 
						||
| 
								 | 
							
								    ) -> Optional[HTTPBasicCredentials]:
							 | 
						||
| 
								 | 
							
								        authorization: str = request.headers.get("Authorization")
							 | 
						||
| 
								 | 
							
								        scheme, param = get_authorization_scheme_param(authorization)
							 | 
						||
| 
								 | 
							
								        if self.realm:
							 | 
						||
| 
								 | 
							
								            unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            unauthorized_headers = {"WWW-Authenticate": "Basic"}
							 | 
						||
| 
								 | 
							
								        invalid_user_credentials_exc = HTTPException(
							 | 
						||
| 
								 | 
							
								            status_code=HTTP_401_UNAUTHORIZED,
							 | 
						||
| 
								 | 
							
								            detail="Invalid authentication credentials",
							 | 
						||
| 
								 | 
							
								            headers=unauthorized_headers,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        if not authorization or scheme.lower() != "basic":
							 | 
						||
| 
								 | 
							
								            if self.auto_error:
							 | 
						||
| 
								 | 
							
								                raise HTTPException(
							 | 
						||
| 
								 | 
							
								                    status_code=HTTP_401_UNAUTHORIZED,
							 | 
						||
| 
								 | 
							
								                    detail="Not authenticated",
							 | 
						||
| 
								 | 
							
								                    headers=unauthorized_headers,
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return None
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            data = b64decode(param).decode("ascii")
							 | 
						||
| 
								 | 
							
								        except (ValueError, UnicodeDecodeError, binascii.Error):
							 | 
						||
| 
								 | 
							
								            raise invalid_user_credentials_exc
							 | 
						||
| 
								 | 
							
								        username, separator, password = data.partition(":")
							 | 
						||
| 
								 | 
							
								        if not separator:
							 | 
						||
| 
								 | 
							
								            raise invalid_user_credentials_exc
							 | 
						||
| 
								 | 
							
								        return HTTPBasicCredentials(username=username, password=password)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class HTTPBearer(HTTPBase):
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        *,
							 | 
						||
| 
								 | 
							
								        bearerFormat: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        scheme_name: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        description: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        auto_error: bool = True,
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description)
							 | 
						||
| 
								 | 
							
								        self.scheme_name = scheme_name or self.__class__.__name__
							 | 
						||
| 
								 | 
							
								        self.auto_error = auto_error
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(
							 | 
						||
| 
								 | 
							
								        self, request: Request
							 | 
						||
| 
								 | 
							
								    ) -> Optional[HTTPAuthorizationCredentials]:
							 | 
						||
| 
								 | 
							
								        authorization: str = request.headers.get("Authorization")
							 | 
						||
| 
								 | 
							
								        scheme, credentials = get_authorization_scheme_param(authorization)
							 | 
						||
| 
								 | 
							
								        if not (authorization and scheme and credentials):
							 | 
						||
| 
								 | 
							
								            if self.auto_error:
							 | 
						||
| 
								 | 
							
								                raise HTTPException(
							 | 
						||
| 
								 | 
							
								                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return None
							 | 
						||
| 
								 | 
							
								        if scheme.lower() != "bearer":
							 | 
						||
| 
								 | 
							
								            if self.auto_error:
							 | 
						||
| 
								 | 
							
								                raise HTTPException(
							 | 
						||
| 
								 | 
							
								                    status_code=HTTP_403_FORBIDDEN,
							 | 
						||
| 
								 | 
							
								                    detail="Invalid authentication credentials",
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return None
							 | 
						||
| 
								 | 
							
								        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class HTTPDigest(HTTPBase):
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        *,
							 | 
						||
| 
								 | 
							
								        scheme_name: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        description: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        auto_error: bool = True,
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        self.model = HTTPBaseModel(scheme="digest", description=description)
							 | 
						||
| 
								 | 
							
								        self.scheme_name = scheme_name or self.__class__.__name__
							 | 
						||
| 
								 | 
							
								        self.auto_error = auto_error
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(
							 | 
						||
| 
								 | 
							
								        self, request: Request
							 | 
						||
| 
								 | 
							
								    ) -> Optional[HTTPAuthorizationCredentials]:
							 | 
						||
| 
								 | 
							
								        authorization: str = request.headers.get("Authorization")
							 | 
						||
| 
								 | 
							
								        scheme, credentials = get_authorization_scheme_param(authorization)
							 | 
						||
| 
								 | 
							
								        if not (authorization and scheme and credentials):
							 | 
						||
| 
								 | 
							
								            if self.auto_error:
							 | 
						||
| 
								 | 
							
								                raise HTTPException(
							 | 
						||
| 
								 | 
							
								                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return None
							 | 
						||
| 
								 | 
							
								        if scheme.lower() != "digest":
							 | 
						||
| 
								 | 
							
								            raise HTTPException(
							 | 
						||
| 
								 | 
							
								                status_code=HTTP_403_FORBIDDEN,
							 | 
						||
| 
								 | 
							
								                detail="Invalid authentication credentials",
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
							 |