You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					881 lines
				
				34 KiB
			
		
		
			
		
	
	
					881 lines
				
				34 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								"""
							 | 
						||
| 
								 | 
							
								The starlette extension to rate-limit requests
							 | 
						||
| 
								 | 
							
								"""
							 | 
						||
| 
								 | 
							
								import asyncio
							 | 
						||
| 
								 | 
							
								import functools
							 | 
						||
| 
								 | 
							
								import inspect
							 | 
						||
| 
								 | 
							
								import itertools
							 | 
						||
| 
								 | 
							
								import logging
							 | 
						||
| 
								 | 
							
								import time
							 | 
						||
| 
								 | 
							
								from datetime import datetime
							 | 
						||
| 
								 | 
							
								from email.utils import formatdate, parsedate_to_datetime
							 | 
						||
| 
								 | 
							
								from functools import wraps
							 | 
						||
| 
								 | 
							
								from typing import (
							 | 
						||
| 
								 | 
							
								    Any,
							 | 
						||
| 
								 | 
							
								    Callable,
							 | 
						||
| 
								 | 
							
								    Dict,
							 | 
						||
| 
								 | 
							
								    List,
							 | 
						||
| 
								 | 
							
								    Optional,
							 | 
						||
| 
								 | 
							
								    Set,
							 | 
						||
| 
								 | 
							
								    Tuple,
							 | 
						||
| 
								 | 
							
								    TypeVar,
							 | 
						||
| 
								 | 
							
								    Union,
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from limits import RateLimitItem  # type: ignore
							 | 
						||
| 
								 | 
							
								from limits.aio.storage import Storage as AsyncStorage  # type: ignore
							 | 
						||
| 
								 | 
							
								from limits.errors import ConfigurationError  # type: ignore
							 | 
						||
| 
								 | 
							
								from limits.storage import MemoryStorage, storage_from_string  # type: ignore
							 | 
						||
| 
								 | 
							
								from limits.storage import Storage  # type: ignore
							 | 
						||
| 
								 | 
							
								from limits.strategies import STRATEGIES, RateLimiter  # type: ignore
							 | 
						||
| 
								 | 
							
								from starlette.config import Config
							 | 
						||
| 
								 | 
							
								from starlette.datastructures import MutableHeaders
							 | 
						||
| 
								 | 
							
								from starlette.requests import Request
							 | 
						||
| 
								 | 
							
								from starlette.responses import JSONResponse, Response
							 | 
						||
| 
								 | 
							
								from typing_extensions import Literal
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .errors import RateLimitExceeded
							 | 
						||
| 
								 | 
							
								from .wrappers import Limit, LimitGroup
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								# used to annotate get_app_config method
							 | 
						||
| 
								 | 
							
								T = TypeVar("T")
							 | 
						||
| 
								 | 
							
								# Define an alias for the most commonly used type
							 | 
						||
| 
								 | 
							
								StrOrCallableStr = Union[str, Callable[..., str]]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class C:
							 | 
						||
| 
								 | 
							
								    ENABLED = "RATELIMIT_ENABLED"
							 | 
						||
| 
								 | 
							
								    HEADERS_ENABLED = "RATELIMIT_HEADERS_ENABLED"
							 | 
						||
| 
								 | 
							
								    STORAGE_URL = "RATELIMIT_STORAGE_URL"
							 | 
						||
| 
								 | 
							
								    STORAGE_OPTIONS = "RATELIMIT_STORAGE_OPTIONS"
							 | 
						||
| 
								 | 
							
								    STRATEGY = "RATELIMIT_STRATEGY"
							 | 
						||
| 
								 | 
							
								    GLOBAL_LIMITS = "RATELIMIT_GLOBAL"
							 | 
						||
| 
								 | 
							
								    DEFAULT_LIMITS = "RATELIMIT_DEFAULT"
							 | 
						||
| 
								 | 
							
								    APPLICATION_LIMITS = "RATELIMIT_APPLICATION"
							 | 
						||
| 
								 | 
							
								    HEADER_LIMIT = "RATELIMIT_HEADER_LIMIT"
							 | 
						||
| 
								 | 
							
								    HEADER_REMAINING = "RATELIMIT_HEADER_REMAINING"
							 | 
						||
| 
								 | 
							
								    HEADER_RESET = "RATELIMIT_HEADER_RESET"
							 | 
						||
| 
								 | 
							
								    SWALLOW_ERRORS = "RATELIMIT_SWALLOW_ERRORS"
							 | 
						||
| 
								 | 
							
								    IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
							 | 
						||
| 
								 | 
							
								    IN_MEMORY_FALLBACK_ENABLED = "RATELIMIT_IN_MEMORY_FALLBACK_ENABLED"
							 | 
						||
| 
								 | 
							
								    HEADER_RETRY_AFTER = "RATELIMIT_HEADER_RETRY_AFTER"
							 | 
						||
| 
								 | 
							
								    HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
							 | 
						||
| 
								 | 
							
								    KEY_PREFIX = "RATELIMIT_KEY_PREFIX"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class HEADERS:
							 | 
						||
| 
								 | 
							
								    RESET = 1
							 | 
						||
| 
								 | 
							
								    REMAINING = 2
							 | 
						||
| 
								 | 
							
								    LIMIT = 3
							 | 
						||
| 
								 | 
							
								    RETRY_AFTER = 4
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								MAX_BACKEND_CHECKS = 5
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> Response:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Build a simple JSON response that includes the details of the rate limit
							 | 
						||
| 
								 | 
							
								    that was hit. If no limit is hit, the countdown is added to headers.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    response = JSONResponse(
							 | 
						||
| 
								 | 
							
								        {"error": f"Rate limit exceeded: {exc.detail}"}, status_code=429
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    response = request.app.state.limiter._inject_headers(
							 | 
						||
| 
								 | 
							
								        response, request.state.view_rate_limit
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    return response
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Limiter:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Initializes the slowapi rate limiter.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    ** parameter **
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    * **app**: `Starlette/FastAPI` instance to initialize the extension
							 | 
						||
| 
								 | 
							
								     with.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    * **default_limits**: a variable list of strings or callables returning strings denoting global
							 | 
						||
| 
								 | 
							
								     limits to apply to all routes. `ratelimit-string` for  more details.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    * **application_limits**: a variable list of strings or callables returning strings for limits that
							 | 
						||
| 
								 | 
							
								     are applied to the entire application (i.e a shared limit for all routes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    * **key_func**: a callable that returns the domain to rate limit by.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    * **headers_enabled**: whether ``X-RateLimit`` response headers are written.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    * **strategy:** the strategy to use. refer to `ratelimit-strategy`
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    * **storage_uri**: the storage location. refer to `ratelimit-conf`
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    * **storage_options**: kwargs to pass to the storage implementation upon
							 | 
						||
| 
								 | 
							
								      instantiation.
							 | 
						||
| 
								 | 
							
								    * **auto_check**: whether to automatically check the rate limit in the before_request
							 | 
						||
| 
								 | 
							
								     chain of the application. default ``True``
							 | 
						||
| 
								 | 
							
								    * **swallow_errors**: whether to swallow errors when hitting a rate limit.
							 | 
						||
| 
								 | 
							
								     An exception will still be logged. default ``False``
							 | 
						||
| 
								 | 
							
								    * **in_memory_fallback**: a variable list of strings or callables returning strings denoting fallback
							 | 
						||
| 
								 | 
							
								     limits to apply when the storage is down.
							 | 
						||
| 
								 | 
							
								    * **in_memory_fallback_enabled**: simply falls back to in memory storage
							 | 
						||
| 
								 | 
							
								     when the main storage is down and inherits the original limits.
							 | 
						||
| 
								 | 
							
								    * **key_prefix**: prefix prepended to rate limiter keys.
							 | 
						||
| 
								 | 
							
								    * **enabled**: set to False to deactivate the limiter (default: True)
							 | 
						||
| 
								 | 
							
								    * **config_filename**: name of the config file for Starlette from which to load settings
							 | 
						||
| 
								 | 
							
								     for the rate limiter. Defaults to ".env".
							 | 
						||
| 
								 | 
							
								    * **key_style**: set to "url" to use the url, "endpoint" to use the view_func
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        # app: Starlette = None,
							 | 
						||
| 
								 | 
							
								        key_func: Callable[..., str],
							 | 
						||
| 
								 | 
							
								        default_limits: List[StrOrCallableStr] = [],
							 | 
						||
| 
								 | 
							
								        application_limits: List[StrOrCallableStr] = [],
							 | 
						||
| 
								 | 
							
								        headers_enabled: bool = False,
							 | 
						||
| 
								 | 
							
								        strategy: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        storage_uri: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        storage_options: Dict[str, str] = {},
							 | 
						||
| 
								 | 
							
								        auto_check: bool = True,
							 | 
						||
| 
								 | 
							
								        swallow_errors: bool = False,
							 | 
						||
| 
								 | 
							
								        in_memory_fallback: List[StrOrCallableStr] = [],
							 | 
						||
| 
								 | 
							
								        in_memory_fallback_enabled: bool = False,
							 | 
						||
| 
								 | 
							
								        retry_after: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        key_prefix: str = "",
							 | 
						||
| 
								 | 
							
								        enabled: bool = True,
							 | 
						||
| 
								 | 
							
								        config_filename: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        key_style: Literal["endpoint", "url"] = "url",
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Configure the rate limiter at app level
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        # assert app is not None, "Passing the app instance to the limiter is required"
							 | 
						||
| 
								 | 
							
								        # self.app = app
							 | 
						||
| 
								 | 
							
								        # app.state.limiter = self
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.logger = logging.getLogger("slowapi")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.app_config = Config(
							 | 
						||
| 
								 | 
							
								            config_filename if config_filename is not None else ".env"
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.enabled = enabled
							 | 
						||
| 
								 | 
							
								        self._default_limits = []
							 | 
						||
| 
								 | 
							
								        self._application_limits = []
							 | 
						||
| 
								 | 
							
								        self._in_memory_fallback: List[LimitGroup] = []
							 | 
						||
| 
								 | 
							
								        self._in_memory_fallback_enabled = (
							 | 
						||
| 
								 | 
							
								            in_memory_fallback_enabled or len(in_memory_fallback) > 0
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self._exempt_routes: Set[str] = set()
							 | 
						||
| 
								 | 
							
								        self._request_filters: List[Callable[..., bool]] = []
							 | 
						||
| 
								 | 
							
								        self._headers_enabled = headers_enabled
							 | 
						||
| 
								 | 
							
								        self._header_mapping: Dict[int, str] = {}
							 | 
						||
| 
								 | 
							
								        self._retry_after: Optional[str] = retry_after
							 | 
						||
| 
								 | 
							
								        self._strategy = strategy
							 | 
						||
| 
								 | 
							
								        self._storage_uri = storage_uri
							 | 
						||
| 
								 | 
							
								        self._storage_options = storage_options
							 | 
						||
| 
								 | 
							
								        self._auto_check = auto_check
							 | 
						||
| 
								 | 
							
								        self._swallow_errors = swallow_errors
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self._key_func = key_func
							 | 
						||
| 
								 | 
							
								        self._key_prefix = key_prefix
							 | 
						||
| 
								 | 
							
								        self._key_style = key_style
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for limit in set(default_limits):
							 | 
						||
| 
								 | 
							
								            self._default_limits.extend(
							 | 
						||
| 
								 | 
							
								                [
							 | 
						||
| 
								 | 
							
								                    LimitGroup(
							 | 
						||
| 
								 | 
							
								                        limit, self._key_func, None, False, None, None, None, 1, False
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                ]
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        for limit in application_limits:
							 | 
						||
| 
								 | 
							
								            self._application_limits.extend(
							 | 
						||
| 
								 | 
							
								                [
							 | 
						||
| 
								 | 
							
								                    LimitGroup(
							 | 
						||
| 
								 | 
							
								                        limit,
							 | 
						||
| 
								 | 
							
								                        self._key_func,
							 | 
						||
| 
								 | 
							
								                        "global",
							 | 
						||
| 
								 | 
							
								                        False,
							 | 
						||
| 
								 | 
							
								                        None,
							 | 
						||
| 
								 | 
							
								                        None,
							 | 
						||
| 
								 | 
							
								                        None,
							 | 
						||
| 
								 | 
							
								                        1,
							 | 
						||
| 
								 | 
							
								                        False,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                ]
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        for limit in in_memory_fallback:
							 | 
						||
| 
								 | 
							
								            self._in_memory_fallback.extend(
							 | 
						||
| 
								 | 
							
								                [
							 | 
						||
| 
								 | 
							
								                    LimitGroup(
							 | 
						||
| 
								 | 
							
								                        limit, self._key_func, None, False, None, None, None, 1, False
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                ]
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        self._route_limits: Dict[str, List[Limit]] = {}
							 | 
						||
| 
								 | 
							
								        self._dynamic_route_limits: Dict[str, List[LimitGroup]] = {}
							 | 
						||
| 
								 | 
							
								        # a flag to note if the storage backend is dead (not available)
							 | 
						||
| 
								 | 
							
								        self._storage_dead: bool = False
							 | 
						||
| 
								 | 
							
								        self._fallback_limiter = None
							 | 
						||
| 
								 | 
							
								        self.__check_backend_count = 0
							 | 
						||
| 
								 | 
							
								        self.__last_check_backend = time.time()
							 | 
						||
| 
								 | 
							
								        self.__marked_for_limiting: Dict[str, List[Callable]] = {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        class BlackHoleHandler(logging.StreamHandler):
							 | 
						||
| 
								 | 
							
								            def emit(*_):
							 | 
						||
| 
								 | 
							
								                return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.logger.addHandler(BlackHoleHandler())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.enabled = self.get_app_config(C.ENABLED, self.enabled)
							 | 
						||
| 
								 | 
							
								        self._swallow_errors = self.get_app_config(
							 | 
						||
| 
								 | 
							
								            C.SWALLOW_ERRORS, self._swallow_errors
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self._headers_enabled = self._headers_enabled or self.get_app_config(
							 | 
						||
| 
								 | 
							
								            C.HEADERS_ENABLED, False
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self._storage_options.update(self.get_app_config(C.STORAGE_OPTIONS, {}))
							 | 
						||
| 
								 | 
							
								        self._storage: Union[Storage, AsyncStorage] = storage_from_string(
							 | 
						||
| 
								 | 
							
								            self._storage_uri or self.get_app_config(C.STORAGE_URL, "memory://"),
							 | 
						||
| 
								 | 
							
								            **self._storage_options,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        strategy = self._strategy or self.get_app_config(C.STRATEGY, "fixed-window")
							 | 
						||
| 
								 | 
							
								        if strategy not in STRATEGIES:
							 | 
						||
| 
								 | 
							
								            raise ConfigurationError("Invalid rate limiting strategy %s" % strategy)
							 | 
						||
| 
								 | 
							
								        self._limiter: RateLimiter = STRATEGIES[strategy](self._storage)
							 | 
						||
| 
								 | 
							
								        self._header_mapping.update(
							 | 
						||
| 
								 | 
							
								            {
							 | 
						||
| 
								 | 
							
								                HEADERS.RESET: self._header_mapping.get(
							 | 
						||
| 
								 | 
							
								                    HEADERS.RESET,
							 | 
						||
| 
								 | 
							
								                    self.get_app_config(C.HEADER_RESET, "X-RateLimit-Reset"),
							 | 
						||
| 
								 | 
							
								                ),
							 | 
						||
| 
								 | 
							
								                HEADERS.REMAINING: self._header_mapping.get(
							 | 
						||
| 
								 | 
							
								                    HEADERS.REMAINING,
							 | 
						||
| 
								 | 
							
								                    self.get_app_config(C.HEADER_REMAINING, "X-RateLimit-Remaining"),
							 | 
						||
| 
								 | 
							
								                ),
							 | 
						||
| 
								 | 
							
								                HEADERS.LIMIT: self._header_mapping.get(
							 | 
						||
| 
								 | 
							
								                    HEADERS.LIMIT,
							 | 
						||
| 
								 | 
							
								                    self.get_app_config(C.HEADER_LIMIT, "X-RateLimit-Limit"),
							 | 
						||
| 
								 | 
							
								                ),
							 | 
						||
| 
								 | 
							
								                HEADERS.RETRY_AFTER: self._header_mapping.get(
							 | 
						||
| 
								 | 
							
								                    HEADERS.RETRY_AFTER,
							 | 
						||
| 
								 | 
							
								                    self.get_app_config(C.HEADER_RETRY_AFTER, "Retry-After"),
							 | 
						||
| 
								 | 
							
								                ),
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self._retry_after = self._retry_after or self.get_app_config(
							 | 
						||
| 
								 | 
							
								            C.HEADER_RETRY_AFTER_VALUE
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self._key_prefix = self._key_prefix or self.get_app_config(C.KEY_PREFIX)
							 | 
						||
| 
								 | 
							
								        app_limits: Optional[StrOrCallableStr] = self.get_app_config(
							 | 
						||
| 
								 | 
							
								            C.APPLICATION_LIMITS, None
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        if not self._application_limits and app_limits:
							 | 
						||
| 
								 | 
							
								            self._application_limits = [
							 | 
						||
| 
								 | 
							
								                LimitGroup(
							 | 
						||
| 
								 | 
							
								                    app_limits,
							 | 
						||
| 
								 | 
							
								                    self._key_func,
							 | 
						||
| 
								 | 
							
								                    "global",
							 | 
						||
| 
								 | 
							
								                    False,
							 | 
						||
| 
								 | 
							
								                    None,
							 | 
						||
| 
								 | 
							
								                    None,
							 | 
						||
| 
								 | 
							
								                    None,
							 | 
						||
| 
								 | 
							
								                    1,
							 | 
						||
| 
								 | 
							
								                    False,
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            ]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        conf_limits: Optional[StrOrCallableStr] = self.get_app_config(
							 | 
						||
| 
								 | 
							
								            C.DEFAULT_LIMITS, None
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        if not self._default_limits and conf_limits:
							 | 
						||
| 
								 | 
							
								            self._default_limits = [
							 | 
						||
| 
								 | 
							
								                LimitGroup(
							 | 
						||
| 
								 | 
							
								                    conf_limits, self._key_func, None, False, None, None, None, 1, False
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            ]
							 | 
						||
| 
								 | 
							
								        fallback_enabled = self.get_app_config(C.IN_MEMORY_FALLBACK_ENABLED, False)
							 | 
						||
| 
								 | 
							
								        fallback_limits: Optional[StrOrCallableStr] = self.get_app_config(
							 | 
						||
| 
								 | 
							
								            C.IN_MEMORY_FALLBACK, None
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        if not self._in_memory_fallback and fallback_limits:
							 | 
						||
| 
								 | 
							
								            self._in_memory_fallback = [
							 | 
						||
| 
								 | 
							
								                LimitGroup(
							 | 
						||
| 
								 | 
							
								                    fallback_limits,
							 | 
						||
| 
								 | 
							
								                    self._key_func,
							 | 
						||
| 
								 | 
							
								                    None,
							 | 
						||
| 
								 | 
							
								                    False,
							 | 
						||
| 
								 | 
							
								                    None,
							 | 
						||
| 
								 | 
							
								                    None,
							 | 
						||
| 
								 | 
							
								                    None,
							 | 
						||
| 
								 | 
							
								                    1,
							 | 
						||
| 
								 | 
							
								                    False,
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            ]
							 | 
						||
| 
								 | 
							
								        if not self._in_memory_fallback_enabled:
							 | 
						||
| 
								 | 
							
								            self._in_memory_fallback_enabled = (
							 | 
						||
| 
								 | 
							
								                fallback_enabled or len(self._in_memory_fallback) > 0
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self._in_memory_fallback_enabled:
							 | 
						||
| 
								 | 
							
								            self._fallback_storage = MemoryStorage()
							 | 
						||
| 
								 | 
							
								            self._fallback_limiter = STRATEGIES[strategy](self._fallback_storage)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def slowapi_startup(self) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Starlette startup event handler that links the app with the Limiter instance.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        app.state.limiter = self  # type: ignore
							 | 
						||
| 
								 | 
							
								        app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)  # type: ignore
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_app_config(self, key: str, default_value: T = None) -> T:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Place holder until we find a better way to load config from app
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return self.app_config(key, default=default_value, cast=type(default_value))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __should_check_backend(self) -> bool:
							 | 
						||
| 
								 | 
							
								        if self.__check_backend_count > MAX_BACKEND_CHECKS:
							 | 
						||
| 
								 | 
							
								            self.__check_backend_count = 0
							 | 
						||
| 
								 | 
							
								        if time.time() - self.__last_check_backend > pow(2, self.__check_backend_count):
							 | 
						||
| 
								 | 
							
								            self.__last_check_backend = time.time()
							 | 
						||
| 
								 | 
							
								            self.__check_backend_count += 1
							 | 
						||
| 
								 | 
							
								            return True
							 | 
						||
| 
								 | 
							
								        return False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def reset(self) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        resets the storage if it supports being reset
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            self._storage.reset()
							 | 
						||
| 
								 | 
							
								            self.logger.info("Storage has been reset and all limits cleared")
							 | 
						||
| 
								 | 
							
								        except NotImplementedError:
							 | 
						||
| 
								 | 
							
								            self.logger.warning("This storage type does not support being reset")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @property
							 | 
						||
| 
								 | 
							
								    def limiter(self) -> RateLimiter:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        The backend that keeps track of consumption of endpoints vs limits
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        if self._storage_dead and self._in_memory_fallback_enabled:
							 | 
						||
| 
								 | 
							
								            assert (
							 | 
						||
| 
								 | 
							
								                self._fallback_limiter
							 | 
						||
| 
								 | 
							
								            ), "Fallback limiter is needed when in memory fallback is enabled"
							 | 
						||
| 
								 | 
							
								            return self._fallback_limiter
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            return self._limiter
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _inject_headers(
							 | 
						||
| 
								 | 
							
								        self, response: Response, current_limit: Tuple[RateLimitItem, List[str]]
							 | 
						||
| 
								 | 
							
								    ) -> Response:
							 | 
						||
| 
								 | 
							
								        if self.enabled and self._headers_enabled and current_limit is not None:
							 | 
						||
| 
								 | 
							
								            if not isinstance(response, Response):
							 | 
						||
| 
								 | 
							
								                raise Exception(
							 | 
						||
| 
								 | 
							
								                    "parameter `response` must be an instance of starlette.responses.Response"
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                window_stats: Tuple[int, int] = self.limiter.get_window_stats(
							 | 
						||
| 
								 | 
							
								                    current_limit[0], *current_limit[1]
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                reset_in = 1 + window_stats[0]
							 | 
						||
| 
								 | 
							
								                response.headers.append(
							 | 
						||
| 
								 | 
							
								                    self._header_mapping[HEADERS.LIMIT], str(current_limit[0].amount)
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                response.headers.append(
							 | 
						||
| 
								 | 
							
								                    self._header_mapping[HEADERS.REMAINING], str(window_stats[1])
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                response.headers.append(
							 | 
						||
| 
								 | 
							
								                    self._header_mapping[HEADERS.RESET], str(reset_in)
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                # response may have an existing retry after
							 | 
						||
| 
								 | 
							
								                existing_retry_after_header = response.headers.get("Retry-After")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                if existing_retry_after_header is not None:
							 | 
						||
| 
								 | 
							
								                    reset_in = max(
							 | 
						||
| 
								 | 
							
								                        self._determine_retry_time(existing_retry_after_header),
							 | 
						||
| 
								 | 
							
								                        reset_in,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                response.headers[self._header_mapping[HEADERS.RETRY_AFTER]] = (
							 | 
						||
| 
								 | 
							
								                    formatdate(reset_in)
							 | 
						||
| 
								 | 
							
								                    if self._retry_after == "http-date"
							 | 
						||
| 
								 | 
							
								                    else str(int(reset_in - time.time()))
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            except:
							 | 
						||
| 
								 | 
							
								                if self._in_memory_fallback and not self._storage_dead:
							 | 
						||
| 
								 | 
							
								                    self.logger.warning(
							 | 
						||
| 
								 | 
							
								                        "Rate limit storage unreachable - falling back to"
							 | 
						||
| 
								 | 
							
								                        " in-memory storage"
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                    self._storage_dead = True
							 | 
						||
| 
								 | 
							
								                    response = self._inject_headers(response, current_limit)
							 | 
						||
| 
								 | 
							
								                if self._swallow_errors:
							 | 
						||
| 
								 | 
							
								                    self.logger.exception(
							 | 
						||
| 
								 | 
							
								                        "Failed to update rate limit headers. Swallowing error"
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    raise
							 | 
						||
| 
								 | 
							
								        return response
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _inject_asgi_headers(
							 | 
						||
| 
								 | 
							
								        self, headers: MutableHeaders, current_limit: Tuple[RateLimitItem, List[str]]
							 | 
						||
| 
								 | 
							
								    ) -> MutableHeaders:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Injects 'X-RateLimit-Reset', 'X-RateLimit-Remaining', 'X-RateLimit-Limit'
							 | 
						||
| 
								 | 
							
								        and 'Retry-After' headers into :headers parameter if needed.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Basically the same as _inject_headers, but without access to the Response object.
							 | 
						||
| 
								 | 
							
								        -> supports ASGI Middlewares.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        if self.enabled and self._headers_enabled and current_limit is not None:
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                window_stats: Tuple[int, int] = self.limiter.get_window_stats(
							 | 
						||
| 
								 | 
							
								                    current_limit[0], *current_limit[1]
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                reset_in = 1 + window_stats[0]
							 | 
						||
| 
								 | 
							
								                headers[self._header_mapping[HEADERS.LIMIT]] = str(
							 | 
						||
| 
								 | 
							
								                    current_limit[0].amount
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                headers[self._header_mapping[HEADERS.REMAINING]] = str(window_stats[1])
							 | 
						||
| 
								 | 
							
								                headers[self._header_mapping[HEADERS.RESET]] = str(reset_in)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                # response may have an existing retry after
							 | 
						||
| 
								 | 
							
								                existing_retry_after_header = headers.get("Retry-After")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                if existing_retry_after_header is not None:
							 | 
						||
| 
								 | 
							
								                    reset_in = max(
							 | 
						||
| 
								 | 
							
								                        self._determine_retry_time(existing_retry_after_header),
							 | 
						||
| 
								 | 
							
								                        reset_in,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                headers[self._header_mapping[HEADERS.RETRY_AFTER]] = (
							 | 
						||
| 
								 | 
							
								                    formatdate(reset_in)
							 | 
						||
| 
								 | 
							
								                    if self._retry_after == "http-date"
							 | 
						||
| 
								 | 
							
								                    else str(int(reset_in - time.time()))
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            except Exception:
							 | 
						||
| 
								 | 
							
								                if self._in_memory_fallback and not self._storage_dead:
							 | 
						||
| 
								 | 
							
								                    self.logger.warning(
							 | 
						||
| 
								 | 
							
								                        "Rate limit storage unreachable - falling back to"
							 | 
						||
| 
								 | 
							
								                        " in-memory storage"
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                    self._storage_dead = True
							 | 
						||
| 
								 | 
							
								                    headers = self._inject_asgi_headers(headers, current_limit)
							 | 
						||
| 
								 | 
							
								                if self._swallow_errors:
							 | 
						||
| 
								 | 
							
								                    self.logger.exception(
							 | 
						||
| 
								 | 
							
								                        "Failed to update rate limit headers. Swallowing error"
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    raise
							 | 
						||
| 
								 | 
							
								        return headers
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __evaluate_limits(
							 | 
						||
| 
								 | 
							
								        self, request: Request, endpoint: str, limits: List[Limit]
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        failed_limit = None
							 | 
						||
| 
								 | 
							
								        limit_for_header = None
							 | 
						||
| 
								 | 
							
								        for lim in limits:
							 | 
						||
| 
								 | 
							
								            limit_scope = lim.scope or endpoint
							 | 
						||
| 
								 | 
							
								            if lim.is_exempt:
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								            if lim.methods is not None and request.method.lower() not in lim.methods:
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								            if lim.per_method:
							 | 
						||
| 
								 | 
							
								                limit_scope += ":%s" % request.method
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if "request" in inspect.signature(lim.key_func).parameters.keys():
							 | 
						||
| 
								 | 
							
								                limit_key = lim.key_func(request)
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                limit_key = lim.key_func()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            args = [limit_key, limit_scope]
							 | 
						||
| 
								 | 
							
								            if all(args):
							 | 
						||
| 
								 | 
							
								                if self._key_prefix:
							 | 
						||
| 
								 | 
							
								                    args = [self._key_prefix] + args
							 | 
						||
| 
								 | 
							
								                if not limit_for_header or lim.limit < limit_for_header[0]:
							 | 
						||
| 
								 | 
							
								                    limit_for_header = (lim.limit, args)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                cost = lim.cost(request) if callable(lim.cost) else lim.cost
							 | 
						||
| 
								 | 
							
								                if not self.limiter.hit(lim.limit, *args, cost=cost):
							 | 
						||
| 
								 | 
							
								                    self.logger.warning(
							 | 
						||
| 
								 | 
							
								                        "ratelimit %s (%s) exceeded at endpoint: %s",
							 | 
						||
| 
								 | 
							
								                        lim.limit,
							 | 
						||
| 
								 | 
							
								                        limit_key,
							 | 
						||
| 
								 | 
							
								                        limit_scope,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                    failed_limit = lim
							 | 
						||
| 
								 | 
							
								                    limit_for_header = (lim.limit, args)
							 | 
						||
| 
								 | 
							
								                    break
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                self.logger.error(
							 | 
						||
| 
								 | 
							
								                    "Skipping limit: %s. Empty value found in parameters.", lim.limit
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								        # keep track of which limit was hit, to be picked up for the response header
							 | 
						||
| 
								 | 
							
								        request.state.view_rate_limit = limit_for_header
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if failed_limit:
							 | 
						||
| 
								 | 
							
								            raise RateLimitExceeded(failed_limit)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _determine_retry_time(self, retry_header_value) -> int:
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            retry_after_date: Optional[datetime] = parsedate_to_datetime(
							 | 
						||
| 
								 | 
							
								                retry_header_value
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        except (TypeError, ValueError):
							 | 
						||
| 
								 | 
							
								            retry_after_date = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if retry_after_date is not None:
							 | 
						||
| 
								 | 
							
								            return int(time.mktime(retry_after_date.timetuple()))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            retry_after_int: int = int(retry_header_value)
							 | 
						||
| 
								 | 
							
								        except TypeError:
							 | 
						||
| 
								 | 
							
								            raise ValueError(
							 | 
						||
| 
								 | 
							
								                "Retry-After Header does not meet RFC2616 - value is not of http-date or int type."
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return int(time.time() + retry_after_int)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _check_request_limit(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        request: Request,
							 | 
						||
| 
								 | 
							
								        endpoint_func: Optional[Callable[..., Any]],
							 | 
						||
| 
								 | 
							
								        in_middleware: bool = True,
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Determine if the request is within limits
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        endpoint_url = request["path"] or ""
							 | 
						||
| 
								 | 
							
								        view_func = endpoint_func
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        endpoint_func_name = (
							 | 
						||
| 
								 | 
							
								            f"{view_func.__module__}.{view_func.__name__}" if view_func else ""
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        _endpoint_key = endpoint_url if self._key_style == "url" else endpoint_func_name
							 | 
						||
| 
								 | 
							
								        # cases where we don't need to check the limits
							 | 
						||
| 
								 | 
							
								        if (
							 | 
						||
| 
								 | 
							
								            not _endpoint_key
							 | 
						||
| 
								 | 
							
								            or not self.enabled
							 | 
						||
| 
								 | 
							
								            # or we are sending a static file
							 | 
						||
| 
								 | 
							
								            # or view_func == current_app.send_static_file
							 | 
						||
| 
								 | 
							
								            or endpoint_func_name in self._exempt_routes
							 | 
						||
| 
								 | 
							
								            or any(fn() for fn in self._request_filters)
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								        limits: List[Limit] = []
							 | 
						||
| 
								 | 
							
								        dynamic_limits: List[Limit] = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not in_middleware:
							 | 
						||
| 
								 | 
							
								            limits = (
							 | 
						||
| 
								 | 
							
								                self._route_limits[endpoint_func_name]
							 | 
						||
| 
								 | 
							
								                if endpoint_func_name in self._route_limits
							 | 
						||
| 
								 | 
							
								                else []
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								            dynamic_limits = []
							 | 
						||
| 
								 | 
							
								            if endpoint_func_name in self._dynamic_route_limits:
							 | 
						||
| 
								 | 
							
								                for lim in self._dynamic_route_limits[endpoint_func_name]:
							 | 
						||
| 
								 | 
							
								                    try:
							 | 
						||
| 
								 | 
							
								                        dynamic_limits.extend(list(lim.with_request(request)))
							 | 
						||
| 
								 | 
							
								                    except ValueError as e:
							 | 
						||
| 
								 | 
							
								                        self.logger.error(
							 | 
						||
| 
								 | 
							
								                            "failed to load ratelimit for view function %s (%s)",
							 | 
						||
| 
								 | 
							
								                            endpoint_func_name,
							 | 
						||
| 
								 | 
							
								                            e,
							 | 
						||
| 
								 | 
							
								                        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            all_limits: List[Limit] = []
							 | 
						||
| 
								 | 
							
								            if self._storage_dead and self._fallback_limiter:
							 | 
						||
| 
								 | 
							
								                if in_middleware and endpoint_func_name in self.__marked_for_limiting:
							 | 
						||
| 
								 | 
							
								                    pass
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    if self.__should_check_backend() and self._storage.check():
							 | 
						||
| 
								 | 
							
								                        self.logger.info("Rate limit storage recovered")
							 | 
						||
| 
								 | 
							
								                        self._storage_dead = False
							 | 
						||
| 
								 | 
							
								                        self.__check_backend_count = 0
							 | 
						||
| 
								 | 
							
								                    else:
							 | 
						||
| 
								 | 
							
								                        all_limits = list(itertools.chain(*self._in_memory_fallback))
							 | 
						||
| 
								 | 
							
								            if not all_limits:
							 | 
						||
| 
								 | 
							
								                route_limits: List[Limit] = limits + dynamic_limits
							 | 
						||
| 
								 | 
							
								                all_limits = (
							 | 
						||
| 
								 | 
							
								                    list(itertools.chain(*self._application_limits))
							 | 
						||
| 
								 | 
							
								                    if in_middleware
							 | 
						||
| 
								 | 
							
								                    else []
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                all_limits += route_limits
							 | 
						||
| 
								 | 
							
								                combined_defaults = all(
							 | 
						||
| 
								 | 
							
								                    not limit.override_defaults for limit in route_limits
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                if (
							 | 
						||
| 
								 | 
							
								                    not route_limits
							 | 
						||
| 
								 | 
							
								                    and not (
							 | 
						||
| 
								 | 
							
								                        in_middleware
							 | 
						||
| 
								 | 
							
								                        and endpoint_func_name in self.__marked_for_limiting
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                    or combined_defaults
							 | 
						||
| 
								 | 
							
								                ):
							 | 
						||
| 
								 | 
							
								                    all_limits += list(itertools.chain(*self._default_limits))
							 | 
						||
| 
								 | 
							
								            # actually check the limits, so far we've only computed the list of limits to check
							 | 
						||
| 
								 | 
							
								            self.__evaluate_limits(request, _endpoint_key, all_limits)
							 | 
						||
| 
								 | 
							
								        except Exception as e:  # no qa
							 | 
						||
| 
								 | 
							
								            if isinstance(e, RateLimitExceeded):
							 | 
						||
| 
								 | 
							
								                raise
							 | 
						||
| 
								 | 
							
								            if self._in_memory_fallback_enabled and not self._storage_dead:
							 | 
						||
| 
								 | 
							
								                self.logger.warn(
							 | 
						||
| 
								 | 
							
								                    "Rate limit storage unreachable - falling back to"
							 | 
						||
| 
								 | 
							
								                    " in-memory storage"
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                self._storage_dead = True
							 | 
						||
| 
								 | 
							
								                self._check_request_limit(request, endpoint_func, in_middleware)
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                if self._swallow_errors:
							 | 
						||
| 
								 | 
							
								                    self.logger.exception("Failed to rate limit. Swallowing error")
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    raise
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __limit_decorator(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        limit_value: StrOrCallableStr,
							 | 
						||
| 
								 | 
							
								        key_func: Optional[Callable[..., str]] = None,
							 | 
						||
| 
								 | 
							
								        shared: bool = False,
							 | 
						||
| 
								 | 
							
								        scope: Optional[StrOrCallableStr] = None,
							 | 
						||
| 
								 | 
							
								        per_method: bool = False,
							 | 
						||
| 
								 | 
							
								        methods: Optional[List[str]] = None,
							 | 
						||
| 
								 | 
							
								        error_message: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        exempt_when: Optional[Callable[..., bool]] = None,
							 | 
						||
| 
								 | 
							
								        cost: Union[int, Callable[..., int]] = 1,
							 | 
						||
| 
								 | 
							
								        override_defaults: bool = True,
							 | 
						||
| 
								 | 
							
								    ) -> Callable[..., Any]:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        _scope = scope if shared else None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        def decorator(func: Callable[..., Response]):
							 | 
						||
| 
								 | 
							
								            keyfunc = key_func or self._key_func
							 | 
						||
| 
								 | 
							
								            name = f"{func.__module__}.{func.__name__}"
							 | 
						||
| 
								 | 
							
								            dynamic_limit = None
							 | 
						||
| 
								 | 
							
								            static_limits: List[Limit] = []
							 | 
						||
| 
								 | 
							
								            if callable(limit_value):
							 | 
						||
| 
								 | 
							
								                dynamic_limit = LimitGroup(
							 | 
						||
| 
								 | 
							
								                    limit_value,
							 | 
						||
| 
								 | 
							
								                    keyfunc,
							 | 
						||
| 
								 | 
							
								                    _scope,
							 | 
						||
| 
								 | 
							
								                    per_method,
							 | 
						||
| 
								 | 
							
								                    methods,
							 | 
						||
| 
								 | 
							
								                    error_message,
							 | 
						||
| 
								 | 
							
								                    exempt_when,
							 | 
						||
| 
								 | 
							
								                    cost,
							 | 
						||
| 
								 | 
							
								                    override_defaults,
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                try:
							 | 
						||
| 
								 | 
							
								                    static_limits = list(
							 | 
						||
| 
								 | 
							
								                        LimitGroup(
							 | 
						||
| 
								 | 
							
								                            limit_value,
							 | 
						||
| 
								 | 
							
								                            keyfunc,
							 | 
						||
| 
								 | 
							
								                            _scope,
							 | 
						||
| 
								 | 
							
								                            per_method,
							 | 
						||
| 
								 | 
							
								                            methods,
							 | 
						||
| 
								 | 
							
								                            error_message,
							 | 
						||
| 
								 | 
							
								                            exempt_when,
							 | 
						||
| 
								 | 
							
								                            cost,
							 | 
						||
| 
								 | 
							
								                            override_defaults,
							 | 
						||
| 
								 | 
							
								                        )
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                except ValueError as e:
							 | 
						||
| 
								 | 
							
								                    self.logger.error(
							 | 
						||
| 
								 | 
							
								                        "Failed to configure throttling for %s (%s)",
							 | 
						||
| 
								 | 
							
								                        name,
							 | 
						||
| 
								 | 
							
								                        e,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								            self.__marked_for_limiting.setdefault(name, []).append(func)
							 | 
						||
| 
								 | 
							
								            if dynamic_limit:
							 | 
						||
| 
								 | 
							
								                self._dynamic_route_limits.setdefault(name, []).append(dynamic_limit)
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                self._route_limits.setdefault(name, []).extend(static_limits)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            connection_type: Optional[str] = None
							 | 
						||
| 
								 | 
							
								            sig = inspect.signature(func)
							 | 
						||
| 
								 | 
							
								            for idx, parameter in enumerate(sig.parameters.values()):
							 | 
						||
| 
								 | 
							
								                if parameter.name == "request" or parameter.name == "websocket":
							 | 
						||
| 
								 | 
							
								                    connection_type = parameter.name
							 | 
						||
| 
								 | 
							
								                    break
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                raise Exception(
							 | 
						||
| 
								 | 
							
								                    f'No "request" or "websocket" argument on function "{func}"'
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if asyncio.iscoroutinefunction(func):
							 | 
						||
| 
								 | 
							
								                # Handle async request/response functions.
							 | 
						||
| 
								 | 
							
								                @functools.wraps(func)
							 | 
						||
| 
								 | 
							
								                async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
							 | 
						||
| 
								 | 
							
								                    # get the request object from the decorated endpoint function
							 | 
						||
| 
								 | 
							
								                    if self.enabled:
							 | 
						||
| 
								 | 
							
								                        request = kwargs.get("request", args[idx] if args else None)
							 | 
						||
| 
								 | 
							
								                        if not isinstance(request, Request):
							 | 
						||
| 
								 | 
							
								                            raise Exception(
							 | 
						||
| 
								 | 
							
								                                "parameter `request` must be an instance of starlette.requests.Request"
							 | 
						||
| 
								 | 
							
								                            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                        if self._auto_check and not getattr(
							 | 
						||
| 
								 | 
							
								                            request.state, "_rate_limiting_complete", False
							 | 
						||
| 
								 | 
							
								                        ):
							 | 
						||
| 
								 | 
							
								                            self._check_request_limit(request, func, False)
							 | 
						||
| 
								 | 
							
								                            request.state._rate_limiting_complete = True
							 | 
						||
| 
								 | 
							
								                    response = await func(*args, **kwargs)  # type: ignore
							 | 
						||
| 
								 | 
							
								                    if self.enabled:
							 | 
						||
| 
								 | 
							
								                        if not isinstance(response, Response):
							 | 
						||
| 
								 | 
							
								                            # get the response object from the decorated endpoint function
							 | 
						||
| 
								 | 
							
								                            self._inject_headers(
							 | 
						||
| 
								 | 
							
								                                kwargs.get("response"), request.state.view_rate_limit  # type: ignore
							 | 
						||
| 
								 | 
							
								                            )
							 | 
						||
| 
								 | 
							
								                        else:
							 | 
						||
| 
								 | 
							
								                            self._inject_headers(
							 | 
						||
| 
								 | 
							
								                                response, request.state.view_rate_limit
							 | 
						||
| 
								 | 
							
								                            )
							 | 
						||
| 
								 | 
							
								                    return response
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                return async_wrapper
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                # Handle sync request/response functions.
							 | 
						||
| 
								 | 
							
								                @functools.wraps(func)
							 | 
						||
| 
								 | 
							
								                def sync_wrapper(*args: Any, **kwargs: Any) -> Response:
							 | 
						||
| 
								 | 
							
								                    # get the request object from the decorated endpoint function
							 | 
						||
| 
								 | 
							
								                    if self.enabled:
							 | 
						||
| 
								 | 
							
								                        request = kwargs.get("request", args[idx] if args else None)
							 | 
						||
| 
								 | 
							
								                        if not isinstance(request, Request):
							 | 
						||
| 
								 | 
							
								                            raise Exception(
							 | 
						||
| 
								 | 
							
								                                "parameter `request` must be an instance of starlette.requests.Request"
							 | 
						||
| 
								 | 
							
								                            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                        if self._auto_check and not getattr(
							 | 
						||
| 
								 | 
							
								                            request.state, "_rate_limiting_complete", False
							 | 
						||
| 
								 | 
							
								                        ):
							 | 
						||
| 
								 | 
							
								                            self._check_request_limit(request, func, False)
							 | 
						||
| 
								 | 
							
								                            request.state._rate_limiting_complete = True
							 | 
						||
| 
								 | 
							
								                    response = func(*args, **kwargs)
							 | 
						||
| 
								 | 
							
								                    if self.enabled:
							 | 
						||
| 
								 | 
							
								                        if not isinstance(response, Response):
							 | 
						||
| 
								 | 
							
								                            # get the response object from the decorated endpoint function
							 | 
						||
| 
								 | 
							
								                            self._inject_headers(
							 | 
						||
| 
								 | 
							
								                                kwargs.get("response"), request.state.view_rate_limit  # type: ignore
							 | 
						||
| 
								 | 
							
								                            )
							 | 
						||
| 
								 | 
							
								                        else:
							 | 
						||
| 
								 | 
							
								                            self._inject_headers(
							 | 
						||
| 
								 | 
							
								                                response, request.state.view_rate_limit
							 | 
						||
| 
								 | 
							
								                            )
							 | 
						||
| 
								 | 
							
								                    return response
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                return sync_wrapper
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return decorator
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def limit(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        limit_value: Union[str, Callable[[str], str]],
							 | 
						||
| 
								 | 
							
								        key_func: Optional[Callable[..., str]] = None,
							 | 
						||
| 
								 | 
							
								        per_method: bool = False,
							 | 
						||
| 
								 | 
							
								        methods: Optional[List[str]] = None,
							 | 
						||
| 
								 | 
							
								        error_message: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        exempt_when: Optional[Callable[..., bool]] = None,
							 | 
						||
| 
								 | 
							
								        cost: Union[int, Callable[..., int]] = 1,
							 | 
						||
| 
								 | 
							
								        override_defaults: bool = True,
							 | 
						||
| 
								 | 
							
								    ) -> Callable:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Decorator to be used for rate limiting individual routes.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        * **limit_value**: rate limit string or a callable that returns a string.
							 | 
						||
| 
								 | 
							
								         :ref:`ratelimit-string` for more details.
							 | 
						||
| 
								 | 
							
								        * **key_func**: function/lambda to extract the unique identifier for
							 | 
						||
| 
								 | 
							
								         the rate limit. defaults to remote address of the request.
							 | 
						||
| 
								 | 
							
								        * **per_method**: whether the limit is sub categorized into the http
							 | 
						||
| 
								 | 
							
								         method of the request.
							 | 
						||
| 
								 | 
							
								        * **methods**: if specified, only the methods in this list will be rate
							 | 
						||
| 
								 | 
							
								         limited (default: None).
							 | 
						||
| 
								 | 
							
								        * **error_message**: string (or callable that returns one) to override the
							 | 
						||
| 
								 | 
							
								         error message used in the response.
							 | 
						||
| 
								 | 
							
								        * **exempt_when**: function returning a boolean indicating whether to exempt
							 | 
						||
| 
								 | 
							
								        the route from the limit
							 | 
						||
| 
								 | 
							
								        * **cost**: integer (or callable that returns one) which is the cost of a hit
							 | 
						||
| 
								 | 
							
								        * **override_defaults**: whether to override the default limits (default: True)
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return self.__limit_decorator(
							 | 
						||
| 
								 | 
							
								            limit_value,
							 | 
						||
| 
								 | 
							
								            key_func,
							 | 
						||
| 
								 | 
							
								            per_method=per_method,
							 | 
						||
| 
								 | 
							
								            methods=methods,
							 | 
						||
| 
								 | 
							
								            error_message=error_message,
							 | 
						||
| 
								 | 
							
								            exempt_when=exempt_when,
							 | 
						||
| 
								 | 
							
								            cost=cost,
							 | 
						||
| 
								 | 
							
								            override_defaults=override_defaults,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def shared_limit(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        limit_value: Union[str, Callable[[str], str]],
							 | 
						||
| 
								 | 
							
								        scope: StrOrCallableStr,
							 | 
						||
| 
								 | 
							
								        key_func: Optional[Callable[..., str]] = None,
							 | 
						||
| 
								 | 
							
								        error_message: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        exempt_when: Optional[Callable[..., bool]] = None,
							 | 
						||
| 
								 | 
							
								        cost: Union[int, Callable[..., int]] = 1,
							 | 
						||
| 
								 | 
							
								        override_defaults: bool = True,
							 | 
						||
| 
								 | 
							
								    ) -> Callable:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Decorator to be applied to multiple routes sharing the same rate limit.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        * **limit_value**: rate limit string or a callable that returns a string.
							 | 
						||
| 
								 | 
							
								         :ref:`ratelimit-string` for more details.
							 | 
						||
| 
								 | 
							
								        * **scope**: a string or callable that returns a string
							 | 
						||
| 
								 | 
							
								         for defining the rate limiting scope.
							 | 
						||
| 
								 | 
							
								        * **key_func**: function/lambda to extract the unique identifier for
							 | 
						||
| 
								 | 
							
								         the rate limit. defaults to remote address of the request.
							 | 
						||
| 
								 | 
							
								        * **per_method**: whether the limit is sub categorized into the http
							 | 
						||
| 
								 | 
							
								         method of the request.
							 | 
						||
| 
								 | 
							
								        * **methods**: if specified, only the methods in this list will be rate
							 | 
						||
| 
								 | 
							
								         limited (default: None).
							 | 
						||
| 
								 | 
							
								        * **error_message**: string (or callable that returns one) to override the
							 | 
						||
| 
								 | 
							
								         error message used in the response.
							 | 
						||
| 
								 | 
							
								        * **exempt_when**: function returning a boolean indicating whether to exempt
							 | 
						||
| 
								 | 
							
								        the route from the limit
							 | 
						||
| 
								 | 
							
								        * **cost**: integer (or callable that returns one) which is the cost of a hit
							 | 
						||
| 
								 | 
							
								        * **override_defaults**: whether to override the default limits (default: True)
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return self.__limit_decorator(
							 | 
						||
| 
								 | 
							
								            limit_value,
							 | 
						||
| 
								 | 
							
								            key_func,
							 | 
						||
| 
								 | 
							
								            True,
							 | 
						||
| 
								 | 
							
								            scope,
							 | 
						||
| 
								 | 
							
								            error_message=error_message,
							 | 
						||
| 
								 | 
							
								            exempt_when=exempt_when,
							 | 
						||
| 
								 | 
							
								            cost=cost,
							 | 
						||
| 
								 | 
							
								            override_defaults=override_defaults,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def exempt(self, obj):
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Decorator to mark a view as exempt from rate limits.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        name = "%s.%s" % (obj.__module__, obj.__name__)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self._exempt_routes.add(name)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if asyncio.iscoroutinefunction(obj):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            @wraps(obj)
							 | 
						||
| 
								 | 
							
								            async def __async_inner(*a, **k):
							 | 
						||
| 
								 | 
							
								                return await obj(*a, **k)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            return __async_inner
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            @wraps(obj)
							 | 
						||
| 
								 | 
							
								            def __inner(*a, **k):
							 | 
						||
| 
								 | 
							
								                return obj(*a, **k)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            return __inner
							 |