You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					165 lines
				
				5.0 KiB
			
		
		
			
		
	
	
					165 lines
				
				5.0 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import threading
							 | 
						||
| 
								 | 
							
								import time
							 | 
						||
| 
								 | 
							
								from collections import Counter
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import limits.typing
							 | 
						||
| 
								 | 
							
								from limits.storage.base import MovingWindowSupport, Storage
							 | 
						||
| 
								 | 
							
								from limits.typing import Dict, List, Optional, Tuple
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class LockableEntry(threading._RLock):
							 | 
						||
| 
								 | 
							
								    def __init__(self, expiry: float) -> None:
							 | 
						||
| 
								 | 
							
								        self.atime = time.time()
							 | 
						||
| 
								 | 
							
								        self.expiry = self.atime + expiry
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class MemoryStorage(Storage, MovingWindowSupport):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    rate limit storage using :class:`collections.Counter`
							 | 
						||
| 
								 | 
							
								    as an in memory storage for fixed and elastic window strategies,
							 | 
						||
| 
								 | 
							
								    and a simple list to implement moving window strategy.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    STORAGE_SCHEME = ["memory"]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, uri: Optional[str] = None, **_: str):
							 | 
						||
| 
								 | 
							
								        self.storage: limits.typing.Counter[str] = Counter()
							 | 
						||
| 
								 | 
							
								        self.expirations: Dict[str, float] = {}
							 | 
						||
| 
								 | 
							
								        self.events: Dict[str, List[LockableEntry]] = {}
							 | 
						||
| 
								 | 
							
								        self.timer = threading.Timer(0.01, self.__expire_events)
							 | 
						||
| 
								 | 
							
								        self.timer.start()
							 | 
						||
| 
								 | 
							
								        super().__init__(uri, **_)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __expire_events(self) -> None:
							 | 
						||
| 
								 | 
							
								        for key in list(self.events.keys()):
							 | 
						||
| 
								 | 
							
								            for event in list(self.events[key]):
							 | 
						||
| 
								 | 
							
								                with event:
							 | 
						||
| 
								 | 
							
								                    if event.expiry <= time.time() and event in self.events[key]:
							 | 
						||
| 
								 | 
							
								                        self.events[key].remove(event)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for key in list(self.expirations.keys()):
							 | 
						||
| 
								 | 
							
								            if self.expirations[key] <= time.time():
							 | 
						||
| 
								 | 
							
								                self.storage.pop(key, None)
							 | 
						||
| 
								 | 
							
								                self.expirations.pop(key, None)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __schedule_expiry(self) -> None:
							 | 
						||
| 
								 | 
							
								        if not self.timer.is_alive():
							 | 
						||
| 
								 | 
							
								            self.timer = threading.Timer(0.01, self.__expire_events)
							 | 
						||
| 
								 | 
							
								            self.timer.start()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def incr(
							 | 
						||
| 
								 | 
							
								        self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
							 | 
						||
| 
								 | 
							
								    ) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        increments the counter for a given rate limit key
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param key: the key to increment
							 | 
						||
| 
								 | 
							
								        :param expiry: amount in seconds for the key to expire in
							 | 
						||
| 
								 | 
							
								        :param elastic_expiry: whether to keep extending the rate limit
							 | 
						||
| 
								 | 
							
								         window every hit.
							 | 
						||
| 
								 | 
							
								        :param amount: the number to increment by
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        self.get(key)
							 | 
						||
| 
								 | 
							
								        self.__schedule_expiry()
							 | 
						||
| 
								 | 
							
								        self.storage[key] += amount
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if elastic_expiry or self.storage[key] == amount:
							 | 
						||
| 
								 | 
							
								            self.expirations[key] = time.time() + expiry
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return self.storage.get(key, 0)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get(self, key: str) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the counter value for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self.expirations.get(key, 0) <= time.time():
							 | 
						||
| 
								 | 
							
								            self.storage.pop(key, None)
							 | 
						||
| 
								 | 
							
								            self.expirations.pop(key, None)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return self.storage.get(key, 0)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def clear(self, key: str) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to clear rate limits for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        self.storage.pop(key, None)
							 | 
						||
| 
								 | 
							
								        self.expirations.pop(key, None)
							 | 
						||
| 
								 | 
							
								        self.events.pop(key, None)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: rate limit key to acquire an entry in
							 | 
						||
| 
								 | 
							
								        :param limit: amount of entries allowed
							 | 
						||
| 
								 | 
							
								        :param expiry: expiry of the entry
							 | 
						||
| 
								 | 
							
								        :param amount: the number of entries to acquire
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        self.events.setdefault(key, [])
							 | 
						||
| 
								 | 
							
								        self.__schedule_expiry()
							 | 
						||
| 
								 | 
							
								        timestamp = time.time()
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            entry = self.events[key][limit - amount]
							 | 
						||
| 
								 | 
							
								        except IndexError:
							 | 
						||
| 
								 | 
							
								            entry = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if entry and entry.atime >= timestamp - expiry:
							 | 
						||
| 
								 | 
							
								            return False
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            self.events[key][:0] = [LockableEntry(expiry) for _ in range(amount)]
							 | 
						||
| 
								 | 
							
								            return True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_expiry(self, key: str) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        :param key: the key to get the expiry for
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return int(self.expirations.get(key, time.time()))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_num_acquired(self, key: str, expiry: int) -> int:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        returns the number of entries already acquired
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param key: rate limit key to acquire an entry in
							 | 
						||
| 
								 | 
							
								        :param expiry: expiry of the entry
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        timestamp = time.time()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return (
							 | 
						||
| 
								 | 
							
								            len([k for k in self.events[key] if k.atime >= timestamp - expiry])
							 | 
						||
| 
								 | 
							
								            if self.events.get(key)
							 | 
						||
| 
								 | 
							
								            else 0
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        returns the starting point and the number of entries in the moving
							 | 
						||
| 
								 | 
							
								        window
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param key: rate limit key
							 | 
						||
| 
								 | 
							
								        :param expiry: expiry of entry
							 | 
						||
| 
								 | 
							
								        :return: (start of window, number of acquired entries)
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        timestamp = time.time()
							 | 
						||
| 
								 | 
							
								        acquired = self.get_num_acquired(key, expiry)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for item in self.events.get(key, []):
							 | 
						||
| 
								 | 
							
								            if item.atime >= timestamp - expiry:
							 | 
						||
| 
								 | 
							
								                return int(item.atime), acquired
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return int(timestamp), acquired
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def check(self) -> bool:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        check if storage is healthy
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def reset(self) -> Optional[int]:
							 | 
						||
| 
								 | 
							
								        self.storage.clear()
							 | 
						||
| 
								 | 
							
								        self.expirations.clear()
							 | 
						||
| 
								 | 
							
								        self.events.clear()
							 | 
						||
| 
								 | 
							
								        return None
							 |