You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					439 lines
				
				14 KiB
			
		
		
			
		
	
	
					439 lines
				
				14 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import time
							 | 
						||
| 
								 | 
							
								import urllib
							 | 
						||
| 
								 | 
							
								from typing import TYPE_CHECKING, cast
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from deprecated.sphinx import versionadded
							 | 
						||
| 
								 | 
							
								from packaging.version import Version
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from limits.aio.storage.base import MovingWindowSupport, Storage
							 | 
						||
| 
								 | 
							
								from limits.errors import ConfigurationError
							 | 
						||
| 
								 | 
							
								from limits.typing import AsyncRedisClient, Dict, Optional, Tuple, Union
							 | 
						||
| 
								 | 
							
								from limits.util import get_package_data
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if TYPE_CHECKING:
							 | 
						||
| 
								 | 
							
								    import coredis
							 | 
						||
| 
								 | 
							
								    import coredis.commands
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class RedisInteractor:
							 | 
						||
| 
								 | 
							
								    RES_DIR = "resources/redis/lua_scripts"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
							 | 
						||
| 
								 | 
							
								    SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
							 | 
						||
| 
								 | 
							
								        f"{RES_DIR}/acquire_moving_window.lua"
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
							 | 
						||
| 
								 | 
							
								    SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    lua_moving_window: "coredis.commands.Script[bytes]"
							 | 
						||
| 
								 | 
							
								    lua_acquire_window: "coredis.commands.Script[bytes]"
							 | 
						||
| 
								 | 
							
								    lua_clear_keys: "coredis.commands.Script[bytes]"
							 | 
						||
| 
								 | 
							
								    lua_incr_expire: "coredis.commands.Script[bytes]"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def _incr(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        key: str,
							 | 
						||
| 
								 | 
							
								        expiry: int,
							 | 
						||
| 
								 | 
							
								        connection: AsyncRedisClient,
							 | 
						||
| 
								 | 
							
								        elastic_expiry: bool = False,
							 | 
						||
| 
								 | 
							
								        amount: int = 1,
							 | 
						||
| 
								 | 
							
								    ) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        increments the counter for a given rate limit key
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param connection: Redis connection
							 | 
						||
| 
								 | 
							
								        :param key: the key to increment
							 | 
						||
| 
								 | 
							
								        :param expiry: amount in seconds for the key to expire in
							 | 
						||
| 
								 | 
							
								        :param amount: the number to increment by
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        value = await connection.incrby(key, amount)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if elastic_expiry or value == amount:
							 | 
						||
| 
								 | 
							
								            await connection.expire(key, expiry)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return value
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def _get(self, key: str, connection: AsyncRedisClient) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param connection: Redis connection
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the counter value for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return int(await connection.get(key) or 0)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def _clear(self, key: str, connection: AsyncRedisClient) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to clear rate limits for
							 | 
						||
| 
								 | 
							
								        :param connection: Redis connection
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        await connection.delete([key])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def get_moving_window(
							 | 
						||
| 
								 | 
							
								        self, key: str, limit: int, expiry: int
							 | 
						||
| 
								 | 
							
								    ) -> Tuple[int, int]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        returns the starting point and the number of entries in the moving
							 | 
						||
| 
								 | 
							
								        window
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param key: rate limit key
							 | 
						||
| 
								 | 
							
								        :param expiry: expiry of entry
							 | 
						||
| 
								 | 
							
								        :return: (start of window, number of acquired entries)
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        timestamp = int(time.time())
							 | 
						||
| 
								 | 
							
								        window = await self.lua_moving_window.execute(
							 | 
						||
| 
								 | 
							
								            [key], [int(timestamp - expiry), limit]
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        if window:
							 | 
						||
| 
								 | 
							
								            return tuple(window)  # type: ignore
							 | 
						||
| 
								 | 
							
								        return timestamp, 0
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def _acquire_entry(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        key: str,
							 | 
						||
| 
								 | 
							
								        limit: int,
							 | 
						||
| 
								 | 
							
								        expiry: int,
							 | 
						||
| 
								 | 
							
								        connection: AsyncRedisClient,
							 | 
						||
| 
								 | 
							
								        amount: int = 1,
							 | 
						||
| 
								 | 
							
								    ) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: rate limit key to acquire an entry in
							 | 
						||
| 
								 | 
							
								        :param limit: amount of entries allowed
							 | 
						||
| 
								 | 
							
								        :param expiry: expiry of the entry
							 | 
						||
| 
								 | 
							
								        :param connection: Redis connection
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        timestamp = time.time()
							 | 
						||
| 
								 | 
							
								        acquired = await self.lua_acquire_window.execute(
							 | 
						||
| 
								 | 
							
								            [key], [timestamp, limit, expiry, amount]
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return bool(acquired)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def _get_expiry(self, key: str, connection: AsyncRedisClient) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the expiry for
							 | 
						||
| 
								 | 
							
								        :param connection: Redis connection
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return int(max(await connection.ttl(key), 0) + time.time())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def _check(self, connection: AsyncRedisClient) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        check if storage is healthy
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param connection: Redis connection
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            await connection.ping()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            return True
							 | 
						||
| 
								 | 
							
								        except:  # noqa
							 | 
						||
| 
								 | 
							
								            return False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@versionadded(version="2.1")
							 | 
						||
| 
								 | 
							
								class RedisStorage(RedisInteractor, Storage, MovingWindowSupport):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Rate limit storage with redis as backend.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Depends on :pypi:`coredis`
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    STORAGE_SCHEME = ["async+redis", "async+rediss", "async+redis+unix"]
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    The storage schemes for redis to be used in an async context
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    DEPENDENCIES = {"coredis": Version("3.4.0")}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        uri: str,
							 | 
						||
| 
								 | 
							
								        connection_pool: Optional["coredis.ConnectionPool"] = None,
							 | 
						||
| 
								 | 
							
								        **options: Union[float, str, bool],
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param uri: uri of the form:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								         - ``async+redis://[:password]@host:port``
							 | 
						||
| 
								 | 
							
								         - ``async+redis://[:password]@host:port/db``
							 | 
						||
| 
								 | 
							
								         - ``async+rediss://[:password]@host:port``
							 | 
						||
| 
								 | 
							
								         - ``async+unix:///path/to/sock`` etc...
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								         This uri is passed directly to :meth:`coredis.Redis.from_url` with
							 | 
						||
| 
								 | 
							
								         the initial ``async`` removed, except for the case of ``async+redis+unix``
							 | 
						||
| 
								 | 
							
								         where it is replaced with ``unix``.
							 | 
						||
| 
								 | 
							
								        :param connection_pool: if provided, the redis client is initialized with
							 | 
						||
| 
								 | 
							
								         the connection pool and any other params passed as :paramref:`options`
							 | 
						||
| 
								 | 
							
								        :param options: all remaining keyword arguments are passed
							 | 
						||
| 
								 | 
							
								         directly to the constructor of :class:`coredis.Redis`
							 | 
						||
| 
								 | 
							
								        :raise ConfigurationError: when the redis library is not available
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        uri = uri.replace("async+redis", "redis", 1)
							 | 
						||
| 
								 | 
							
								        uri = uri.replace("redis+unix", "unix")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        super().__init__(uri, **options)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.dependency = self.dependencies["coredis"].module
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if connection_pool:
							 | 
						||
| 
								 | 
							
								            self.storage = self.dependency.Redis(
							 | 
						||
| 
								 | 
							
								                connection_pool=connection_pool, **options
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            self.storage = self.dependency.Redis.from_url(uri, **options)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.initialize_storage(uri)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def initialize_storage(self, _uri: str) -> None:
							 | 
						||
| 
								 | 
							
								        # all these methods are coroutines, so must be called with await
							 | 
						||
| 
								 | 
							
								        self.lua_moving_window = self.storage.register_script(self.SCRIPT_MOVING_WINDOW)
							 | 
						||
| 
								 | 
							
								        self.lua_acquire_window = self.storage.register_script(
							 | 
						||
| 
								 | 
							
								            self.SCRIPT_ACQUIRE_MOVING_WINDOW
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self.lua_clear_keys = self.storage.register_script(self.SCRIPT_CLEAR_KEYS)
							 | 
						||
| 
								 | 
							
								        self.lua_incr_expire = self.storage.register_script(
							 | 
						||
| 
								 | 
							
								            RedisStorage.SCRIPT_INCR_EXPIRE
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def incr(
							 | 
						||
| 
								 | 
							
								        self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
							 | 
						||
| 
								 | 
							
								    ) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        increments the counter for a given rate limit key
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param key: the key to increment
							 | 
						||
| 
								 | 
							
								        :param expiry: amount in seconds for the key to expire in
							 | 
						||
| 
								 | 
							
								        :param amount: the number to increment by
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if elastic_expiry:
							 | 
						||
| 
								 | 
							
								            return await super()._incr(
							 | 
						||
| 
								 | 
							
								                key, expiry, self.storage, elastic_expiry, amount
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            return cast(
							 | 
						||
| 
								 | 
							
								                int, await self.lua_incr_expire.execute([key], [expiry, amount])
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def get(self, key: str) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the counter value for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return await super()._get(key, self.storage)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def clear(self, key: str) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to clear rate limits for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return await super()._clear(key, self.storage)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def acquire_entry(
							 | 
						||
| 
								 | 
							
								        self, key: str, limit: int, expiry: int, amount: int = 1
							 | 
						||
| 
								 | 
							
								    ) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: rate limit key to acquire an entry in
							 | 
						||
| 
								 | 
							
								        :param limit: amount of entries allowed
							 | 
						||
| 
								 | 
							
								        :param expiry: expiry of the entry
							 | 
						||
| 
								 | 
							
								        :param amount: the number of entries to acquire
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return await super()._acquire_entry(key, limit, expiry, self.storage, amount)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def get_expiry(self, key: str) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the expiry for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return await super()._get_expiry(key, self.storage)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def check(self) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Check if storage is healthy by calling :meth:`coredis.Redis.ping`
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return await super()._check(self.storage)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def reset(self) -> Optional[int]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        This function calls a Lua Script to delete keys prefixed with 'LIMITER'
							 | 
						||
| 
								 | 
							
								        in block of 5000.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        .. warning:: This operation was designed to be fast, but was not tested
							 | 
						||
| 
								 | 
							
								           on a large production based system. Be careful with its usage as it
							 | 
						||
| 
								 | 
							
								           could be slow on very large data sets.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return cast(int, await self.lua_clear_keys.execute(["LIMITER*"]))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@versionadded(version="2.1")
							 | 
						||
| 
								 | 
							
								class RedisClusterStorage(RedisStorage):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Rate limit storage with redis cluster as backend
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Depends on :pypi:`coredis`
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    STORAGE_SCHEME = ["async+redis+cluster"]
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    The storage schemes for redis cluster to be used in an async context
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    DEFAULT_OPTIONS: Dict[str, Union[float, str, bool]] = {
							 | 
						||
| 
								 | 
							
								        "max_connections": 1000,
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    "Default options passed to :class:`coredis.RedisCluster`"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, uri: str, **options: Union[float, str, bool]) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param uri: url of the form
							 | 
						||
| 
								 | 
							
								         ``async+redis+cluster://[:password]@host:port,host:port``
							 | 
						||
| 
								 | 
							
								        :param options: all remaining keyword arguments are passed
							 | 
						||
| 
								 | 
							
								         directly to the constructor of :class:`coredis.RedisCluster`
							 | 
						||
| 
								 | 
							
								        :raise ConfigurationError: when the coredis library is not
							 | 
						||
| 
								 | 
							
								         available or if the redis host cannot be pinged.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        parsed = urllib.parse.urlparse(uri)
							 | 
						||
| 
								 | 
							
								        cluster_hosts = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for loc in parsed.netloc.split(","):
							 | 
						||
| 
								 | 
							
								            host, port = loc.split(":")
							 | 
						||
| 
								 | 
							
								            cluster_hosts.append({"host": host, "port": int(port)})
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        super(RedisStorage, self).__init__(uri, **options)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.dependency = self.dependencies["coredis"].module
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.storage: "coredis.RedisCluster[str]" = self.dependency.RedisCluster(
							 | 
						||
| 
								 | 
							
								            startup_nodes=cluster_hosts, **{**self.DEFAULT_OPTIONS, **options}
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self.initialize_storage(uri)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def reset(self) -> Optional[int]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Redis Clusters are sharded and deleting across shards
							 | 
						||
| 
								 | 
							
								        can't be done atomically. Because of this, this reset loops over all
							 | 
						||
| 
								 | 
							
								        keys that are prefixed with 'LIMITER' and calls delete on them, one at
							 | 
						||
| 
								 | 
							
								        a time.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        .. warning:: This operation was not tested with extremely large data sets.
							 | 
						||
| 
								 | 
							
								           On a large production based system, care should be taken with its
							 | 
						||
| 
								 | 
							
								           usage as it could be slow on very large data sets
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        keys = await self.storage.keys("LIMITER*")
							 | 
						||
| 
								 | 
							
								        count = 0
							 | 
						||
| 
								 | 
							
								        for key in keys:
							 | 
						||
| 
								 | 
							
								            count += await self.storage.delete([key])
							 | 
						||
| 
								 | 
							
								        return count
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@versionadded(version="2.1")
							 | 
						||
| 
								 | 
							
								class RedisSentinelStorage(RedisStorage):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Rate limit storage with redis sentinel as backend
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Depends on :pypi:`coredis`
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    STORAGE_SCHEME = ["async+redis+sentinel"]
							 | 
						||
| 
								 | 
							
								    """The storage scheme for redis accessed via a redis sentinel installation"""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    DEFAULT_OPTIONS: Dict[str, Union[float, str, bool]] = {
							 | 
						||
| 
								 | 
							
								        "stream_timeout": 0.2,
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    "Default options passed to :class:`~coredis.sentinel.Sentinel`"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    DEPENDENCIES = {"coredis.sentinel": Version("3.4.0")}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        uri: str,
							 | 
						||
| 
								 | 
							
								        service_name: Optional[str] = None,
							 | 
						||
| 
								 | 
							
								        use_replicas: bool = True,
							 | 
						||
| 
								 | 
							
								        sentinel_kwargs: Optional[Dict[str, Union[float, str, bool]]] = None,
							 | 
						||
| 
								 | 
							
								        **options: Union[float, str, bool],
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param uri: url of the form
							 | 
						||
| 
								 | 
							
								         ``async+redis+sentinel://host:port,host:port/service_name``
							 | 
						||
| 
								 | 
							
								        :param service_name, optional: sentinel service name
							 | 
						||
| 
								 | 
							
								         (if not provided in `uri`)
							 | 
						||
| 
								 | 
							
								        :param use_replicas: Whether to use replicas for read only operations
							 | 
						||
| 
								 | 
							
								        :param sentinel_kwargs, optional: kwargs to pass as
							 | 
						||
| 
								 | 
							
								         ``sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel`
							 | 
						||
| 
								 | 
							
								        :param options: all remaining keyword arguments are passed
							 | 
						||
| 
								 | 
							
								         directly to the constructor of :class:`coredis.sentinel.Sentinel`
							 | 
						||
| 
								 | 
							
								        :raise ConfigurationError: when the coredis library is not available
							 | 
						||
| 
								 | 
							
								         or if the redis primary host cannot be pinged.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        parsed = urllib.parse.urlparse(uri)
							 | 
						||
| 
								 | 
							
								        sentinel_configuration = []
							 | 
						||
| 
								 | 
							
								        connection_options = options.copy()
							 | 
						||
| 
								 | 
							
								        sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {}
							 | 
						||
| 
								 | 
							
								        parsed_auth: Dict[str, Union[float, str, bool]] = {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if parsed.username:
							 | 
						||
| 
								 | 
							
								            parsed_auth["username"] = parsed.username
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if parsed.password:
							 | 
						||
| 
								 | 
							
								            parsed_auth["password"] = parsed.password
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        sep = parsed.netloc.find("@") + 1
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for loc in parsed.netloc[sep:].split(","):
							 | 
						||
| 
								 | 
							
								            host, port = loc.split(":")
							 | 
						||
| 
								 | 
							
								            sentinel_configuration.append((host, int(port)))
							 | 
						||
| 
								 | 
							
								        self.service_name = (
							 | 
						||
| 
								 | 
							
								            parsed.path.replace("/", "") if parsed.path else service_name
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self.service_name is None:
							 | 
						||
| 
								 | 
							
								            raise ConfigurationError("'service_name' not provided")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        connection_options.setdefault("stream_timeout", 0.2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        super(RedisStorage, self).__init__()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.dependency = self.dependencies["coredis.sentinel"].module
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.sentinel = self.dependency.Sentinel(
							 | 
						||
| 
								 | 
							
								            sentinel_configuration,
							 | 
						||
| 
								 | 
							
								            sentinel_kwargs={**parsed_auth, **sentinel_options},
							 | 
						||
| 
								 | 
							
								            **{**parsed_auth, **connection_options},
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self.storage = self.sentinel.primary_for(self.service_name)
							 | 
						||
| 
								 | 
							
								        self.storage_replica = self.sentinel.replica_for(self.service_name)
							 | 
						||
| 
								 | 
							
								        self.use_replicas = use_replicas
							 | 
						||
| 
								 | 
							
								        self.initialize_storage(uri)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def get(self, key: str) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the counter value for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return await super()._get(
							 | 
						||
| 
								 | 
							
								            key, self.storage_replica if self.use_replicas else self.storage
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def get_expiry(self, key: str) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the expiry for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return await super()._get_expiry(
							 | 
						||
| 
								 | 
							
								            key, self.storage_replica if self.use_replicas else self.storage
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def check(self) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Check if storage is healthy by calling :meth:`coredis.StrictRedis.ping`
							 | 
						||
| 
								 | 
							
								        on the replica.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return await super()._check(
							 | 
						||
| 
								 | 
							
								            self.storage_replica if self.use_replicas else self.storage
							 | 
						||
| 
								 | 
							
								        )
							 |