You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					105 lines
				
				4.0 KiB
			
		
		
			
		
	
	
					105 lines
				
				4.0 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import gzip
							 | 
						||
| 
								 | 
							
								import io
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from starlette.datastructures import Headers, MutableHeaders
							 | 
						||
| 
								 | 
							
								from starlette.types import ASGIApp, Message, Receive, Scope, Send
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class GZipMiddleware:
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        self.app = app
							 | 
						||
| 
								 | 
							
								        self.minimum_size = minimum_size
							 | 
						||
| 
								 | 
							
								        self.compresslevel = compresslevel
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
							 | 
						||
| 
								 | 
							
								        if scope["type"] == "http":
							 | 
						||
| 
								 | 
							
								            headers = Headers(scope=scope)
							 | 
						||
| 
								 | 
							
								            if "gzip" in headers.get("Accept-Encoding", ""):
							 | 
						||
| 
								 | 
							
								                responder = GZipResponder(
							 | 
						||
| 
								 | 
							
								                    self.app, self.minimum_size, compresslevel=self.compresslevel
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                await responder(scope, receive, send)
							 | 
						||
| 
								 | 
							
								                return
							 | 
						||
| 
								 | 
							
								        await self.app(scope, receive, send)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class GZipResponder:
							 | 
						||
| 
								 | 
							
								    def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
							 | 
						||
| 
								 | 
							
								        self.app = app
							 | 
						||
| 
								 | 
							
								        self.minimum_size = minimum_size
							 | 
						||
| 
								 | 
							
								        self.send: Send = unattached_send
							 | 
						||
| 
								 | 
							
								        self.initial_message: Message = {}
							 | 
						||
| 
								 | 
							
								        self.started = False
							 | 
						||
| 
								 | 
							
								        self.gzip_buffer = io.BytesIO()
							 | 
						||
| 
								 | 
							
								        self.gzip_file = gzip.GzipFile(
							 | 
						||
| 
								 | 
							
								            mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
							 | 
						||
| 
								 | 
							
								        self.send = send
							 | 
						||
| 
								 | 
							
								        await self.app(scope, receive, self.send_with_gzip)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def send_with_gzip(self, message: Message) -> None:
							 | 
						||
| 
								 | 
							
								        message_type = message["type"]
							 | 
						||
| 
								 | 
							
								        if message_type == "http.response.start":
							 | 
						||
| 
								 | 
							
								            # Don't send the initial message until we've determined how to
							 | 
						||
| 
								 | 
							
								            # modify the outgoing headers correctly.
							 | 
						||
| 
								 | 
							
								            self.initial_message = message
							 | 
						||
| 
								 | 
							
								        elif message_type == "http.response.body" and not self.started:
							 | 
						||
| 
								 | 
							
								            self.started = True
							 | 
						||
| 
								 | 
							
								            body = message.get("body", b"")
							 | 
						||
| 
								 | 
							
								            more_body = message.get("more_body", False)
							 | 
						||
| 
								 | 
							
								            if len(body) < self.minimum_size and not more_body:
							 | 
						||
| 
								 | 
							
								                # Don't apply GZip to small outgoing responses.
							 | 
						||
| 
								 | 
							
								                await self.send(self.initial_message)
							 | 
						||
| 
								 | 
							
								                await self.send(message)
							 | 
						||
| 
								 | 
							
								            elif not more_body:
							 | 
						||
| 
								 | 
							
								                # Standard GZip response.
							 | 
						||
| 
								 | 
							
								                self.gzip_file.write(body)
							 | 
						||
| 
								 | 
							
								                self.gzip_file.close()
							 | 
						||
| 
								 | 
							
								                body = self.gzip_buffer.getvalue()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                headers = MutableHeaders(raw=self.initial_message["headers"])
							 | 
						||
| 
								 | 
							
								                headers["Content-Encoding"] = "gzip"
							 | 
						||
| 
								 | 
							
								                headers["Content-Length"] = str(len(body))
							 | 
						||
| 
								 | 
							
								                headers.add_vary_header("Accept-Encoding")
							 | 
						||
| 
								 | 
							
								                message["body"] = body
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                await self.send(self.initial_message)
							 | 
						||
| 
								 | 
							
								                await self.send(message)
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                # Initial body in streaming GZip response.
							 | 
						||
| 
								 | 
							
								                headers = MutableHeaders(raw=self.initial_message["headers"])
							 | 
						||
| 
								 | 
							
								                headers["Content-Encoding"] = "gzip"
							 | 
						||
| 
								 | 
							
								                headers.add_vary_header("Accept-Encoding")
							 | 
						||
| 
								 | 
							
								                del headers["Content-Length"]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                self.gzip_file.write(body)
							 | 
						||
| 
								 | 
							
								                message["body"] = self.gzip_buffer.getvalue()
							 | 
						||
| 
								 | 
							
								                self.gzip_buffer.seek(0)
							 | 
						||
| 
								 | 
							
								                self.gzip_buffer.truncate()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                await self.send(self.initial_message)
							 | 
						||
| 
								 | 
							
								                await self.send(message)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        elif message_type == "http.response.body":
							 | 
						||
| 
								 | 
							
								            # Remaining body in streaming GZip response.
							 | 
						||
| 
								 | 
							
								            body = message.get("body", b"")
							 | 
						||
| 
								 | 
							
								            more_body = message.get("more_body", False)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            self.gzip_file.write(body)
							 | 
						||
| 
								 | 
							
								            if not more_body:
							 | 
						||
| 
								 | 
							
								                self.gzip_file.close()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            message["body"] = self.gzip_buffer.getvalue()
							 | 
						||
| 
								 | 
							
								            self.gzip_buffer.seek(0)
							 | 
						||
| 
								 | 
							
								            self.gzip_buffer.truncate()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            await self.send(message)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								async def unattached_send(message: Message) -> None:
							 | 
						||
| 
								 | 
							
								    raise RuntimeError("send awaitable not set")  # pragma: no cover
							 |