You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							71 lines
						
					
					
						
							2.5 KiB
						
					
					
				
			
		
		
	
	
							71 lines
						
					
					
						
							2.5 KiB
						
					
					
				import typing
 | 
						|
 | 
						|
import anyio
 | 
						|
 | 
						|
from starlette.requests import Request
 | 
						|
from starlette.responses import Response, StreamingResponse
 | 
						|
from starlette.types import ASGIApp, Receive, Scope, Send
 | 
						|
 | 
						|
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
 | 
						|
DispatchFunction = typing.Callable[
 | 
						|
    [Request, RequestResponseEndpoint], typing.Awaitable[Response]
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
class BaseHTTPMiddleware:
 | 
						|
    def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
 | 
						|
        self.app = app
 | 
						|
        self.dispatch_func = self.dispatch if dispatch is None else dispatch
 | 
						|
 | 
						|
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
 | 
						|
        if scope["type"] != "http":
 | 
						|
            await self.app(scope, receive, send)
 | 
						|
            return
 | 
						|
 | 
						|
        async def call_next(request: Request) -> Response:
 | 
						|
            app_exc: typing.Optional[Exception] = None
 | 
						|
            send_stream, recv_stream = anyio.create_memory_object_stream()
 | 
						|
 | 
						|
            async def coro() -> None:
 | 
						|
                nonlocal app_exc
 | 
						|
 | 
						|
                async with send_stream:
 | 
						|
                    try:
 | 
						|
                        await self.app(scope, request.receive, send_stream.send)
 | 
						|
                    except Exception as exc:
 | 
						|
                        app_exc = exc
 | 
						|
 | 
						|
            task_group.start_soon(coro)
 | 
						|
 | 
						|
            try:
 | 
						|
                message = await recv_stream.receive()
 | 
						|
            except anyio.EndOfStream:
 | 
						|
                if app_exc is not None:
 | 
						|
                    raise app_exc
 | 
						|
                raise RuntimeError("No response returned.")
 | 
						|
 | 
						|
            assert message["type"] == "http.response.start"
 | 
						|
 | 
						|
            async def body_stream() -> typing.AsyncGenerator[bytes, None]:
 | 
						|
                async with recv_stream:
 | 
						|
                    async for message in recv_stream:
 | 
						|
                        assert message["type"] == "http.response.body"
 | 
						|
                        yield message.get("body", b"")
 | 
						|
 | 
						|
            response = StreamingResponse(
 | 
						|
                status_code=message["status"], content=body_stream()
 | 
						|
            )
 | 
						|
            response.raw_headers = message["headers"]
 | 
						|
            return response
 | 
						|
 | 
						|
        async with anyio.create_task_group() as task_group:
 | 
						|
            request = Request(scope, receive=receive)
 | 
						|
            response = await self.dispatch_func(request, call_next)
 | 
						|
            await response(scope, receive, send)
 | 
						|
            task_group.cancel_scope.cancel()
 | 
						|
 | 
						|
    async def dispatch(
 | 
						|
        self, request: Request, call_next: RequestResponseEndpoint
 | 
						|
    ) -> Response:
 | 
						|
        raise NotImplementedError()  # pragma: no cover
 |