# SPDX-License-Identifier: MIT from __future__ import annotations import asyncio import time from collections import deque from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Optional from disnake.abc import PrivateChannel from disnake.enums import Enum from .errors import MaxConcurrencyReached if TYPE_CHECKING: from typing_extensions import Self from ...message import Message __all__ = ( "BucketType", "Cooldown", "CooldownMapping", "DynamicCooldownMapping", "MaxConcurrency", ) class BucketType(Enum): default = 0 user = 1 guild = 2 channel = 3 member = 4 category = 5 role = 6 def get_key(self, msg: Message) -> Any: if self is BucketType.user: return msg.author.id elif self is BucketType.guild: return (msg.guild or msg.author).id elif self is BucketType.channel: return msg.channel.id elif self is BucketType.member: return ((msg.guild and msg.guild.id), msg.author.id) elif self is BucketType.category: return (msg.channel.category or msg.channel).id # type: ignore elif self is BucketType.role: # we return the channel id of a private-channel as there are only roles in guilds # and that yields the same result as for a guild with only the @everyone role # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore def __call__(self, msg: Message) -> Any: return self.get_key(msg) class Cooldown: """Represents a cooldown for a command. Attributes ---------- rate: :class:`int` The total number of tokens available per :attr:`per` seconds. per: :class:`float` The length of the cooldown period in seconds. """ __slots__ = ("rate", "per", "_window", "_tokens", "_last") def __init__(self, rate: float, per: float) -> None: self.rate: int = int(rate) self.per: float = float(per) self._window: float = 0.0 self._tokens: int = self.rate self._last: float = 0.0 def get_tokens(self, current: Optional[float] = None) -> int: """Returns the number of available tokens before rate limiting is applied. Parameters ---------- current: Optional[:class:`float`] The time in seconds since Unix epoch to calculate tokens at. If not supplied then :func:`time.time()` is used. Returns ------- :class:`int` The number of tokens available before the cooldown is to be applied. """ if not current: current = time.time() tokens = self._tokens if current > self._window + self.per: tokens = self.rate return tokens def get_retry_after(self, current: Optional[float] = None) -> float: """Returns the time in seconds until the cooldown will be reset. Parameters ---------- current: Optional[:class:`float`] The current time in seconds since Unix epoch. If not supplied, then :func:`time.time()` is used. Returns ------- :class:`float` The number of seconds to wait before this cooldown will be reset. """ current = current or time.time() tokens = self.get_tokens(current) if tokens == 0: return self.per - (current - self._window) return 0.0 def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]: """Updates the cooldown rate limit. Parameters ---------- current: Optional[:class:`float`] The time in seconds since Unix epoch to update the rate limit at. If not supplied, then :func:`time.time()` is used. Returns ------- Optional[:class:`float`] The retry-after time in seconds if rate limited. """ current = current or time.time() self._last = current self._tokens = self.get_tokens(current) # first token used means that we start a new rate limit window if self._tokens == self.rate: self._window = current # check if we are rate limited if self._tokens == 0: return self.per - (current - self._window) # we're not so decrement our tokens self._tokens -= 1 def reset(self) -> None: """Reset the cooldown to its initial state.""" self._tokens = self.rate self._last = 0.0 def copy(self) -> Cooldown: """Creates a copy of this cooldown. Returns ------- :class:`Cooldown` A new instance of this cooldown. """ return Cooldown(self.rate, self.per) def __repr__(self) -> str: return f"" class CooldownMapping: def __init__( self, original: Optional[Cooldown], type: Callable[[Message], Any], ) -> None: if not callable(type): raise TypeError("Cooldown type must be a BucketType or callable") self._cache: Dict[Any, Cooldown] = {} self._cooldown: Optional[Cooldown] = original self._type: Callable[[Message], Any] = type def copy(self) -> CooldownMapping: ret = CooldownMapping(self._cooldown, self._type) ret._cache = self._cache.copy() return ret @property def valid(self) -> bool: return self._cooldown is not None @property def type(self) -> Callable[[Message], Any]: return self._type @classmethod def from_cooldown(cls, rate: float, per: float, type) -> Self: return cls(Cooldown(rate, per), type) def _bucket_key(self, msg: Message) -> Any: return self._type(msg) def _verify_cache_integrity(self, current: Optional[float] = None) -> None: # we want to delete all cache objects that haven't been used # in a cooldown window. e.g. if we have a command that has a # cooldown of 60s and it has not been used in 60s then that key should be deleted current = current or time.time() dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per] for k in dead_keys: del self._cache[k] def _is_default(self) -> bool: # This method can be overridden in subclasses return self._type is BucketType.default def create_bucket(self, message: Message) -> Cooldown: return self._cooldown.copy() # type: ignore def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown: if self._is_default(): return self._cooldown # type: ignore self._verify_cache_integrity(current) key = self._bucket_key(message) if key not in self._cache: bucket = self.create_bucket(message) if bucket is not None: self._cache[key] = bucket else: bucket = self._cache[key] return bucket def update_rate_limit( self, message: Message, current: Optional[float] = None ) -> Optional[float]: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) class DynamicCooldownMapping(CooldownMapping): def __init__( self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any] ) -> None: super().__init__(None, type) self._factory: Callable[[Message], Cooldown] = factory def copy(self) -> DynamicCooldownMapping: ret = DynamicCooldownMapping(self._factory, self._type) ret._cache = self._cache.copy() return ret @property def valid(self) -> bool: return True def _is_default(self) -> bool: # In dynamic mappings even default bucket types may have custom behavior return False def create_bucket(self, message: Message) -> Cooldown: return self._factory(message) class _Semaphore: """A custom version of a semaphore. If you're wondering why asyncio.Semaphore isn't being used, it's because it doesn't expose the internal value. This internal value is necessary because I need to support both `wait=True` and `wait=False`. An asyncio.Queue could have been used to do this as well -- but it is not as inefficient since internally that uses two queues and is a bit overkill for what is basically a counter. """ __slots__ = ("value", "loop", "_waiters") def __init__(self, number: int) -> None: self.value: int = number self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() self._waiters: Deque[asyncio.Future] = deque() def __repr__(self) -> str: return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>" def locked(self) -> bool: return self.value == 0 def is_active(self) -> bool: return len(self._waiters) > 0 def wake_up(self) -> None: while self._waiters: future = self._waiters.popleft() if not future.done(): future.set_result(None) return async def acquire(self, *, wait: bool = False) -> bool: if not wait and self.value <= 0: # signal that we're not acquiring return False while self.value <= 0: future = self.loop.create_future() self._waiters.append(future) try: await future except Exception: future.cancel() if self.value > 0 and not future.cancelled(): self.wake_up() raise self.value -= 1 return True def release(self) -> None: self.value += 1 self.wake_up() class MaxConcurrency: __slots__ = ("number", "per", "wait", "_mapping") def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: self._mapping: Dict[Any, _Semaphore] = {} self.per: BucketType = per self.number: int = number self.wait: bool = wait if number <= 0: raise ValueError("max_concurrency 'number' cannot be less than 1") if not isinstance(per, BucketType): raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}") def copy(self) -> Self: return self.__class__(self.number, per=self.per, wait=self.wait) def __repr__(self) -> str: return f"" def get_key(self, message: Message) -> Any: return self.per.get_key(message) async def acquire(self, message: Message) -> None: key = self.get_key(message) try: sem = self._mapping[key] except KeyError: self._mapping[key] = sem = _Semaphore(self.number) acquired = await sem.acquire(wait=self.wait) if not acquired: raise MaxConcurrencyReached(self.number, self.per) async def release(self, message: Message) -> None: # Technically there's no reason for this function to be async # But it might be more useful in the future key = self.get_key(message) try: sem = self._mapping[key] except KeyError: # ...? peculiar return else: sem.release() if sem.value >= self.number and not sem.is_active(): del self._mapping[key]