You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

216 lines
6.6 KiB

"""
Rate limiting strategies
"""
import weakref
from abc import ABCMeta, abstractmethod
from typing import Dict, Tuple, Type, Union, cast
from .limits import RateLimitItem
from .storage import MovingWindowSupport, Storage, StorageTypes
class RateLimiter(metaclass=ABCMeta):
def __init__(self, storage: StorageTypes):
assert isinstance(storage, Storage)
self.storage: Storage = weakref.proxy(storage)
@abstractmethod
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
raise NotImplementedError
@abstractmethod
def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check the rate limit without consuming from it.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
"""
raise NotImplementedError
@abstractmethod
def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> Tuple[int, int]:
"""
Query the reset time and remaining amount for the limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: (reset time, remaining)
"""
raise NotImplementedError
def clear(self, item: RateLimitItem, *identifiers: str) -> None:
return self.storage.clear(item.key_for(*identifiers))
class MovingWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:moving window`
"""
def __init__(self, storage: StorageTypes):
if not (
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
):
raise NotImplementedError(
"MovingWindowRateLimiting is not implemented for storage "
"of type %s" % storage.__class__
)
super().__init__(storage)
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
:return: (reset time, remaining)
"""
return cast(MovingWindowSupport, self.storage).acquire_entry(
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
)
def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
"""
return (
cast(MovingWindowSupport, self.storage).get_moving_window(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
)[1]
< item.amount
)
def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> Tuple[int, int]:
"""
returns the number of requests remaining within this limit.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: tuple (reset time, remaining)
"""
window_start, window_items = cast(
MovingWindowSupport, self.storage
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
reset = window_start + item.get_expiry()
return (reset, item.amount - window_items)
class FixedWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:fixed window`
"""
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
return (
self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
elastic_expiry=False,
amount=cost,
)
<= item.amount
)
def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
"""
return self.storage.get(item.key_for(*identifiers)) < item.amount
def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> Tuple[int, int]:
"""
Query the reset time and remaining amount for the limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: (reset time, remaining)
"""
remaining = max(0, item.amount - self.storage.get(item.key_for(*identifiers)))
reset = self.storage.get_expiry(item.key_for(*identifiers))
return (reset, remaining)
class FixedWindowElasticExpiryRateLimiter(FixedWindowRateLimiter):
"""
Reference: :ref:`strategies:fixed window with elastic expiry`
"""
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
return (
self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
elastic_expiry=True,
amount=cost,
)
<= item.amount
)
KnownStrategy = Union[
Type[FixedWindowRateLimiter],
Type[FixedWindowElasticExpiryRateLimiter],
Type[MovingWindowRateLimiter],
]
STRATEGIES: Dict[str, KnownStrategy] = {
"fixed-window": FixedWindowRateLimiter,
"fixed-window-elastic-expiry": FixedWindowElasticExpiryRateLimiter,
"moving-window": MovingWindowRateLimiter,
}