You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					71 lines
				
				2.5 KiB
			
		
		
			
		
	
	
					71 lines
				
				2.5 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import typing
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import anyio
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from starlette.requests import Request
							 | 
						||
| 
								 | 
							
								from starlette.responses import Response, StreamingResponse
							 | 
						||
| 
								 | 
							
								from starlette.types import ASGIApp, Receive, Scope, Send
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
							 | 
						||
| 
								 | 
							
								DispatchFunction = typing.Callable[
							 | 
						||
| 
								 | 
							
								    [Request, RequestResponseEndpoint], typing.Awaitable[Response]
							 | 
						||
| 
								 | 
							
								]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class BaseHTTPMiddleware:
							 | 
						||
| 
								 | 
							
								    def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
							 | 
						||
| 
								 | 
							
								        self.app = app
							 | 
						||
| 
								 | 
							
								        self.dispatch_func = self.dispatch if dispatch is None else dispatch
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
							 | 
						||
| 
								 | 
							
								        if scope["type"] != "http":
							 | 
						||
| 
								 | 
							
								            await self.app(scope, receive, send)
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        async def call_next(request: Request) -> Response:
							 | 
						||
| 
								 | 
							
								            app_exc: typing.Optional[Exception] = None
							 | 
						||
| 
								 | 
							
								            send_stream, recv_stream = anyio.create_memory_object_stream()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            async def coro() -> None:
							 | 
						||
| 
								 | 
							
								                nonlocal app_exc
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                async with send_stream:
							 | 
						||
| 
								 | 
							
								                    try:
							 | 
						||
| 
								 | 
							
								                        await self.app(scope, request.receive, send_stream.send)
							 | 
						||
| 
								 | 
							
								                    except Exception as exc:
							 | 
						||
| 
								 | 
							
								                        app_exc = exc
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            task_group.start_soon(coro)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                message = await recv_stream.receive()
							 | 
						||
| 
								 | 
							
								            except anyio.EndOfStream:
							 | 
						||
| 
								 | 
							
								                if app_exc is not None:
							 | 
						||
| 
								 | 
							
								                    raise app_exc
							 | 
						||
| 
								 | 
							
								                raise RuntimeError("No response returned.")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            assert message["type"] == "http.response.start"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            async def body_stream() -> typing.AsyncGenerator[bytes, None]:
							 | 
						||
| 
								 | 
							
								                async with recv_stream:
							 | 
						||
| 
								 | 
							
								                    async for message in recv_stream:
							 | 
						||
| 
								 | 
							
								                        assert message["type"] == "http.response.body"
							 | 
						||
| 
								 | 
							
								                        yield message.get("body", b"")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            response = StreamingResponse(
							 | 
						||
| 
								 | 
							
								                status_code=message["status"], content=body_stream()
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								            response.raw_headers = message["headers"]
							 | 
						||
| 
								 | 
							
								            return response
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        async with anyio.create_task_group() as task_group:
							 | 
						||
| 
								 | 
							
								            request = Request(scope, receive=receive)
							 | 
						||
| 
								 | 
							
								            response = await self.dispatch_func(request, call_next)
							 | 
						||
| 
								 | 
							
								            await response(scope, receive, send)
							 | 
						||
| 
								 | 
							
								            task_group.cancel_scope.cancel()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def dispatch(
							 | 
						||
| 
								 | 
							
								        self, request: Request, call_next: RequestResponseEndpoint
							 | 
						||
| 
								 | 
							
								    ) -> Response:
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError()  # pragma: no cover
							 |