You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					221 lines
				
				8.0 KiB
			
		
		
			
		
	
	
					221 lines
				
				8.0 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								from typing import Any, Dict, List, Optional, Union
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from fastapi.exceptions import HTTPException
							 | 
						||
| 
								 | 
							
								from fastapi.openapi.models import OAuth2 as OAuth2Model
							 | 
						||
| 
								 | 
							
								from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
							 | 
						||
| 
								 | 
							
								from fastapi.param_functions import Form
							 | 
						||
| 
								 | 
							
								from fastapi.security.base import SecurityBase
							 | 
						||
| 
								 | 
							
								from fastapi.security.utils import get_authorization_scheme_param
							 | 
						||
| 
								 | 
							
								from starlette.requests import Request
							 | 
						||
| 
								 | 
							
								from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class OAuth2PasswordRequestForm:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    This is a dependency class, use it like:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        @app.post("/login")
							 | 
						||
| 
								 | 
							
								        def login(form_data: OAuth2PasswordRequestForm = Depends()):
							 | 
						||
| 
								 | 
							
								            data = form_data.parse()
							 | 
						||
| 
								 | 
							
								            print(data.username)
							 | 
						||
| 
								 | 
							
								            print(data.password)
							 | 
						||
| 
								 | 
							
								            for scope in data.scopes:
							 | 
						||
| 
								 | 
							
								                print(scope)
							 | 
						||
| 
								 | 
							
								            if data.client_id:
							 | 
						||
| 
								 | 
							
								                print(data.client_id)
							 | 
						||
| 
								 | 
							
								            if data.client_secret:
							 | 
						||
| 
								 | 
							
								                print(data.client_secret)
							 | 
						||
| 
								 | 
							
								            return data
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    It creates the following Form request parameters in your endpoint:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
							 | 
						||
| 
								 | 
							
								        Nevertheless, this dependency class is permissive and allows not passing it. If you want to enforce it,
							 | 
						||
| 
								 | 
							
								        use instead the OAuth2PasswordRequestFormStrict dependency.
							 | 
						||
| 
								 | 
							
								    username: username string. The OAuth2 spec requires the exact field name "username".
							 | 
						||
| 
								 | 
							
								    password: password string. The OAuth2 spec requires the exact field name "password".
							 | 
						||
| 
								 | 
							
								    scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
							 | 
						||
| 
								 | 
							
								        "items:read items:write users:read profile openid"
							 | 
						||
| 
								 | 
							
								    client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
							 | 
						||
| 
								 | 
							
								        using HTTP Basic auth, as: client_id:client_secret
							 | 
						||
| 
								 | 
							
								    client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
							 | 
						||
| 
								 | 
							
								        using HTTP Basic auth, as: client_id:client_secret
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        grant_type: str = Form(None, regex="password"),
							 | 
						||
| 
								 | 
							
								        username: str = Form(...),
							 | 
						||
| 
								 | 
							
								        password: str = Form(...),
							 | 
						||
| 
								 | 
							
								        scope: str = Form(""),
							 | 
						||
| 
								 | 
							
								        client_id: Optional[str] = Form(None),
							 | 
						||
| 
								 | 
							
								        client_secret: Optional[str] = Form(None),
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        self.grant_type = grant_type
							 | 
						||
| 
								 | 
							
								        self.username = username
							 | 
						||
| 
								 | 
							
								        self.password = password
							 | 
						||
| 
								 | 
							
								        self.scopes = scope.split()
							 | 
						||
| 
								 | 
							
								        self.client_id = client_id
							 | 
						||
| 
								 | 
							
								        self.client_secret = client_secret
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    This is a dependency class, use it like:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        @app.post("/login")
							 | 
						||
| 
								 | 
							
								        def login(form_data: OAuth2PasswordRequestFormStrict = Depends()):
							 | 
						||
| 
								 | 
							
								            data = form_data.parse()
							 | 
						||
| 
								 | 
							
								            print(data.username)
							 | 
						||
| 
								 | 
							
								            print(data.password)
							 | 
						||
| 
								 | 
							
								            for scope in data.scopes:
							 | 
						||
| 
								 | 
							
								                print(scope)
							 | 
						||
| 
								 | 
							
								            if data.client_id:
							 | 
						||
| 
								 | 
							
								                print(data.client_id)
							 | 
						||
| 
								 | 
							
								            if data.client_secret:
							 | 
						||
| 
								 | 
							
								                print(data.client_secret)
							 | 
						||
| 
								 | 
							
								            return data
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    It creates the following Form request parameters in your endpoint:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
							 | 
						||
| 
								 | 
							
								        This dependency is strict about it. If you want to be permissive, use instead the
							 | 
						||
| 
								 | 
							
								        OAuth2PasswordRequestForm dependency class.
							 | 
						||
| 
								 | 
							
								    username: username string. The OAuth2 spec requires the exact field name "username".
							 | 
						||
| 
								 | 
							
								    password: password string. The OAuth2 spec requires the exact field name "password".
							 | 
						||
| 
								 | 
							
								    scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
							 | 
						||
| 
								 | 
							
								        "items:read items:write users:read profile openid"
							 | 
						||
| 
								 | 
							
								    client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
							 | 
						||
| 
								 | 
							
								        using HTTP Basic auth, as: client_id:client_secret
							 | 
						||
| 
								 | 
							
								    client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
							 | 
						||
| 
								 | 
							
								        using HTTP Basic auth, as: client_id:client_secret
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        grant_type: str = Form(..., regex="password"),
							 | 
						||
| 
								 | 
							
								        username: str = Form(...),
							 | 
						||
| 
								 | 
							
								        password: str = Form(...),
							 | 
						||
| 
								 | 
							
								        scope: str = Form(""),
							 | 
						||
| 
								 | 
							
								        client_id: Optional[str] = Form(None),
							 | 
						||
| 
								 | 
							
								        client_secret: Optional[str] = Form(None),
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        super().__init__(
							 | 
						||
| 
								 | 
							
								            grant_type=grant_type,
							 | 
						||
| 
								 | 
							
								            username=username,
							 | 
						||
| 
								 | 
							
								            password=password,
							 | 
						||
| 
								 | 
							
								            scope=scope,
							 | 
						||
| 
								 | 
							
								            client_id=client_id,
							 | 
						||
| 
								 | 
							
								            client_secret=client_secret,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class OAuth2(SecurityBase):
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        *,
							 | 
						||
| 
								 | 
							
								        flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(),
							 | 
						||
| 
								 | 
							
								        scheme_name: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        description: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        auto_error: Optional[bool] = True
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        self.model = OAuth2Model(flows=flows, description=description)
							 | 
						||
| 
								 | 
							
								        self.scheme_name = scheme_name or self.__class__.__name__
							 | 
						||
| 
								 | 
							
								        self.auto_error = auto_error
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, request: Request) -> Optional[str]:
							 | 
						||
| 
								 | 
							
								        authorization: str = request.headers.get("Authorization")
							 | 
						||
| 
								 | 
							
								        if not authorization:
							 | 
						||
| 
								 | 
							
								            if self.auto_error:
							 | 
						||
| 
								 | 
							
								                raise HTTPException(
							 | 
						||
| 
								 | 
							
								                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return None
							 | 
						||
| 
								 | 
							
								        return authorization
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class OAuth2PasswordBearer(OAuth2):
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        tokenUrl: str,
							 | 
						||
| 
								 | 
							
								        scheme_name: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        scopes: Optional[Dict[str, str]] = None,
							 | 
						||
| 
								 | 
							
								        description: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        auto_error: bool = True,
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        if not scopes:
							 | 
						||
| 
								 | 
							
								            scopes = {}
							 | 
						||
| 
								 | 
							
								        flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
							 | 
						||
| 
								 | 
							
								        super().__init__(
							 | 
						||
| 
								 | 
							
								            flows=flows,
							 | 
						||
| 
								 | 
							
								            scheme_name=scheme_name,
							 | 
						||
| 
								 | 
							
								            description=description,
							 | 
						||
| 
								 | 
							
								            auto_error=auto_error,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, request: Request) -> Optional[str]:
							 | 
						||
| 
								 | 
							
								        authorization: str = request.headers.get("Authorization")
							 | 
						||
| 
								 | 
							
								        scheme, param = get_authorization_scheme_param(authorization)
							 | 
						||
| 
								 | 
							
								        if not authorization or scheme.lower() != "bearer":
							 | 
						||
| 
								 | 
							
								            if self.auto_error:
							 | 
						||
| 
								 | 
							
								                raise HTTPException(
							 | 
						||
| 
								 | 
							
								                    status_code=HTTP_401_UNAUTHORIZED,
							 | 
						||
| 
								 | 
							
								                    detail="Not authenticated",
							 | 
						||
| 
								 | 
							
								                    headers={"WWW-Authenticate": "Bearer"},
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return None
							 | 
						||
| 
								 | 
							
								        return param
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class OAuth2AuthorizationCodeBearer(OAuth2):
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        authorizationUrl: str,
							 | 
						||
| 
								 | 
							
								        tokenUrl: str,
							 | 
						||
| 
								 | 
							
								        refreshUrl: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        scheme_name: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        scopes: Optional[Dict[str, str]] = None,
							 | 
						||
| 
								 | 
							
								        description: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        auto_error: bool = True,
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        if not scopes:
							 | 
						||
| 
								 | 
							
								            scopes = {}
							 | 
						||
| 
								 | 
							
								        flows = OAuthFlowsModel(
							 | 
						||
| 
								 | 
							
								            authorizationCode={
							 | 
						||
| 
								 | 
							
								                "authorizationUrl": authorizationUrl,
							 | 
						||
| 
								 | 
							
								                "tokenUrl": tokenUrl,
							 | 
						||
| 
								 | 
							
								                "refreshUrl": refreshUrl,
							 | 
						||
| 
								 | 
							
								                "scopes": scopes,
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        super().__init__(
							 | 
						||
| 
								 | 
							
								            flows=flows,
							 | 
						||
| 
								 | 
							
								            scheme_name=scheme_name,
							 | 
						||
| 
								 | 
							
								            description=description,
							 | 
						||
| 
								 | 
							
								            auto_error=auto_error,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, request: Request) -> Optional[str]:
							 | 
						||
| 
								 | 
							
								        authorization: str = request.headers.get("Authorization")
							 | 
						||
| 
								 | 
							
								        scheme, param = get_authorization_scheme_param(authorization)
							 | 
						||
| 
								 | 
							
								        if not authorization or scheme.lower() != "bearer":
							 | 
						||
| 
								 | 
							
								            if self.auto_error:
							 | 
						||
| 
								 | 
							
								                raise HTTPException(
							 | 
						||
| 
								 | 
							
								                    status_code=HTTP_401_UNAUTHORIZED,
							 | 
						||
| 
								 | 
							
								                    detail="Not authenticated",
							 | 
						||
| 
								 | 
							
								                    headers={"WWW-Authenticate": "Bearer"},
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return None  # pragma: nocover
							 | 
						||
| 
								 | 
							
								        return param
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SecurityScopes:
							 | 
						||
| 
								 | 
							
								    def __init__(self, scopes: Optional[List[str]] = None):
							 | 
						||
| 
								 | 
							
								        self.scopes = scopes or []
							 | 
						||
| 
								 | 
							
								        self.scope_str = " ".join(self.scopes)
							 |