You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					136 lines
				
				4.4 KiB
			
		
		
			
		
	
	
					136 lines
				
				4.4 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import inspect
							 | 
						||
| 
								 | 
							
								import typing
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from starlette.requests import Request
							 | 
						||
| 
								 | 
							
								from starlette.responses import Response
							 | 
						||
| 
								 | 
							
								from starlette.routing import BaseRoute, Mount, Route
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								try:
							 | 
						||
| 
								 | 
							
								    import yaml
							 | 
						||
| 
								 | 
							
								except ImportError:  # pragma: nocover
							 | 
						||
| 
								 | 
							
								    yaml = None  # type: ignore
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class OpenAPIResponse(Response):
							 | 
						||
| 
								 | 
							
								    media_type = "application/vnd.oai.openapi"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def render(self, content: typing.Any) -> bytes:
							 | 
						||
| 
								 | 
							
								        assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
							 | 
						||
| 
								 | 
							
								        assert isinstance(
							 | 
						||
| 
								 | 
							
								            content, dict
							 | 
						||
| 
								 | 
							
								        ), "The schema passed to OpenAPIResponse should be a dictionary."
							 | 
						||
| 
								 | 
							
								        return yaml.dump(content, default_flow_style=False).encode("utf-8")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class EndpointInfo(typing.NamedTuple):
							 | 
						||
| 
								 | 
							
								    path: str
							 | 
						||
| 
								 | 
							
								    http_method: str
							 | 
						||
| 
								 | 
							
								    func: typing.Callable
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class BaseSchemaGenerator:
							 | 
						||
| 
								 | 
							
								    def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError()  # pragma: no cover
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_endpoints(
							 | 
						||
| 
								 | 
							
								        self, routes: typing.List[BaseRoute]
							 | 
						||
| 
								 | 
							
								    ) -> typing.List[EndpointInfo]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Given the routes, yields the following information:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        - path
							 | 
						||
| 
								 | 
							
								            eg: /users/
							 | 
						||
| 
								 | 
							
								        - http_method
							 | 
						||
| 
								 | 
							
								            one of 'get', 'post', 'put', 'patch', 'delete', 'options'
							 | 
						||
| 
								 | 
							
								        - func
							 | 
						||
| 
								 | 
							
								            method ready to extract the docstring
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        endpoints_info: list = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for route in routes:
							 | 
						||
| 
								 | 
							
								            if isinstance(route, Mount):
							 | 
						||
| 
								 | 
							
								                routes = route.routes or []
							 | 
						||
| 
								 | 
							
								                sub_endpoints = [
							 | 
						||
| 
								 | 
							
								                    EndpointInfo(
							 | 
						||
| 
								 | 
							
								                        path="".join((route.path, sub_endpoint.path)),
							 | 
						||
| 
								 | 
							
								                        http_method=sub_endpoint.http_method,
							 | 
						||
| 
								 | 
							
								                        func=sub_endpoint.func,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                    for sub_endpoint in self.get_endpoints(routes)
							 | 
						||
| 
								 | 
							
								                ]
							 | 
						||
| 
								 | 
							
								                endpoints_info.extend(sub_endpoints)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            elif not isinstance(route, Route) or not route.include_in_schema:
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
							 | 
						||
| 
								 | 
							
								                for method in route.methods or ["GET"]:
							 | 
						||
| 
								 | 
							
								                    if method == "HEAD":
							 | 
						||
| 
								 | 
							
								                        continue
							 | 
						||
| 
								 | 
							
								                    endpoints_info.append(
							 | 
						||
| 
								 | 
							
								                        EndpointInfo(route.path, method.lower(), route.endpoint)
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                for method in ["get", "post", "put", "patch", "delete", "options"]:
							 | 
						||
| 
								 | 
							
								                    if not hasattr(route.endpoint, method):
							 | 
						||
| 
								 | 
							
								                        continue
							 | 
						||
| 
								 | 
							
								                    func = getattr(route.endpoint, method)
							 | 
						||
| 
								 | 
							
								                    endpoints_info.append(
							 | 
						||
| 
								 | 
							
								                        EndpointInfo(route.path, method.lower(), func)
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return endpoints_info
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def parse_docstring(self, func_or_method: typing.Callable) -> dict:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Given a function, parse the docstring as YAML and return a dictionary of info.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        docstring = func_or_method.__doc__
							 | 
						||
| 
								 | 
							
								        if not docstring:
							 | 
						||
| 
								 | 
							
								            return {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # We support having regular docstrings before the schema
							 | 
						||
| 
								 | 
							
								        # definition. Here we return just the schema part from
							 | 
						||
| 
								 | 
							
								        # the docstring.
							 | 
						||
| 
								 | 
							
								        docstring = docstring.split("---")[-1]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        parsed = yaml.safe_load(docstring)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not isinstance(parsed, dict):
							 | 
						||
| 
								 | 
							
								            # A regular docstring (not yaml formatted) can return
							 | 
						||
| 
								 | 
							
								            # a simple string here, which wouldn't follow the schema.
							 | 
						||
| 
								 | 
							
								            return {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return parsed
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def OpenAPIResponse(self, request: Request) -> Response:
							 | 
						||
| 
								 | 
							
								        routes = request.app.routes
							 | 
						||
| 
								 | 
							
								        schema = self.get_schema(routes=routes)
							 | 
						||
| 
								 | 
							
								        return OpenAPIResponse(schema)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SchemaGenerator(BaseSchemaGenerator):
							 | 
						||
| 
								 | 
							
								    def __init__(self, base_schema: dict) -> None:
							 | 
						||
| 
								 | 
							
								        self.base_schema = base_schema
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
							 | 
						||
| 
								 | 
							
								        schema = dict(self.base_schema)
							 | 
						||
| 
								 | 
							
								        schema.setdefault("paths", {})
							 | 
						||
| 
								 | 
							
								        endpoints_info = self.get_endpoints(routes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for endpoint in endpoints_info:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            parsed = self.parse_docstring(endpoint.func)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if not parsed:
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if endpoint.path not in schema["paths"]:
							 | 
						||
| 
								 | 
							
								                schema["paths"][endpoint.path] = {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            schema["paths"][endpoint.path][endpoint.http_method] = parsed
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return schema
							 |