You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							221 lines
						
					
					
						
							8.0 KiB
						
					
					
				
			
		
		
	
	
							221 lines
						
					
					
						
							8.0 KiB
						
					
					
				from typing import Any, Dict, List, Optional, Union
 | 
						|
 | 
						|
from fastapi.exceptions import HTTPException
 | 
						|
from fastapi.openapi.models import OAuth2 as OAuth2Model
 | 
						|
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
 | 
						|
from fastapi.param_functions import Form
 | 
						|
from fastapi.security.base import SecurityBase
 | 
						|
from fastapi.security.utils import get_authorization_scheme_param
 | 
						|
from starlette.requests import Request
 | 
						|
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
 | 
						|
 | 
						|
 | 
						|
class OAuth2PasswordRequestForm:
 | 
						|
    """
 | 
						|
    This is a dependency class, use it like:
 | 
						|
 | 
						|
        @app.post("/login")
 | 
						|
        def login(form_data: OAuth2PasswordRequestForm = Depends()):
 | 
						|
            data = form_data.parse()
 | 
						|
            print(data.username)
 | 
						|
            print(data.password)
 | 
						|
            for scope in data.scopes:
 | 
						|
                print(scope)
 | 
						|
            if data.client_id:
 | 
						|
                print(data.client_id)
 | 
						|
            if data.client_secret:
 | 
						|
                print(data.client_secret)
 | 
						|
            return data
 | 
						|
 | 
						|
 | 
						|
    It creates the following Form request parameters in your endpoint:
 | 
						|
 | 
						|
    grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
 | 
						|
        Nevertheless, this dependency class is permissive and allows not passing it. If you want to enforce it,
 | 
						|
        use instead the OAuth2PasswordRequestFormStrict dependency.
 | 
						|
    username: username string. The OAuth2 spec requires the exact field name "username".
 | 
						|
    password: password string. The OAuth2 spec requires the exact field name "password".
 | 
						|
    scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
 | 
						|
        "items:read items:write users:read profile openid"
 | 
						|
    client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
 | 
						|
        using HTTP Basic auth, as: client_id:client_secret
 | 
						|
    client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
 | 
						|
        using HTTP Basic auth, as: client_id:client_secret
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        grant_type: str = Form(None, regex="password"),
 | 
						|
        username: str = Form(...),
 | 
						|
        password: str = Form(...),
 | 
						|
        scope: str = Form(""),
 | 
						|
        client_id: Optional[str] = Form(None),
 | 
						|
        client_secret: Optional[str] = Form(None),
 | 
						|
    ):
 | 
						|
        self.grant_type = grant_type
 | 
						|
        self.username = username
 | 
						|
        self.password = password
 | 
						|
        self.scopes = scope.split()
 | 
						|
        self.client_id = client_id
 | 
						|
        self.client_secret = client_secret
 | 
						|
 | 
						|
 | 
						|
class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
 | 
						|
    """
 | 
						|
    This is a dependency class, use it like:
 | 
						|
 | 
						|
        @app.post("/login")
 | 
						|
        def login(form_data: OAuth2PasswordRequestFormStrict = Depends()):
 | 
						|
            data = form_data.parse()
 | 
						|
            print(data.username)
 | 
						|
            print(data.password)
 | 
						|
            for scope in data.scopes:
 | 
						|
                print(scope)
 | 
						|
            if data.client_id:
 | 
						|
                print(data.client_id)
 | 
						|
            if data.client_secret:
 | 
						|
                print(data.client_secret)
 | 
						|
            return data
 | 
						|
 | 
						|
 | 
						|
    It creates the following Form request parameters in your endpoint:
 | 
						|
 | 
						|
    grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
 | 
						|
        This dependency is strict about it. If you want to be permissive, use instead the
 | 
						|
        OAuth2PasswordRequestForm dependency class.
 | 
						|
    username: username string. The OAuth2 spec requires the exact field name "username".
 | 
						|
    password: password string. The OAuth2 spec requires the exact field name "password".
 | 
						|
    scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
 | 
						|
        "items:read items:write users:read profile openid"
 | 
						|
    client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
 | 
						|
        using HTTP Basic auth, as: client_id:client_secret
 | 
						|
    client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
 | 
						|
        using HTTP Basic auth, as: client_id:client_secret
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        grant_type: str = Form(..., regex="password"),
 | 
						|
        username: str = Form(...),
 | 
						|
        password: str = Form(...),
 | 
						|
        scope: str = Form(""),
 | 
						|
        client_id: Optional[str] = Form(None),
 | 
						|
        client_secret: Optional[str] = Form(None),
 | 
						|
    ):
 | 
						|
        super().__init__(
 | 
						|
            grant_type=grant_type,
 | 
						|
            username=username,
 | 
						|
            password=password,
 | 
						|
            scope=scope,
 | 
						|
            client_id=client_id,
 | 
						|
            client_secret=client_secret,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class OAuth2(SecurityBase):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(),
 | 
						|
        scheme_name: Optional[str] = None,
 | 
						|
        description: Optional[str] = None,
 | 
						|
        auto_error: Optional[bool] = True
 | 
						|
    ):
 | 
						|
        self.model = OAuth2Model(flows=flows, description=description)
 | 
						|
        self.scheme_name = scheme_name or self.__class__.__name__
 | 
						|
        self.auto_error = auto_error
 | 
						|
 | 
						|
    async def __call__(self, request: Request) -> Optional[str]:
 | 
						|
        authorization: str = request.headers.get("Authorization")
 | 
						|
        if not authorization:
 | 
						|
            if self.auto_error:
 | 
						|
                raise HTTPException(
 | 
						|
                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        return authorization
 | 
						|
 | 
						|
 | 
						|
class OAuth2PasswordBearer(OAuth2):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        tokenUrl: str,
 | 
						|
        scheme_name: Optional[str] = None,
 | 
						|
        scopes: Optional[Dict[str, str]] = None,
 | 
						|
        description: Optional[str] = None,
 | 
						|
        auto_error: bool = True,
 | 
						|
    ):
 | 
						|
        if not scopes:
 | 
						|
            scopes = {}
 | 
						|
        flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
 | 
						|
        super().__init__(
 | 
						|
            flows=flows,
 | 
						|
            scheme_name=scheme_name,
 | 
						|
            description=description,
 | 
						|
            auto_error=auto_error,
 | 
						|
        )
 | 
						|
 | 
						|
    async def __call__(self, request: Request) -> Optional[str]:
 | 
						|
        authorization: str = request.headers.get("Authorization")
 | 
						|
        scheme, param = get_authorization_scheme_param(authorization)
 | 
						|
        if not authorization or scheme.lower() != "bearer":
 | 
						|
            if self.auto_error:
 | 
						|
                raise HTTPException(
 | 
						|
                    status_code=HTTP_401_UNAUTHORIZED,
 | 
						|
                    detail="Not authenticated",
 | 
						|
                    headers={"WWW-Authenticate": "Bearer"},
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        return param
 | 
						|
 | 
						|
 | 
						|
class OAuth2AuthorizationCodeBearer(OAuth2):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        authorizationUrl: str,
 | 
						|
        tokenUrl: str,
 | 
						|
        refreshUrl: Optional[str] = None,
 | 
						|
        scheme_name: Optional[str] = None,
 | 
						|
        scopes: Optional[Dict[str, str]] = None,
 | 
						|
        description: Optional[str] = None,
 | 
						|
        auto_error: bool = True,
 | 
						|
    ):
 | 
						|
        if not scopes:
 | 
						|
            scopes = {}
 | 
						|
        flows = OAuthFlowsModel(
 | 
						|
            authorizationCode={
 | 
						|
                "authorizationUrl": authorizationUrl,
 | 
						|
                "tokenUrl": tokenUrl,
 | 
						|
                "refreshUrl": refreshUrl,
 | 
						|
                "scopes": scopes,
 | 
						|
            }
 | 
						|
        )
 | 
						|
        super().__init__(
 | 
						|
            flows=flows,
 | 
						|
            scheme_name=scheme_name,
 | 
						|
            description=description,
 | 
						|
            auto_error=auto_error,
 | 
						|
        )
 | 
						|
 | 
						|
    async def __call__(self, request: Request) -> Optional[str]:
 | 
						|
        authorization: str = request.headers.get("Authorization")
 | 
						|
        scheme, param = get_authorization_scheme_param(authorization)
 | 
						|
        if not authorization or scheme.lower() != "bearer":
 | 
						|
            if self.auto_error:
 | 
						|
                raise HTTPException(
 | 
						|
                    status_code=HTTP_401_UNAUTHORIZED,
 | 
						|
                    detail="Not authenticated",
 | 
						|
                    headers={"WWW-Authenticate": "Bearer"},
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return None  # pragma: nocover
 | 
						|
        return param
 | 
						|
 | 
						|
 | 
						|
class SecurityScopes:
 | 
						|
    def __init__(self, scopes: Optional[List[str]] = None):
 | 
						|
        self.scopes = scopes or []
 | 
						|
        self.scope_str = " ".join(self.scopes)
 |