You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							268 lines
						
					
					
						
							8.8 KiB
						
					
					
				
			
		
		
	
	
							268 lines
						
					
					
						
							8.8 KiB
						
					
					
				from __future__ import annotations
 | 
						|
 | 
						|
import asyncio
 | 
						|
import calendar
 | 
						|
import datetime
 | 
						|
import time
 | 
						|
from typing import Any
 | 
						|
 | 
						|
from deprecated.sphinx import versionadded
 | 
						|
 | 
						|
from limits.aio.storage.base import MovingWindowSupport, Storage
 | 
						|
from limits.typing import Dict, Optional, ParamSpec, Tuple, TypeVar, Union
 | 
						|
 | 
						|
P = ParamSpec("P")
 | 
						|
R = TypeVar("R")
 | 
						|
 | 
						|
 | 
						|
@versionadded(version="2.1")
 | 
						|
class MongoDBStorage(Storage, MovingWindowSupport):
 | 
						|
    """
 | 
						|
    Rate limit storage with MongoDB as backend.
 | 
						|
 | 
						|
    Depends on :pypi:`motor`
 | 
						|
    """
 | 
						|
 | 
						|
    STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
 | 
						|
    """
 | 
						|
    The storage scheme for MongoDB for use in an async context
 | 
						|
    """
 | 
						|
 | 
						|
    DEFAULT_OPTIONS: Dict[str, Union[float, str, bool]] = {
 | 
						|
        "serverSelectionTimeoutMS": 1000,
 | 
						|
        "socketTimeoutMS": 1000,
 | 
						|
        "connectTimeoutMS": 1000,
 | 
						|
    }
 | 
						|
    "Default options passed to :class:`~motor.motor_asyncio.AsyncIOMotorClient`"
 | 
						|
 | 
						|
    DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        uri: str,
 | 
						|
        database_name: str = "limits",
 | 
						|
        **options: Union[float, str, bool],
 | 
						|
    ) -> None:
 | 
						|
        """
 | 
						|
        :param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
 | 
						|
         This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
 | 
						|
        :param database_name: The database to use for storing the rate limit
 | 
						|
         collections.
 | 
						|
        :param options: all remaining keyword arguments are merged with
 | 
						|
         :data:`DEFAULT_OPTIONS` and passed to the constructor of
 | 
						|
         :class:`~motor.motor_asyncio.AsyncIOMotorClient`
 | 
						|
        :raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
 | 
						|
         not available
 | 
						|
        """
 | 
						|
 | 
						|
        mongo_opts = options.copy()
 | 
						|
        [mongo_opts.setdefault(k, v) for k, v in self.DEFAULT_OPTIONS.items()]
 | 
						|
        uri = uri.replace("async+mongodb", "mongodb", 1)
 | 
						|
 | 
						|
        super().__init__(uri, **options)
 | 
						|
 | 
						|
        self.dependency = self.dependencies["motor.motor_asyncio"]
 | 
						|
        self.proxy_dependency = self.dependencies["pymongo"]
 | 
						|
 | 
						|
        self.storage = self.dependency.module.AsyncIOMotorClient(uri, **mongo_opts)
 | 
						|
        # TODO: Fix this hack. It was noticed when running a benchmark
 | 
						|
        # with FastAPI - however - doesn't appear in unit tests or in an isolated
 | 
						|
        # use. Reference: https://jira.mongodb.org/browse/MOTOR-822
 | 
						|
        self.storage.get_io_loop = asyncio.get_running_loop
 | 
						|
 | 
						|
        self.__database_name = database_name
 | 
						|
        self.__indices_created = False
 | 
						|
 | 
						|
    @property
 | 
						|
    def database(self):  # type: ignore
 | 
						|
        return self.storage.get_database(self.__database_name)
 | 
						|
 | 
						|
    async def create_indices(self) -> None:
 | 
						|
        if not self.__indices_created:
 | 
						|
            await asyncio.gather(
 | 
						|
                self.database.counters.create_index("expireAt", expireAfterSeconds=0),
 | 
						|
                self.database.windows.create_index("expireAt", expireAfterSeconds=0),
 | 
						|
            )
 | 
						|
        self.__indices_created = True
 | 
						|
 | 
						|
    async def reset(self) -> Optional[int]:
 | 
						|
        """
 | 
						|
        Delete all rate limit keys in the rate limit collections (counters, windows)
 | 
						|
        """
 | 
						|
        num_keys = sum(
 | 
						|
            await asyncio.gather(
 | 
						|
                self.database.counters.count_documents({}),
 | 
						|
                self.database.windows.count_documents({}),
 | 
						|
            )
 | 
						|
        )
 | 
						|
        await asyncio.gather(
 | 
						|
            self.database.counters.drop(), self.database.windows.drop()
 | 
						|
        )
 | 
						|
 | 
						|
        return num_keys
 | 
						|
 | 
						|
    async def clear(self, key: str) -> None:
 | 
						|
        """
 | 
						|
        :param key: the key to clear rate limits for
 | 
						|
        """
 | 
						|
        await asyncio.gather(
 | 
						|
            self.database.counters.find_one_and_delete({"_id": key}),
 | 
						|
            self.database.windows.find_one_and_delete({"_id": key}),
 | 
						|
        )
 | 
						|
 | 
						|
    async def get_expiry(self, key: str) -> int:
 | 
						|
        """
 | 
						|
        :param key: the key to get the expiry for
 | 
						|
        """
 | 
						|
        counter = await self.database.counters.find_one({"_id": key})
 | 
						|
        expiry = counter["expireAt"] if counter else datetime.datetime.utcnow()
 | 
						|
 | 
						|
        return calendar.timegm(expiry.timetuple())
 | 
						|
 | 
						|
    async def get(self, key: str) -> int:
 | 
						|
        """
 | 
						|
        :param key: the key to get the counter value for
 | 
						|
        """
 | 
						|
        counter = await self.database.counters.find_one(
 | 
						|
            {"_id": key, "expireAt": {"$gte": datetime.datetime.utcnow()}},
 | 
						|
            projection=["count"],
 | 
						|
        )
 | 
						|
 | 
						|
        return counter and counter["count"] or 0
 | 
						|
 | 
						|
    async def incr(
 | 
						|
        self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
 | 
						|
    ) -> int:
 | 
						|
        """
 | 
						|
        increments the counter for a given rate limit key
 | 
						|
 | 
						|
        :param key: the key to increment
 | 
						|
        :param expiry: amount in seconds for the key to expire in
 | 
						|
        :param elastic_expiry: whether to keep extending the rate limit
 | 
						|
         window every hit.
 | 
						|
        :param amount: the number to increment by
 | 
						|
        """
 | 
						|
        await self.create_indices()
 | 
						|
 | 
						|
        expiration = datetime.datetime.utcnow() + datetime.timedelta(seconds=expiry)
 | 
						|
 | 
						|
        response = await self.database.counters.find_one_and_update(
 | 
						|
            {"_id": key},
 | 
						|
            [
 | 
						|
                {
 | 
						|
                    "$set": {
 | 
						|
                        "count": {
 | 
						|
                            "$cond": {
 | 
						|
                                "if": {"$lt": ["$expireAt", "$$NOW"]},
 | 
						|
                                "then": amount,
 | 
						|
                                "else": {"$add": ["$count", amount]},
 | 
						|
                            }
 | 
						|
                        },
 | 
						|
                        "expireAt": {
 | 
						|
                            "$cond": {
 | 
						|
                                "if": {"$lt": ["$expireAt", "$$NOW"]},
 | 
						|
                                "then": expiration,
 | 
						|
                                "else": (expiration if elastic_expiry else "$expireAt"),
 | 
						|
                            }
 | 
						|
                        },
 | 
						|
                    }
 | 
						|
                },
 | 
						|
            ],
 | 
						|
            upsert=True,
 | 
						|
            projection=["count"],
 | 
						|
            return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
 | 
						|
        )
 | 
						|
 | 
						|
        return int(response["count"])
 | 
						|
 | 
						|
    async def check(self) -> bool:
 | 
						|
        """
 | 
						|
        Check if storage is healthy by calling
 | 
						|
        :meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            await self.storage.server_info()
 | 
						|
 | 
						|
            return True
 | 
						|
        except:  # noqa: E722
 | 
						|
            return False
 | 
						|
 | 
						|
    async def get_moving_window(
 | 
						|
        self, key: str, limit: int, expiry: int
 | 
						|
    ) -> Tuple[int, int]:
 | 
						|
        """
 | 
						|
        returns the starting point and the number of entries in the moving
 | 
						|
        window
 | 
						|
 | 
						|
        :param str key: rate limit key
 | 
						|
        :param int expiry: expiry of entry
 | 
						|
        :return: (start of window, number of acquired entries)
 | 
						|
        """
 | 
						|
        timestamp = time.time()
 | 
						|
        result = await self.database.windows.aggregate(
 | 
						|
            [
 | 
						|
                {"$match": {"_id": key}},
 | 
						|
                {
 | 
						|
                    "$project": {
 | 
						|
                        "entries": {
 | 
						|
                            "$filter": {
 | 
						|
                                "input": "$entries",
 | 
						|
                                "as": "entry",
 | 
						|
                                "cond": {"$gte": ["$$entry", timestamp - expiry]},
 | 
						|
                            }
 | 
						|
                        }
 | 
						|
                    }
 | 
						|
                },
 | 
						|
                {"$unwind": "$entries"},
 | 
						|
                {
 | 
						|
                    "$group": {
 | 
						|
                        "_id": "$_id",
 | 
						|
                        "max": {"$max": "$entries"},
 | 
						|
                        "count": {"$sum": 1},
 | 
						|
                    }
 | 
						|
                },
 | 
						|
            ]
 | 
						|
        ).to_list(length=1)
 | 
						|
 | 
						|
        if result:
 | 
						|
            return (int(result[0]["max"]), result[0]["count"])
 | 
						|
 | 
						|
        return (int(timestamp), 0)
 | 
						|
 | 
						|
    async def acquire_entry(
 | 
						|
        self, key: str, limit: int, expiry: int, amount: int = 1
 | 
						|
    ) -> bool:
 | 
						|
        """
 | 
						|
        :param key: rate limit key to acquire an entry in
 | 
						|
        :param limit: amount of entries allowed
 | 
						|
        :param expiry: expiry of the entry
 | 
						|
        :param amount: the number of entries to acquire
 | 
						|
        """
 | 
						|
        await self.create_indices()
 | 
						|
 | 
						|
        timestamp = time.time()
 | 
						|
        try:
 | 
						|
            updates: Dict[str, Any] = {  # type: ignore
 | 
						|
                "$push": {"entries": {"$each": [], "$position": 0, "$slice": limit}}
 | 
						|
            }
 | 
						|
 | 
						|
            updates["$set"] = {
 | 
						|
                "expireAt": (
 | 
						|
                    datetime.datetime.utcnow() + datetime.timedelta(seconds=expiry)
 | 
						|
                )
 | 
						|
            }
 | 
						|
            updates["$push"]["entries"]["$each"] = [timestamp] * amount
 | 
						|
            await self.database.windows.update_one(
 | 
						|
                {
 | 
						|
                    "_id": key,
 | 
						|
                    "entries.%d"
 | 
						|
                    % (limit - amount): {"$not": {"$gte": timestamp - expiry}},
 | 
						|
                },
 | 
						|
                updates,
 | 
						|
                upsert=True,
 | 
						|
            )
 | 
						|
 | 
						|
            return True
 | 
						|
        except self.proxy_dependency.module.errors.DuplicateKeyError:
 | 
						|
            return False
 |