You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.2 KiB
137 lines
4.2 KiB
import time
|
|
import urllib.parse
|
|
|
|
from deprecated.sphinx import versionadded
|
|
|
|
from limits.aio.storage.base import Storage
|
|
from limits.typing import EmcacheClientP, Optional, Union
|
|
|
|
|
|
@versionadded(version="2.1")
|
|
class MemcachedStorage(Storage):
|
|
"""
|
|
Rate limit storage with memcached as backend.
|
|
|
|
Depends on :pypi:`emcache`
|
|
"""
|
|
|
|
STORAGE_SCHEME = ["async+memcached"]
|
|
"""The storage scheme for memcached to be used in an async context"""
|
|
|
|
DEPENDENCIES = ["emcache"]
|
|
|
|
def __init__(self, uri: str, **options: Union[float, str, bool]) -> None:
|
|
"""
|
|
:param uri: memcached location of the form
|
|
``async+memcached://host:port,host:port``
|
|
:param options: all remaining keyword arguments are passed
|
|
directly to the constructor of :class:`emcache.Client`
|
|
:raise ConfigurationError: when :pypi:`emcache` is not available
|
|
"""
|
|
parsed = urllib.parse.urlparse(uri)
|
|
self.hosts = []
|
|
|
|
for host, port in (
|
|
loc.split(":") for loc in parsed.netloc.strip().split(",") if loc.strip()
|
|
):
|
|
self.hosts.append((host, int(port)))
|
|
|
|
self._options = options
|
|
self._storage = None
|
|
super().__init__(uri, **options)
|
|
self.dependency = self.dependencies["emcache"].module
|
|
|
|
async def get_storage(self) -> EmcacheClientP:
|
|
if not self._storage:
|
|
self._storage = await self.dependency.create_client(
|
|
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
|
|
**self._options,
|
|
)
|
|
assert self._storage
|
|
return self._storage
|
|
|
|
async def get(self, key: str) -> int:
|
|
"""
|
|
:param key: the key to get the counter value for
|
|
"""
|
|
|
|
item = await (await self.get_storage()).get(key.encode("utf-8"))
|
|
|
|
return item and int(item.value) or 0
|
|
|
|
async def clear(self, key: str) -> None:
|
|
"""
|
|
:param key: the key to clear rate limits for
|
|
"""
|
|
await (await self.get_storage()).delete(key.encode("utf-8"))
|
|
|
|
async def incr(
|
|
self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
|
|
) -> int:
|
|
"""
|
|
increments the counter for a given rate limit key
|
|
|
|
:param key: the key to increment
|
|
:param expiry: amount in seconds for the key to expire in
|
|
:param elastic_expiry: whether to keep extending the rate limit
|
|
window every hit.
|
|
:param amount: the number to increment by
|
|
"""
|
|
storage = await self.get_storage()
|
|
limit_key = key.encode("utf-8")
|
|
expire_key = f"{key}/expires".encode()
|
|
added = True
|
|
try:
|
|
await storage.add(limit_key, f"{amount}".encode(), exptime=expiry)
|
|
except self.dependency.NotStoredStorageCommandError:
|
|
added = False
|
|
storage = await self.get_storage()
|
|
|
|
if not added:
|
|
value = await storage.increment(limit_key, amount) or amount
|
|
|
|
if elastic_expiry:
|
|
await storage.touch(limit_key, exptime=expiry)
|
|
await storage.set(
|
|
expire_key,
|
|
str(expiry + time.time()).encode("utf-8"),
|
|
exptime=expiry,
|
|
noreply=False,
|
|
)
|
|
|
|
return value
|
|
else:
|
|
await storage.set(
|
|
expire_key,
|
|
str(expiry + time.time()).encode("utf-8"),
|
|
exptime=expiry,
|
|
noreply=False,
|
|
)
|
|
|
|
return amount
|
|
|
|
async def get_expiry(self, key: str) -> int:
|
|
"""
|
|
:param key: the key to get the expiry for
|
|
"""
|
|
storage = await self.get_storage()
|
|
item = await storage.get(f"{key}/expires".encode())
|
|
|
|
return int(item and float(item.value) or time.time())
|
|
|
|
async def check(self) -> bool:
|
|
"""
|
|
Check if storage is healthy by calling the ``get`` command
|
|
on the key ``limiter-check``
|
|
"""
|
|
try:
|
|
storage = await self.get_storage()
|
|
await storage.get(b"limiter-check")
|
|
|
|
return True
|
|
except: # noqa
|
|
return False
|
|
|
|
async def reset(self) -> Optional[int]:
|
|
raise NotImplementedError
|