You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

210 lines
6.3 KiB

"""
Asynchronous rate limiting strategies
"""
import weakref
from abc import ABC, abstractmethod
from typing import Tuple, cast
from ..limits import RateLimitItem
from ..storage import StorageTypes
from .storage import MovingWindowSupport, Storage
class RateLimiter(ABC):
def __init__(self, storage: StorageTypes):
assert isinstance(storage, Storage)
self.storage: Storage = weakref.proxy(storage)
@abstractmethod
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
"""
raise NotImplementedError
@abstractmethod
async def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
"""
raise NotImplementedError
@abstractmethod
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> Tuple[int, int]:
"""
Query the reset time and remaining amount for the limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: (reset time, remaining))
"""
raise NotImplementedError
async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
return await self.storage.clear(item.key_for(*identifiers))
class MovingWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:moving window`
"""
def __init__(self, storage: StorageTypes) -> None:
if not (
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
):
raise NotImplementedError(
"MovingWindowRateLimiting is not implemented for storage "
"of type %s" % storage.__class__
)
super().__init__(storage)
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
"""
return await cast(MovingWindowSupport, self.storage).acquire_entry(
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
)
async def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
"""
res = await cast(MovingWindowSupport, self.storage).get_moving_window(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
)
amount = res[1]
return amount < item.amount
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> Tuple[int, int]:
"""
returns the number of requests remaining within this limit.
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: (reset time, remaining)
"""
window_start, window_items = await cast(
MovingWindowSupport, self.storage
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
reset = window_start + item.get_expiry()
return reset, item.amount - window_items
class FixedWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:fixed window`
"""
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
"""
return (
await self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
elastic_expiry=False,
amount=cost,
)
<= item.amount
)
async def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
"""
return await self.storage.get(item.key_for(*identifiers)) < item.amount
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> Tuple[int, int]:
"""
Query the reset time and remaining amount for the limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: reset time, remaining
"""
remaining = max(
0,
item.amount - await self.storage.get(item.key_for(*identifiers)),
)
reset = await self.storage.get_expiry(item.key_for(*identifiers))
return reset, remaining
class FixedWindowElasticExpiryRateLimiter(FixedWindowRateLimiter):
"""
Reference: :ref:`strategies:fixed window with elastic expiry`
"""
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: a :class:`limits.limits.RateLimitItem` instance
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
"""
amount = await self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
elastic_expiry=True,
amount=cost,
)
return amount <= item.amount
STRATEGIES = {
"fixed-window": FixedWindowRateLimiter,
"fixed-window-elastic-expiry": FixedWindowElasticExpiryRateLimiter,
"moving-window": MovingWindowRateLimiter,
}