You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					78 lines
				
				3.0 KiB
			
		
		
			
		
	
	
					78 lines
				
				3.0 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import json
							 | 
						||
| 
								 | 
							
								import typing
							 | 
						||
| 
								 | 
							
								from base64 import b64decode, b64encode
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import itsdangerous
							 | 
						||
| 
								 | 
							
								from itsdangerous.exc import BadSignature
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from starlette.datastructures import MutableHeaders, Secret
							 | 
						||
| 
								 | 
							
								from starlette.requests import HTTPConnection
							 | 
						||
| 
								 | 
							
								from starlette.types import ASGIApp, Message, Receive, Scope, Send
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SessionMiddleware:
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        app: ASGIApp,
							 | 
						||
| 
								 | 
							
								        secret_key: typing.Union[str, Secret],
							 | 
						||
| 
								 | 
							
								        session_cookie: str = "session",
							 | 
						||
| 
								 | 
							
								        max_age: int = 14 * 24 * 60 * 60,  # 14 days, in seconds
							 | 
						||
| 
								 | 
							
								        same_site: str = "lax",
							 | 
						||
| 
								 | 
							
								        https_only: bool = False,
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        self.app = app
							 | 
						||
| 
								 | 
							
								        self.signer = itsdangerous.TimestampSigner(str(secret_key))
							 | 
						||
| 
								 | 
							
								        self.session_cookie = session_cookie
							 | 
						||
| 
								 | 
							
								        self.max_age = max_age
							 | 
						||
| 
								 | 
							
								        self.security_flags = "httponly; samesite=" + same_site
							 | 
						||
| 
								 | 
							
								        if https_only:  # Secure flag can be used with HTTPS only
							 | 
						||
| 
								 | 
							
								            self.security_flags += "; secure"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
							 | 
						||
| 
								 | 
							
								        if scope["type"] not in ("http", "websocket"):  # pragma: no cover
							 | 
						||
| 
								 | 
							
								            await self.app(scope, receive, send)
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        connection = HTTPConnection(scope)
							 | 
						||
| 
								 | 
							
								        initial_session_was_empty = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self.session_cookie in connection.cookies:
							 | 
						||
| 
								 | 
							
								            data = connection.cookies[self.session_cookie].encode("utf-8")
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                data = self.signer.unsign(data, max_age=self.max_age)
							 | 
						||
| 
								 | 
							
								                scope["session"] = json.loads(b64decode(data))
							 | 
						||
| 
								 | 
							
								                initial_session_was_empty = False
							 | 
						||
| 
								 | 
							
								            except BadSignature:
							 | 
						||
| 
								 | 
							
								                scope["session"] = {}
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            scope["session"] = {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        async def send_wrapper(message: Message) -> None:
							 | 
						||
| 
								 | 
							
								            if message["type"] == "http.response.start":
							 | 
						||
| 
								 | 
							
								                path = scope.get("root_path", "") or "/"
							 | 
						||
| 
								 | 
							
								                if scope["session"]:
							 | 
						||
| 
								 | 
							
								                    # We have session data to persist.
							 | 
						||
| 
								 | 
							
								                    data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
							 | 
						||
| 
								 | 
							
								                    data = self.signer.sign(data)
							 | 
						||
| 
								 | 
							
								                    headers = MutableHeaders(scope=message)
							 | 
						||
| 
								 | 
							
								                    header_value = "%s=%s; path=%s; Max-Age=%d; %s" % (
							 | 
						||
| 
								 | 
							
								                        self.session_cookie,
							 | 
						||
| 
								 | 
							
								                        data.decode("utf-8"),
							 | 
						||
| 
								 | 
							
								                        path,
							 | 
						||
| 
								 | 
							
								                        self.max_age,
							 | 
						||
| 
								 | 
							
								                        self.security_flags,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                    headers.append("Set-Cookie", header_value)
							 | 
						||
| 
								 | 
							
								                elif not initial_session_was_empty:
							 | 
						||
| 
								 | 
							
								                    # The session has been cleared.
							 | 
						||
| 
								 | 
							
								                    headers = MutableHeaders(scope=message)
							 | 
						||
| 
								 | 
							
								                    header_value = "{}={}; {}".format(
							 | 
						||
| 
								 | 
							
								                        self.session_cookie,
							 | 
						||
| 
								 | 
							
								                        f"null; path={path}; expires=Thu, 01 Jan 1970 00:00:00 GMT;",
							 | 
						||
| 
								 | 
							
								                        self.security_flags,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                    headers.append("Set-Cookie", header_value)
							 | 
						||
| 
								 | 
							
								            await send(message)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        await self.app(scope, receive, send_wrapper)
							 |