You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							53 lines
						
					
					
						
							1.7 KiB
						
					
					
				
			
		
		
	
	
							53 lines
						
					
					
						
							1.7 KiB
						
					
					
				import typing
 | 
						|
 | 
						|
from starlette.authentication import (
 | 
						|
    AuthCredentials,
 | 
						|
    AuthenticationBackend,
 | 
						|
    AuthenticationError,
 | 
						|
    UnauthenticatedUser,
 | 
						|
)
 | 
						|
from starlette.requests import HTTPConnection
 | 
						|
from starlette.responses import PlainTextResponse, Response
 | 
						|
from starlette.types import ASGIApp, Receive, Scope, Send
 | 
						|
 | 
						|
 | 
						|
class AuthenticationMiddleware:
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        app: ASGIApp,
 | 
						|
        backend: AuthenticationBackend,
 | 
						|
        on_error: typing.Callable[
 | 
						|
            [HTTPConnection, AuthenticationError], Response
 | 
						|
        ] = None,
 | 
						|
    ) -> None:
 | 
						|
        self.app = app
 | 
						|
        self.backend = backend
 | 
						|
        self.on_error: typing.Callable[
 | 
						|
            [HTTPConnection, AuthenticationError], Response
 | 
						|
        ] = (on_error if on_error is not None else self.default_on_error)
 | 
						|
 | 
						|
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
 | 
						|
        if scope["type"] not in ["http", "websocket"]:
 | 
						|
            await self.app(scope, receive, send)
 | 
						|
            return
 | 
						|
 | 
						|
        conn = HTTPConnection(scope)
 | 
						|
        try:
 | 
						|
            auth_result = await self.backend.authenticate(conn)
 | 
						|
        except AuthenticationError as exc:
 | 
						|
            response = self.on_error(conn, exc)
 | 
						|
            if scope["type"] == "websocket":
 | 
						|
                await send({"type": "websocket.close", "code": 1000})
 | 
						|
            else:
 | 
						|
                await response(scope, receive, send)
 | 
						|
            return
 | 
						|
 | 
						|
        if auth_result is None:
 | 
						|
            auth_result = AuthCredentials(), UnauthenticatedUser()
 | 
						|
        scope["auth"], scope["user"] = auth_result
 | 
						|
        await self.app(scope, receive, send)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
 | 
						|
        return PlainTextResponse(str(exc), status_code=400)
 |