You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					61 lines
				
				2.1 KiB
			
		
		
			
		
	
	
					61 lines
				
				2.1 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import typing
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from starlette.datastructures import URL, Headers
							 | 
						||
| 
								 | 
							
								from starlette.responses import PlainTextResponse, RedirectResponse, Response
							 | 
						||
| 
								 | 
							
								from starlette.types import ASGIApp, Receive, Scope, Send
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class TrustedHostMiddleware:
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        app: ASGIApp,
							 | 
						||
| 
								 | 
							
								        allowed_hosts: typing.Sequence[str] = None,
							 | 
						||
| 
								 | 
							
								        www_redirect: bool = True,
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        if allowed_hosts is None:
							 | 
						||
| 
								 | 
							
								            allowed_hosts = ["*"]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for pattern in allowed_hosts:
							 | 
						||
| 
								 | 
							
								            assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
							 | 
						||
| 
								 | 
							
								            if pattern.startswith("*") and pattern != "*":
							 | 
						||
| 
								 | 
							
								                assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
							 | 
						||
| 
								 | 
							
								        self.app = app
							 | 
						||
| 
								 | 
							
								        self.allowed_hosts = list(allowed_hosts)
							 | 
						||
| 
								 | 
							
								        self.allow_any = "*" in allowed_hosts
							 | 
						||
| 
								 | 
							
								        self.www_redirect = www_redirect
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
							 | 
						||
| 
								 | 
							
								        if self.allow_any or scope["type"] not in (
							 | 
						||
| 
								 | 
							
								            "http",
							 | 
						||
| 
								 | 
							
								            "websocket",
							 | 
						||
| 
								 | 
							
								        ):  # pragma: no cover
							 | 
						||
| 
								 | 
							
								            await self.app(scope, receive, send)
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        headers = Headers(scope=scope)
							 | 
						||
| 
								 | 
							
								        host = headers.get("host", "").split(":")[0]
							 | 
						||
| 
								 | 
							
								        is_valid_host = False
							 | 
						||
| 
								 | 
							
								        found_www_redirect = False
							 | 
						||
| 
								 | 
							
								        for pattern in self.allowed_hosts:
							 | 
						||
| 
								 | 
							
								            if host == pattern or (
							 | 
						||
| 
								 | 
							
								                pattern.startswith("*") and host.endswith(pattern[1:])
							 | 
						||
| 
								 | 
							
								            ):
							 | 
						||
| 
								 | 
							
								                is_valid_host = True
							 | 
						||
| 
								 | 
							
								                break
							 | 
						||
| 
								 | 
							
								            elif "www." + host == pattern:
							 | 
						||
| 
								 | 
							
								                found_www_redirect = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if is_valid_host:
							 | 
						||
| 
								 | 
							
								            await self.app(scope, receive, send)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            response: Response
							 | 
						||
| 
								 | 
							
								            if found_www_redirect and self.www_redirect:
							 | 
						||
| 
								 | 
							
								                url = URL(scope=scope)
							 | 
						||
| 
								 | 
							
								                redirect_url = url.replace(netloc="www." + url.netloc)
							 | 
						||
| 
								 | 
							
								                response = RedirectResponse(url=str(redirect_url))
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                response = PlainTextResponse("Invalid host header", status_code=400)
							 | 
						||
| 
								 | 
							
								            await response(scope, receive, send)
							 |