You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
268 lines
8.8 KiB
268 lines
8.8 KiB
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import calendar
|
|
import datetime
|
|
import time
|
|
from typing import Any
|
|
|
|
from deprecated.sphinx import versionadded
|
|
|
|
from limits.aio.storage.base import MovingWindowSupport, Storage
|
|
from limits.typing import Dict, Optional, ParamSpec, Tuple, TypeVar, Union
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
|
|
|
|
@versionadded(version="2.1")
|
|
class MongoDBStorage(Storage, MovingWindowSupport):
|
|
"""
|
|
Rate limit storage with MongoDB as backend.
|
|
|
|
Depends on :pypi:`motor`
|
|
"""
|
|
|
|
STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
|
|
"""
|
|
The storage scheme for MongoDB for use in an async context
|
|
"""
|
|
|
|
DEFAULT_OPTIONS: Dict[str, Union[float, str, bool]] = {
|
|
"serverSelectionTimeoutMS": 1000,
|
|
"socketTimeoutMS": 1000,
|
|
"connectTimeoutMS": 1000,
|
|
}
|
|
"Default options passed to :class:`~motor.motor_asyncio.AsyncIOMotorClient`"
|
|
|
|
DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
|
|
|
|
def __init__(
|
|
self,
|
|
uri: str,
|
|
database_name: str = "limits",
|
|
**options: Union[float, str, bool],
|
|
) -> None:
|
|
"""
|
|
:param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
|
|
This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
|
|
:param database_name: The database to use for storing the rate limit
|
|
collections.
|
|
:param options: all remaining keyword arguments are merged with
|
|
:data:`DEFAULT_OPTIONS` and passed to the constructor of
|
|
:class:`~motor.motor_asyncio.AsyncIOMotorClient`
|
|
:raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
|
|
not available
|
|
"""
|
|
|
|
mongo_opts = options.copy()
|
|
[mongo_opts.setdefault(k, v) for k, v in self.DEFAULT_OPTIONS.items()]
|
|
uri = uri.replace("async+mongodb", "mongodb", 1)
|
|
|
|
super().__init__(uri, **options)
|
|
|
|
self.dependency = self.dependencies["motor.motor_asyncio"]
|
|
self.proxy_dependency = self.dependencies["pymongo"]
|
|
|
|
self.storage = self.dependency.module.AsyncIOMotorClient(uri, **mongo_opts)
|
|
# TODO: Fix this hack. It was noticed when running a benchmark
|
|
# with FastAPI - however - doesn't appear in unit tests or in an isolated
|
|
# use. Reference: https://jira.mongodb.org/browse/MOTOR-822
|
|
self.storage.get_io_loop = asyncio.get_running_loop
|
|
|
|
self.__database_name = database_name
|
|
self.__indices_created = False
|
|
|
|
@property
|
|
def database(self): # type: ignore
|
|
return self.storage.get_database(self.__database_name)
|
|
|
|
async def create_indices(self) -> None:
|
|
if not self.__indices_created:
|
|
await asyncio.gather(
|
|
self.database.counters.create_index("expireAt", expireAfterSeconds=0),
|
|
self.database.windows.create_index("expireAt", expireAfterSeconds=0),
|
|
)
|
|
self.__indices_created = True
|
|
|
|
async def reset(self) -> Optional[int]:
|
|
"""
|
|
Delete all rate limit keys in the rate limit collections (counters, windows)
|
|
"""
|
|
num_keys = sum(
|
|
await asyncio.gather(
|
|
self.database.counters.count_documents({}),
|
|
self.database.windows.count_documents({}),
|
|
)
|
|
)
|
|
await asyncio.gather(
|
|
self.database.counters.drop(), self.database.windows.drop()
|
|
)
|
|
|
|
return num_keys
|
|
|
|
async def clear(self, key: str) -> None:
|
|
"""
|
|
:param key: the key to clear rate limits for
|
|
"""
|
|
await asyncio.gather(
|
|
self.database.counters.find_one_and_delete({"_id": key}),
|
|
self.database.windows.find_one_and_delete({"_id": key}),
|
|
)
|
|
|
|
async def get_expiry(self, key: str) -> int:
|
|
"""
|
|
:param key: the key to get the expiry for
|
|
"""
|
|
counter = await self.database.counters.find_one({"_id": key})
|
|
expiry = counter["expireAt"] if counter else datetime.datetime.utcnow()
|
|
|
|
return calendar.timegm(expiry.timetuple())
|
|
|
|
async def get(self, key: str) -> int:
|
|
"""
|
|
:param key: the key to get the counter value for
|
|
"""
|
|
counter = await self.database.counters.find_one(
|
|
{"_id": key, "expireAt": {"$gte": datetime.datetime.utcnow()}},
|
|
projection=["count"],
|
|
)
|
|
|
|
return counter and counter["count"] or 0
|
|
|
|
async def incr(
|
|
self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
|
|
) -> int:
|
|
"""
|
|
increments the counter for a given rate limit key
|
|
|
|
:param key: the key to increment
|
|
:param expiry: amount in seconds for the key to expire in
|
|
:param elastic_expiry: whether to keep extending the rate limit
|
|
window every hit.
|
|
:param amount: the number to increment by
|
|
"""
|
|
await self.create_indices()
|
|
|
|
expiration = datetime.datetime.utcnow() + datetime.timedelta(seconds=expiry)
|
|
|
|
response = await self.database.counters.find_one_and_update(
|
|
{"_id": key},
|
|
[
|
|
{
|
|
"$set": {
|
|
"count": {
|
|
"$cond": {
|
|
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
|
"then": amount,
|
|
"else": {"$add": ["$count", amount]},
|
|
}
|
|
},
|
|
"expireAt": {
|
|
"$cond": {
|
|
"if": {"$lt": ["$expireAt", "$$NOW"]},
|
|
"then": expiration,
|
|
"else": (expiration if elastic_expiry else "$expireAt"),
|
|
}
|
|
},
|
|
}
|
|
},
|
|
],
|
|
upsert=True,
|
|
projection=["count"],
|
|
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
|
|
)
|
|
|
|
return int(response["count"])
|
|
|
|
async def check(self) -> bool:
|
|
"""
|
|
Check if storage is healthy by calling
|
|
:meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
|
|
"""
|
|
try:
|
|
await self.storage.server_info()
|
|
|
|
return True
|
|
except: # noqa: E722
|
|
return False
|
|
|
|
async def get_moving_window(
|
|
self, key: str, limit: int, expiry: int
|
|
) -> Tuple[int, int]:
|
|
"""
|
|
returns the starting point and the number of entries in the moving
|
|
window
|
|
|
|
:param str key: rate limit key
|
|
:param int expiry: expiry of entry
|
|
:return: (start of window, number of acquired entries)
|
|
"""
|
|
timestamp = time.time()
|
|
result = await self.database.windows.aggregate(
|
|
[
|
|
{"$match": {"_id": key}},
|
|
{
|
|
"$project": {
|
|
"entries": {
|
|
"$filter": {
|
|
"input": "$entries",
|
|
"as": "entry",
|
|
"cond": {"$gte": ["$$entry", timestamp - expiry]},
|
|
}
|
|
}
|
|
}
|
|
},
|
|
{"$unwind": "$entries"},
|
|
{
|
|
"$group": {
|
|
"_id": "$_id",
|
|
"max": {"$max": "$entries"},
|
|
"count": {"$sum": 1},
|
|
}
|
|
},
|
|
]
|
|
).to_list(length=1)
|
|
|
|
if result:
|
|
return (int(result[0]["max"]), result[0]["count"])
|
|
|
|
return (int(timestamp), 0)
|
|
|
|
async def acquire_entry(
|
|
self, key: str, limit: int, expiry: int, amount: int = 1
|
|
) -> bool:
|
|
"""
|
|
:param key: rate limit key to acquire an entry in
|
|
:param limit: amount of entries allowed
|
|
:param expiry: expiry of the entry
|
|
:param amount: the number of entries to acquire
|
|
"""
|
|
await self.create_indices()
|
|
|
|
timestamp = time.time()
|
|
try:
|
|
updates: Dict[str, Any] = { # type: ignore
|
|
"$push": {"entries": {"$each": [], "$position": 0, "$slice": limit}}
|
|
}
|
|
|
|
updates["$set"] = {
|
|
"expireAt": (
|
|
datetime.datetime.utcnow() + datetime.timedelta(seconds=expiry)
|
|
)
|
|
}
|
|
updates["$push"]["entries"]["$each"] = [timestamp] * amount
|
|
await self.database.windows.update_one(
|
|
{
|
|
"_id": key,
|
|
"entries.%d"
|
|
% (limit - amount): {"$not": {"$gte": timestamp - expiry}},
|
|
},
|
|
updates,
|
|
upsert=True,
|
|
)
|
|
|
|
return True
|
|
except self.proxy_dependency.module.errors.DuplicateKeyError:
|
|
return False
|