You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							439 lines
						
					
					
						
							14 KiB
						
					
					
				
			
		
		
	
	
							439 lines
						
					
					
						
							14 KiB
						
					
					
				import time
 | 
						|
import urllib
 | 
						|
from typing import TYPE_CHECKING, cast
 | 
						|
 | 
						|
from deprecated.sphinx import versionadded
 | 
						|
from packaging.version import Version
 | 
						|
 | 
						|
from limits.aio.storage.base import MovingWindowSupport, Storage
 | 
						|
from limits.errors import ConfigurationError
 | 
						|
from limits.typing import AsyncRedisClient, Dict, Optional, Tuple, Union
 | 
						|
from limits.util import get_package_data
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    import coredis
 | 
						|
    import coredis.commands
 | 
						|
 | 
						|
 | 
						|
class RedisInteractor:
 | 
						|
    RES_DIR = "resources/redis/lua_scripts"
 | 
						|
 | 
						|
    SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
 | 
						|
    SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
 | 
						|
        f"{RES_DIR}/acquire_moving_window.lua"
 | 
						|
    )
 | 
						|
    SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
 | 
						|
    SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
 | 
						|
 | 
						|
    lua_moving_window: "coredis.commands.Script[bytes]"
 | 
						|
    lua_acquire_window: "coredis.commands.Script[bytes]"
 | 
						|
    lua_clear_keys: "coredis.commands.Script[bytes]"
 | 
						|
    lua_incr_expire: "coredis.commands.Script[bytes]"
 | 
						|
 | 
						|
    async def _incr(
 | 
						|
        self,
 | 
						|
        key: str,
 | 
						|
        expiry: int,
 | 
						|
        connection: AsyncRedisClient,
 | 
						|
        elastic_expiry: bool = False,
 | 
						|
        amount: int = 1,
 | 
						|
    ) -> int:
 | 
						|
        """
 | 
						|
        increments the counter for a given rate limit key
 | 
						|
 | 
						|
        :param connection: Redis connection
 | 
						|
        :param key: the key to increment
 | 
						|
        :param expiry: amount in seconds for the key to expire in
 | 
						|
        :param amount: the number to increment by
 | 
						|
        """
 | 
						|
        value = await connection.incrby(key, amount)
 | 
						|
 | 
						|
        if elastic_expiry or value == amount:
 | 
						|
            await connection.expire(key, expiry)
 | 
						|
 | 
						|
        return value
 | 
						|
 | 
						|
    async def _get(self, key: str, connection: AsyncRedisClient) -> int:
 | 
						|
        """
 | 
						|
        :param connection: Redis connection
 | 
						|
        :param key: the key to get the counter value for
 | 
						|
        """
 | 
						|
 | 
						|
        return int(await connection.get(key) or 0)
 | 
						|
 | 
						|
    async def _clear(self, key: str, connection: AsyncRedisClient) -> None:
 | 
						|
        """
 | 
						|
        :param key: the key to clear rate limits for
 | 
						|
        :param connection: Redis connection
 | 
						|
        """
 | 
						|
        await connection.delete([key])
 | 
						|
 | 
						|
    async def get_moving_window(
 | 
						|
        self, key: str, limit: int, expiry: int
 | 
						|
    ) -> Tuple[int, int]:
 | 
						|
        """
 | 
						|
        returns the starting point and the number of entries in the moving
 | 
						|
        window
 | 
						|
 | 
						|
        :param key: rate limit key
 | 
						|
        :param expiry: expiry of entry
 | 
						|
        :return: (start of window, number of acquired entries)
 | 
						|
        """
 | 
						|
        timestamp = int(time.time())
 | 
						|
        window = await self.lua_moving_window.execute(
 | 
						|
            [key], [int(timestamp - expiry), limit]
 | 
						|
        )
 | 
						|
        if window:
 | 
						|
            return tuple(window)  # type: ignore
 | 
						|
        return timestamp, 0
 | 
						|
 | 
						|
    async def _acquire_entry(
 | 
						|
        self,
 | 
						|
        key: str,
 | 
						|
        limit: int,
 | 
						|
        expiry: int,
 | 
						|
        connection: AsyncRedisClient,
 | 
						|
        amount: int = 1,
 | 
						|
    ) -> bool:
 | 
						|
        """
 | 
						|
        :param key: rate limit key to acquire an entry in
 | 
						|
        :param limit: amount of entries allowed
 | 
						|
        :param expiry: expiry of the entry
 | 
						|
        :param connection: Redis connection
 | 
						|
        """
 | 
						|
        timestamp = time.time()
 | 
						|
        acquired = await self.lua_acquire_window.execute(
 | 
						|
            [key], [timestamp, limit, expiry, amount]
 | 
						|
        )
 | 
						|
 | 
						|
        return bool(acquired)
 | 
						|
 | 
						|
    async def _get_expiry(self, key: str, connection: AsyncRedisClient) -> int:
 | 
						|
        """
 | 
						|
        :param key: the key to get the expiry for
 | 
						|
        :param connection: Redis connection
 | 
						|
        """
 | 
						|
 | 
						|
        return int(max(await connection.ttl(key), 0) + time.time())
 | 
						|
 | 
						|
    async def _check(self, connection: AsyncRedisClient) -> bool:
 | 
						|
        """
 | 
						|
        check if storage is healthy
 | 
						|
 | 
						|
        :param connection: Redis connection
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            await connection.ping()
 | 
						|
 | 
						|
            return True
 | 
						|
        except:  # noqa
 | 
						|
            return False
 | 
						|
 | 
						|
 | 
						|
@versionadded(version="2.1")
 | 
						|
class RedisStorage(RedisInteractor, Storage, MovingWindowSupport):
 | 
						|
    """
 | 
						|
    Rate limit storage with redis as backend.
 | 
						|
 | 
						|
    Depends on :pypi:`coredis`
 | 
						|
    """
 | 
						|
 | 
						|
    STORAGE_SCHEME = ["async+redis", "async+rediss", "async+redis+unix"]
 | 
						|
    """
 | 
						|
    The storage schemes for redis to be used in an async context
 | 
						|
    """
 | 
						|
    DEPENDENCIES = {"coredis": Version("3.4.0")}
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        uri: str,
 | 
						|
        connection_pool: Optional["coredis.ConnectionPool"] = None,
 | 
						|
        **options: Union[float, str, bool],
 | 
						|
    ) -> None:
 | 
						|
        """
 | 
						|
        :param uri: uri of the form:
 | 
						|
 | 
						|
         - ``async+redis://[:password]@host:port``
 | 
						|
         - ``async+redis://[:password]@host:port/db``
 | 
						|
         - ``async+rediss://[:password]@host:port``
 | 
						|
         - ``async+unix:///path/to/sock`` etc...
 | 
						|
 | 
						|
         This uri is passed directly to :meth:`coredis.Redis.from_url` with
 | 
						|
         the initial ``async`` removed, except for the case of ``async+redis+unix``
 | 
						|
         where it is replaced with ``unix``.
 | 
						|
        :param connection_pool: if provided, the redis client is initialized with
 | 
						|
         the connection pool and any other params passed as :paramref:`options`
 | 
						|
        :param options: all remaining keyword arguments are passed
 | 
						|
         directly to the constructor of :class:`coredis.Redis`
 | 
						|
        :raise ConfigurationError: when the redis library is not available
 | 
						|
        """
 | 
						|
        uri = uri.replace("async+redis", "redis", 1)
 | 
						|
        uri = uri.replace("redis+unix", "unix")
 | 
						|
 | 
						|
        super().__init__(uri, **options)
 | 
						|
 | 
						|
        self.dependency = self.dependencies["coredis"].module
 | 
						|
 | 
						|
        if connection_pool:
 | 
						|
            self.storage = self.dependency.Redis(
 | 
						|
                connection_pool=connection_pool, **options
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            self.storage = self.dependency.Redis.from_url(uri, **options)
 | 
						|
 | 
						|
        self.initialize_storage(uri)
 | 
						|
 | 
						|
    def initialize_storage(self, _uri: str) -> None:
 | 
						|
        # all these methods are coroutines, so must be called with await
 | 
						|
        self.lua_moving_window = self.storage.register_script(self.SCRIPT_MOVING_WINDOW)
 | 
						|
        self.lua_acquire_window = self.storage.register_script(
 | 
						|
            self.SCRIPT_ACQUIRE_MOVING_WINDOW
 | 
						|
        )
 | 
						|
        self.lua_clear_keys = self.storage.register_script(self.SCRIPT_CLEAR_KEYS)
 | 
						|
        self.lua_incr_expire = self.storage.register_script(
 | 
						|
            RedisStorage.SCRIPT_INCR_EXPIRE
 | 
						|
        )
 | 
						|
 | 
						|
    async def incr(
 | 
						|
        self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
 | 
						|
    ) -> int:
 | 
						|
        """
 | 
						|
        increments the counter for a given rate limit key
 | 
						|
 | 
						|
        :param key: the key to increment
 | 
						|
        :param expiry: amount in seconds for the key to expire in
 | 
						|
        :param amount: the number to increment by
 | 
						|
        """
 | 
						|
 | 
						|
        if elastic_expiry:
 | 
						|
            return await super()._incr(
 | 
						|
                key, expiry, self.storage, elastic_expiry, amount
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            return cast(
 | 
						|
                int, await self.lua_incr_expire.execute([key], [expiry, amount])
 | 
						|
            )
 | 
						|
 | 
						|
    async def get(self, key: str) -> int:
 | 
						|
        """
 | 
						|
        :param key: the key to get the counter value for
 | 
						|
        """
 | 
						|
 | 
						|
        return await super()._get(key, self.storage)
 | 
						|
 | 
						|
    async def clear(self, key: str) -> None:
 | 
						|
        """
 | 
						|
        :param key: the key to clear rate limits for
 | 
						|
        """
 | 
						|
 | 
						|
        return await super()._clear(key, self.storage)
 | 
						|
 | 
						|
    async def acquire_entry(
 | 
						|
        self, key: str, limit: int, expiry: int, amount: int = 1
 | 
						|
    ) -> bool:
 | 
						|
        """
 | 
						|
        :param key: rate limit key to acquire an entry in
 | 
						|
        :param limit: amount of entries allowed
 | 
						|
        :param expiry: expiry of the entry
 | 
						|
        :param amount: the number of entries to acquire
 | 
						|
        """
 | 
						|
 | 
						|
        return await super()._acquire_entry(key, limit, expiry, self.storage, amount)
 | 
						|
 | 
						|
    async def get_expiry(self, key: str) -> int:
 | 
						|
        """
 | 
						|
        :param key: the key to get the expiry for
 | 
						|
        """
 | 
						|
 | 
						|
        return await super()._get_expiry(key, self.storage)
 | 
						|
 | 
						|
    async def check(self) -> bool:
 | 
						|
        """
 | 
						|
        Check if storage is healthy by calling :meth:`coredis.Redis.ping`
 | 
						|
        """
 | 
						|
 | 
						|
        return await super()._check(self.storage)
 | 
						|
 | 
						|
    async def reset(self) -> Optional[int]:
 | 
						|
        """
 | 
						|
        This function calls a Lua Script to delete keys prefixed with 'LIMITER'
 | 
						|
        in block of 5000.
 | 
						|
 | 
						|
        .. warning:: This operation was designed to be fast, but was not tested
 | 
						|
           on a large production based system. Be careful with its usage as it
 | 
						|
           could be slow on very large data sets.
 | 
						|
        """
 | 
						|
 | 
						|
        return cast(int, await self.lua_clear_keys.execute(["LIMITER*"]))
 | 
						|
 | 
						|
 | 
						|
@versionadded(version="2.1")
 | 
						|
class RedisClusterStorage(RedisStorage):
 | 
						|
    """
 | 
						|
    Rate limit storage with redis cluster as backend
 | 
						|
 | 
						|
    Depends on :pypi:`coredis`
 | 
						|
    """
 | 
						|
 | 
						|
    STORAGE_SCHEME = ["async+redis+cluster"]
 | 
						|
    """
 | 
						|
    The storage schemes for redis cluster to be used in an async context
 | 
						|
    """
 | 
						|
 | 
						|
    DEFAULT_OPTIONS: Dict[str, Union[float, str, bool]] = {
 | 
						|
        "max_connections": 1000,
 | 
						|
    }
 | 
						|
    "Default options passed to :class:`coredis.RedisCluster`"
 | 
						|
 | 
						|
    def __init__(self, uri: str, **options: Union[float, str, bool]) -> None:
 | 
						|
        """
 | 
						|
        :param uri: url of the form
 | 
						|
         ``async+redis+cluster://[:password]@host:port,host:port``
 | 
						|
        :param options: all remaining keyword arguments are passed
 | 
						|
         directly to the constructor of :class:`coredis.RedisCluster`
 | 
						|
        :raise ConfigurationError: when the coredis library is not
 | 
						|
         available or if the redis host cannot be pinged.
 | 
						|
        """
 | 
						|
        parsed = urllib.parse.urlparse(uri)
 | 
						|
        cluster_hosts = []
 | 
						|
 | 
						|
        for loc in parsed.netloc.split(","):
 | 
						|
            host, port = loc.split(":")
 | 
						|
            cluster_hosts.append({"host": host, "port": int(port)})
 | 
						|
 | 
						|
        super(RedisStorage, self).__init__(uri, **options)
 | 
						|
 | 
						|
        self.dependency = self.dependencies["coredis"].module
 | 
						|
 | 
						|
        self.storage: "coredis.RedisCluster[str]" = self.dependency.RedisCluster(
 | 
						|
            startup_nodes=cluster_hosts, **{**self.DEFAULT_OPTIONS, **options}
 | 
						|
        )
 | 
						|
        self.initialize_storage(uri)
 | 
						|
 | 
						|
    async def reset(self) -> Optional[int]:
 | 
						|
        """
 | 
						|
        Redis Clusters are sharded and deleting across shards
 | 
						|
        can't be done atomically. Because of this, this reset loops over all
 | 
						|
        keys that are prefixed with 'LIMITER' and calls delete on them, one at
 | 
						|
        a time.
 | 
						|
 | 
						|
        .. warning:: This operation was not tested with extremely large data sets.
 | 
						|
           On a large production based system, care should be taken with its
 | 
						|
           usage as it could be slow on very large data sets
 | 
						|
        """
 | 
						|
 | 
						|
        keys = await self.storage.keys("LIMITER*")
 | 
						|
        count = 0
 | 
						|
        for key in keys:
 | 
						|
            count += await self.storage.delete([key])
 | 
						|
        return count
 | 
						|
 | 
						|
 | 
						|
@versionadded(version="2.1")
 | 
						|
class RedisSentinelStorage(RedisStorage):
 | 
						|
    """
 | 
						|
    Rate limit storage with redis sentinel as backend
 | 
						|
 | 
						|
    Depends on :pypi:`coredis`
 | 
						|
    """
 | 
						|
 | 
						|
    STORAGE_SCHEME = ["async+redis+sentinel"]
 | 
						|
    """The storage scheme for redis accessed via a redis sentinel installation"""
 | 
						|
 | 
						|
    DEFAULT_OPTIONS: Dict[str, Union[float, str, bool]] = {
 | 
						|
        "stream_timeout": 0.2,
 | 
						|
    }
 | 
						|
    "Default options passed to :class:`~coredis.sentinel.Sentinel`"
 | 
						|
 | 
						|
    DEPENDENCIES = {"coredis.sentinel": Version("3.4.0")}
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        uri: str,
 | 
						|
        service_name: Optional[str] = None,
 | 
						|
        use_replicas: bool = True,
 | 
						|
        sentinel_kwargs: Optional[Dict[str, Union[float, str, bool]]] = None,
 | 
						|
        **options: Union[float, str, bool],
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        :param uri: url of the form
 | 
						|
         ``async+redis+sentinel://host:port,host:port/service_name``
 | 
						|
        :param service_name, optional: sentinel service name
 | 
						|
         (if not provided in `uri`)
 | 
						|
        :param use_replicas: Whether to use replicas for read only operations
 | 
						|
        :param sentinel_kwargs, optional: kwargs to pass as
 | 
						|
         ``sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel`
 | 
						|
        :param options: all remaining keyword arguments are passed
 | 
						|
         directly to the constructor of :class:`coredis.sentinel.Sentinel`
 | 
						|
        :raise ConfigurationError: when the coredis library is not available
 | 
						|
         or if the redis primary host cannot be pinged.
 | 
						|
        """
 | 
						|
 | 
						|
        parsed = urllib.parse.urlparse(uri)
 | 
						|
        sentinel_configuration = []
 | 
						|
        connection_options = options.copy()
 | 
						|
        sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {}
 | 
						|
        parsed_auth: Dict[str, Union[float, str, bool]] = {}
 | 
						|
 | 
						|
        if parsed.username:
 | 
						|
            parsed_auth["username"] = parsed.username
 | 
						|
 | 
						|
        if parsed.password:
 | 
						|
            parsed_auth["password"] = parsed.password
 | 
						|
 | 
						|
        sep = parsed.netloc.find("@") + 1
 | 
						|
 | 
						|
        for loc in parsed.netloc[sep:].split(","):
 | 
						|
            host, port = loc.split(":")
 | 
						|
            sentinel_configuration.append((host, int(port)))
 | 
						|
        self.service_name = (
 | 
						|
            parsed.path.replace("/", "") if parsed.path else service_name
 | 
						|
        )
 | 
						|
 | 
						|
        if self.service_name is None:
 | 
						|
            raise ConfigurationError("'service_name' not provided")
 | 
						|
 | 
						|
        connection_options.setdefault("stream_timeout", 0.2)
 | 
						|
 | 
						|
        super(RedisStorage, self).__init__()
 | 
						|
 | 
						|
        self.dependency = self.dependencies["coredis.sentinel"].module
 | 
						|
 | 
						|
        self.sentinel = self.dependency.Sentinel(
 | 
						|
            sentinel_configuration,
 | 
						|
            sentinel_kwargs={**parsed_auth, **sentinel_options},
 | 
						|
            **{**parsed_auth, **connection_options},
 | 
						|
        )
 | 
						|
        self.storage = self.sentinel.primary_for(self.service_name)
 | 
						|
        self.storage_replica = self.sentinel.replica_for(self.service_name)
 | 
						|
        self.use_replicas = use_replicas
 | 
						|
        self.initialize_storage(uri)
 | 
						|
 | 
						|
    async def get(self, key: str) -> int:
 | 
						|
        """
 | 
						|
        :param key: the key to get the counter value for
 | 
						|
        """
 | 
						|
 | 
						|
        return await super()._get(
 | 
						|
            key, self.storage_replica if self.use_replicas else self.storage
 | 
						|
        )
 | 
						|
 | 
						|
    async def get_expiry(self, key: str) -> int:
 | 
						|
        """
 | 
						|
        :param key: the key to get the expiry for
 | 
						|
        """
 | 
						|
 | 
						|
        return await super()._get_expiry(
 | 
						|
            key, self.storage_replica if self.use_replicas else self.storage
 | 
						|
        )
 | 
						|
 | 
						|
    async def check(self) -> bool:
 | 
						|
        """
 | 
						|
        Check if storage is healthy by calling :meth:`coredis.StrictRedis.ping`
 | 
						|
        on the replica.
 | 
						|
        """
 | 
						|
 | 
						|
        return await super()._check(
 | 
						|
            self.storage_replica if self.use_replicas else self.storage
 | 
						|
        )
 |