You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
3.0 KiB

import typing
from os import PathLike
from starlette.background import BackgroundTask
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
try:
import jinja2
# @contextfunction renamed to @pass_context in Jinja 3.0, to be removed in 3.1
if hasattr(jinja2, "pass_context"):
pass_context = jinja2.pass_context
else: # pragma: nocover
pass_context = jinja2.contextfunction
except ImportError: # pragma: nocover
jinja2 = None # type: ignore
class _TemplateResponse(Response):
media_type = "text/html"
def __init__(
self,
template: typing.Any,
context: dict,
status_code: int = 200,
headers: dict = None,
media_type: str = None,
background: BackgroundTask = None,
):
self.template = template
self.context = context
content = template.render(context)
super().__init__(content, status_code, headers, media_type, background)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = self.context.get("request", {})
extensions = request.get("extensions", {})
if "http.response.template" in extensions:
await send(
{
"type": "http.response.template",
"template": self.template,
"context": self.context,
}
)
await super().__call__(scope, receive, send)
class Jinja2Templates:
"""
templates = Jinja2Templates("templates")
return templates.TemplateResponse("index.html", {"request": request})
"""
def __init__(self, directory: typing.Union[str, PathLike]) -> None:
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
self.env = self._create_env(directory)
def _create_env(
self, directory: typing.Union[str, PathLike]
) -> "jinja2.Environment":
@pass_context
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
request = context["request"]
return request.url_for(name, **path_params)
loader = jinja2.FileSystemLoader(directory)
env = jinja2.Environment(loader=loader, autoescape=True)
env.globals["url_for"] = url_for
return env
def get_template(self, name: str) -> "jinja2.Template":
return self.env.get_template(name)
def TemplateResponse(
self,
name: str,
context: dict,
status_code: int = 200,
headers: dict = None,
media_type: str = None,
background: BackgroundTask = None,
) -> _TemplateResponse:
if "request" not in context:
raise ValueError('context must include a "request" key')
template = self.get_template(name)
return _TemplateResponse(
template,
context,
status_code=status_code,
headers=headers,
media_type=media_type,
background=background,
)