You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					189 lines
				
				5.9 KiB
			
		
		
			
		
	
	
					189 lines
				
				5.9 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import inspect
							 | 
						||
| 
								 | 
							
								import threading
							 | 
						||
| 
								 | 
							
								import time
							 | 
						||
| 
								 | 
							
								import urllib.parse
							 | 
						||
| 
								 | 
							
								from types import ModuleType
							 | 
						||
| 
								 | 
							
								from typing import cast
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from limits.errors import ConfigurationError
							 | 
						||
| 
								 | 
							
								from limits.storage.base import Storage
							 | 
						||
| 
								 | 
							
								from limits.typing import Callable, List, MemcachedClientP, Optional, P, R, Tuple, Union
							 | 
						||
| 
								 | 
							
								from limits.util import get_dependency
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class MemcachedStorage(Storage):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Rate limit storage with memcached as backend.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Depends on :pypi:`pymemcache`.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    STORAGE_SCHEME = ["memcached"]
							 | 
						||
| 
								 | 
							
								    """The storage scheme for memcached"""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        uri: str,
							 | 
						||
| 
								 | 
							
								        **options: Union[str, Callable[[], MemcachedClientP]],
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param uri: memcached location of the form
							 | 
						||
| 
								 | 
							
								         ``memcached://host:port,host:port``,
							 | 
						||
| 
								 | 
							
								         ``memcached:///var/tmp/path/to/sock``
							 | 
						||
| 
								 | 
							
								        :param options: all remaining keyword arguments are passed
							 | 
						||
| 
								 | 
							
								         directly to the constructor of :class:`pymemcache.client.base.PooledClient`
							 | 
						||
| 
								 | 
							
								         or :class:`pymemcache.client.hash.HashClient` (if there are more than
							 | 
						||
| 
								 | 
							
								         one hosts specified)
							 | 
						||
| 
								 | 
							
								        :raise ConfigurationError: when :pypi:`pymemcache` is not available
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        parsed = urllib.parse.urlparse(uri)
							 | 
						||
| 
								 | 
							
								        self.hosts = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for loc in parsed.netloc.strip().split(","):
							 | 
						||
| 
								 | 
							
								            if not loc:
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								            host, port = loc.split(":")
							 | 
						||
| 
								 | 
							
								            self.hosts.append((host, int(port)))
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            # filesystem path to UDS
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if parsed.path and not parsed.netloc and not parsed.port:
							 | 
						||
| 
								 | 
							
								                self.hosts = [parsed.path]  # type: ignore
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.library = str(options.pop("library", "pymemcache.client"))
							 | 
						||
| 
								 | 
							
								        self.cluster_library = str(
							 | 
						||
| 
								 | 
							
								            options.pop("cluster_library", "pymemcache.client.hash")
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self.client_getter = cast(
							 | 
						||
| 
								 | 
							
								            Callable[[ModuleType, List[Tuple[str, int]]], MemcachedClientP],
							 | 
						||
| 
								 | 
							
								            options.pop("client_getter", self.get_client),
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self.options = options
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not get_dependency(self.library):
							 | 
						||
| 
								 | 
							
								            raise ConfigurationError(
							 | 
						||
| 
								 | 
							
								                "memcached prerequisite not available."
							 | 
						||
| 
								 | 
							
								                " please install %s" % self.library
							 | 
						||
| 
								 | 
							
								            )  # pragma: no cover
							 | 
						||
| 
								 | 
							
								        self.local_storage = threading.local()
							 | 
						||
| 
								 | 
							
								        self.local_storage.storage = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_client(
							 | 
						||
| 
								 | 
							
								        self, module: ModuleType, hosts: List[Tuple[str, int]], **kwargs: str
							 | 
						||
| 
								 | 
							
								    ) -> MemcachedClientP:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        returns a memcached client.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param module: the memcached module
							 | 
						||
| 
								 | 
							
								        :param hosts: list of memcached hosts
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        return cast(
							 | 
						||
| 
								 | 
							
								            MemcachedClientP,
							 | 
						||
| 
								 | 
							
								            module.HashClient(hosts, **kwargs)
							 | 
						||
| 
								 | 
							
								            if len(hosts) > 1
							 | 
						||
| 
								 | 
							
								            else module.PooledClient(*hosts, **kwargs),
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def call_memcached_func(
							 | 
						||
| 
								 | 
							
								        self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
							 | 
						||
| 
								 | 
							
								    ) -> R:
							 | 
						||
| 
								 | 
							
								        if "noreply" in kwargs:
							 | 
						||
| 
								 | 
							
								            argspec = inspect.getfullargspec(func)
							 | 
						||
| 
								 | 
							
								            if not ("noreply" in argspec.args or argspec.varkw):
							 | 
						||
| 
								 | 
							
								                kwargs.pop("noreply")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return func(*args, **kwargs)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @property
							 | 
						||
| 
								 | 
							
								    def storage(self) -> MemcachedClientP:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        lazily creates a memcached client instance using a thread local
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not (hasattr(self.local_storage, "storage") and self.local_storage.storage):
							 | 
						||
| 
								 | 
							
								            dependency = get_dependency(
							 | 
						||
| 
								 | 
							
								                self.cluster_library if len(self.hosts) > 1 else self.library
							 | 
						||
| 
								 | 
							
								            )[0]
							 | 
						||
| 
								 | 
							
								            if not dependency:
							 | 
						||
| 
								 | 
							
								                raise ConfigurationError(f"Unable to import {self.cluster_library}")
							 | 
						||
| 
								 | 
							
								            self.local_storage.storage = self.client_getter(
							 | 
						||
| 
								 | 
							
								                dependency, self.hosts, **self.options
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return cast(MemcachedClientP, self.local_storage.storage)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get(self, key: str) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the counter value for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return int(self.storage.get(key) or 0)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def clear(self, key: str) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to clear rate limits for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        self.storage.delete(key)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def incr(
							 | 
						||
| 
								 | 
							
								        self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
							 | 
						||
| 
								 | 
							
								    ) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        increments the counter for a given rate limit key
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param key: the key to increment
							 | 
						||
| 
								 | 
							
								        :param expiry: amount in seconds for the key to expire in
							 | 
						||
| 
								 | 
							
								        :param elastic_expiry: whether to keep extending the rate limit
							 | 
						||
| 
								 | 
							
								         window every hit.
							 | 
						||
| 
								 | 
							
								        :param amount: the number to increment by
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not self.call_memcached_func(
							 | 
						||
| 
								 | 
							
								            self.storage.add, key, amount, expiry, noreply=False
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								            value = self.storage.incr(key, amount) or amount
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if elastic_expiry:
							 | 
						||
| 
								 | 
							
								                self.call_memcached_func(self.storage.touch, key, expiry)
							 | 
						||
| 
								 | 
							
								                self.call_memcached_func(
							 | 
						||
| 
								 | 
							
								                    self.storage.set,
							 | 
						||
| 
								 | 
							
								                    key + "/expires",
							 | 
						||
| 
								 | 
							
								                    expiry + time.time(),
							 | 
						||
| 
								 | 
							
								                    expire=expiry,
							 | 
						||
| 
								 | 
							
								                    noreply=False,
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            return value
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            self.call_memcached_func(
							 | 
						||
| 
								 | 
							
								                self.storage.set,
							 | 
						||
| 
								 | 
							
								                key + "/expires",
							 | 
						||
| 
								 | 
							
								                expiry + time.time(),
							 | 
						||
| 
								 | 
							
								                expire=expiry,
							 | 
						||
| 
								 | 
							
								                noreply=False,
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return amount
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_expiry(self, key: str) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the expiry for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return int(float(self.storage.get(key + "/expires") or time.time()))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def check(self) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Check if storage is healthy by calling the ``get`` command
							 | 
						||
| 
								 | 
							
								        on the key ``limiter-check``
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            self.call_memcached_func(self.storage.get, "limiter-check")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            return True
							 | 
						||
| 
								 | 
							
								        except:  # noqa
							 | 
						||
| 
								 | 
							
								            return False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def reset(self) -> Optional[int]:
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError
							 |