You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
2.1 KiB

import typing
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
class TrustedHostMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_hosts: typing.Sequence[str] = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
allowed_hosts = ["*"]
for pattern in allowed_hosts:
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
if pattern.startswith("*") and pattern != "*":
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
self.app = app
self.allowed_hosts = list(allowed_hosts)
self.allow_any = "*" in allowed_hosts
self.www_redirect = www_redirect
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.allow_any or scope["type"] not in (
"http",
"websocket",
): # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
if host == pattern or (
pattern.startswith("*") and host.endswith(pattern[1:])
):
is_valid_host = True
break
elif "www." + host == pattern:
found_www_redirect = True
if is_valid_host:
await self.app(scope, receive, send)
else:
response: Response
if found_www_redirect and self.www_redirect:
url = URL(scope=scope)
redirect_url = url.replace(netloc="www." + url.netloc)
response = RedirectResponse(url=str(redirect_url))
else:
response = PlainTextResponse("Invalid host header", status_code=400)
await response(scope, receive, send)