You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
3.0 KiB

import json
import typing
from base64 import b64decode, b64encode
import itsdangerous
from itsdangerous.exc import BadSignature
from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send
class SessionMiddleware:
def __init__(
self,
app: ASGIApp,
secret_key: typing.Union[str, Secret],
session_cookie: str = "session",
max_age: int = 14 * 24 * 60 * 60, # 14 days, in seconds
same_site: str = "lax",
https_only: bool = False,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
self.session_cookie = session_cookie
self.max_age = max_age
self.security_flags = "httponly; samesite=" + same_site
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
return
connection = HTTPConnection(scope)
initial_session_was_empty = True
if self.session_cookie in connection.cookies:
data = connection.cookies[self.session_cookie].encode("utf-8")
try:
data = self.signer.unsign(data, max_age=self.max_age)
scope["session"] = json.loads(b64decode(data))
initial_session_was_empty = False
except BadSignature:
scope["session"] = {}
else:
scope["session"] = {}
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
path = scope.get("root_path", "") or "/"
if scope["session"]:
# We have session data to persist.
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "%s=%s; path=%s; Max-Age=%d; %s" % (
self.session_cookie,
data.decode("utf-8"),
path,
self.max_age,
self.security_flags,
)
headers.append("Set-Cookie", header_value)
elif not initial_session_was_empty:
# The session has been cleared.
headers = MutableHeaders(scope=message)
header_value = "{}={}; {}".format(
self.session_cookie,
f"null; path={path}; expires=Thu, 01 Jan 1970 00:00:00 GMT;",
self.security_flags,
)
headers.append("Set-Cookie", header_value)
await send(message)
await self.app(scope, receive, send_wrapper)