You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
2.5 KiB
71 lines
2.5 KiB
|
3 years ago
|
import typing
|
||
|
|
|
||
|
|
import anyio
|
||
|
|
|
||
|
|
from starlette.requests import Request
|
||
|
|
from starlette.responses import Response, StreamingResponse
|
||
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||
|
|
|
||
|
|
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||
|
|
DispatchFunction = typing.Callable[
|
||
|
|
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
class BaseHTTPMiddleware:
|
||
|
|
def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
|
||
|
|
self.app = app
|
||
|
|
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||
|
|
|
||
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
|
|
if scope["type"] != "http":
|
||
|
|
await self.app(scope, receive, send)
|
||
|
|
return
|
||
|
|
|
||
|
|
async def call_next(request: Request) -> Response:
|
||
|
|
app_exc: typing.Optional[Exception] = None
|
||
|
|
send_stream, recv_stream = anyio.create_memory_object_stream()
|
||
|
|
|
||
|
|
async def coro() -> None:
|
||
|
|
nonlocal app_exc
|
||
|
|
|
||
|
|
async with send_stream:
|
||
|
|
try:
|
||
|
|
await self.app(scope, request.receive, send_stream.send)
|
||
|
|
except Exception as exc:
|
||
|
|
app_exc = exc
|
||
|
|
|
||
|
|
task_group.start_soon(coro)
|
||
|
|
|
||
|
|
try:
|
||
|
|
message = await recv_stream.receive()
|
||
|
|
except anyio.EndOfStream:
|
||
|
|
if app_exc is not None:
|
||
|
|
raise app_exc
|
||
|
|
raise RuntimeError("No response returned.")
|
||
|
|
|
||
|
|
assert message["type"] == "http.response.start"
|
||
|
|
|
||
|
|
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||
|
|
async with recv_stream:
|
||
|
|
async for message in recv_stream:
|
||
|
|
assert message["type"] == "http.response.body"
|
||
|
|
yield message.get("body", b"")
|
||
|
|
|
||
|
|
response = StreamingResponse(
|
||
|
|
status_code=message["status"], content=body_stream()
|
||
|
|
)
|
||
|
|
response.raw_headers = message["headers"]
|
||
|
|
return response
|
||
|
|
|
||
|
|
async with anyio.create_task_group() as task_group:
|
||
|
|
request = Request(scope, receive=receive)
|
||
|
|
response = await self.dispatch_func(request, call_next)
|
||
|
|
await response(scope, receive, send)
|
||
|
|
task_group.cancel_scope.cancel()
|
||
|
|
|
||
|
|
async def dispatch(
|
||
|
|
self, request: Request, call_next: RequestResponseEndpoint
|
||
|
|
) -> Response:
|
||
|
|
raise NotImplementedError() # pragma: no cover
|