You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					53 lines
				
				1.7 KiB
			
		
		
			
		
	
	
					53 lines
				
				1.7 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import typing
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from starlette.authentication import (
							 | 
						||
| 
								 | 
							
								    AuthCredentials,
							 | 
						||
| 
								 | 
							
								    AuthenticationBackend,
							 | 
						||
| 
								 | 
							
								    AuthenticationError,
							 | 
						||
| 
								 | 
							
								    UnauthenticatedUser,
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								from starlette.requests import HTTPConnection
							 | 
						||
| 
								 | 
							
								from starlette.responses import PlainTextResponse, Response
							 | 
						||
| 
								 | 
							
								from starlette.types import ASGIApp, Receive, Scope, Send
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class AuthenticationMiddleware:
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        app: ASGIApp,
							 | 
						||
| 
								 | 
							
								        backend: AuthenticationBackend,
							 | 
						||
| 
								 | 
							
								        on_error: typing.Callable[
							 | 
						||
| 
								 | 
							
								            [HTTPConnection, AuthenticationError], Response
							 | 
						||
| 
								 | 
							
								        ] = None,
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        self.app = app
							 | 
						||
| 
								 | 
							
								        self.backend = backend
							 | 
						||
| 
								 | 
							
								        self.on_error: typing.Callable[
							 | 
						||
| 
								 | 
							
								            [HTTPConnection, AuthenticationError], Response
							 | 
						||
| 
								 | 
							
								        ] = (on_error if on_error is not None else self.default_on_error)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
							 | 
						||
| 
								 | 
							
								        if scope["type"] not in ["http", "websocket"]:
							 | 
						||
| 
								 | 
							
								            await self.app(scope, receive, send)
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        conn = HTTPConnection(scope)
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            auth_result = await self.backend.authenticate(conn)
							 | 
						||
| 
								 | 
							
								        except AuthenticationError as exc:
							 | 
						||
| 
								 | 
							
								            response = self.on_error(conn, exc)
							 | 
						||
| 
								 | 
							
								            if scope["type"] == "websocket":
							 | 
						||
| 
								 | 
							
								                await send({"type": "websocket.close", "code": 1000})
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                await response(scope, receive, send)
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if auth_result is None:
							 | 
						||
| 
								 | 
							
								            auth_result = AuthCredentials(), UnauthenticatedUser()
							 | 
						||
| 
								 | 
							
								        scope["auth"], scope["user"] = auth_result
							 | 
						||
| 
								 | 
							
								        await self.app(scope, receive, send)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @staticmethod
							 | 
						||
| 
								 | 
							
								    def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
							 | 
						||
| 
								 | 
							
								        return PlainTextResponse(str(exc), status_code=400)
							 |