stormbrigade_sheriff/sbsheriff/Lib/site-packages/disnake/ext/commands/cooldowns.py

380 lines
12 KiB
Python

# SPDX-License-Identifier: MIT
from __future__ import annotations
import asyncio
import time
from collections import deque
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Optional
from disnake.abc import PrivateChannel
from disnake.enums import Enum
from .errors import MaxConcurrencyReached
if TYPE_CHECKING:
from typing_extensions import Self
from ...message import Message
__all__ = (
"BucketType",
"Cooldown",
"CooldownMapping",
"DynamicCooldownMapping",
"MaxConcurrency",
)
class BucketType(Enum):
default = 0
user = 1
guild = 2
channel = 3
member = 4
category = 5
role = 6
def get_key(self, msg: Message) -> Any:
if self is BucketType.user:
return msg.author.id
elif self is BucketType.guild:
return (msg.guild or msg.author).id
elif self is BucketType.channel:
return msg.channel.id
elif self is BucketType.member:
return ((msg.guild and msg.guild.id), msg.author.id)
elif self is BucketType.category:
return (msg.channel.category or msg.channel).id # type: ignore
elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
def __call__(self, msg: Message) -> Any:
return self.get_key(msg)
class Cooldown:
"""Represents a cooldown for a command.
Attributes
----------
rate: :class:`int`
The total number of tokens available per :attr:`per` seconds.
per: :class:`float`
The length of the cooldown period in seconds.
"""
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
self.per: float = float(per)
self._window: float = 0.0
self._tokens: int = self.rate
self._last: float = 0.0
def get_tokens(self, current: Optional[float] = None) -> int:
"""Returns the number of available tokens before rate limiting is applied.
Parameters
----------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to calculate tokens at.
If not supplied then :func:`time.time()` is used.
Returns
-------
:class:`int`
The number of tokens available before the cooldown is to be applied.
"""
if not current:
current = time.time()
tokens = self._tokens
if current > self._window + self.per:
tokens = self.rate
return tokens
def get_retry_after(self, current: Optional[float] = None) -> float:
"""Returns the time in seconds until the cooldown will be reset.
Parameters
----------
current: Optional[:class:`float`]
The current time in seconds since Unix epoch.
If not supplied, then :func:`time.time()` is used.
Returns
-------
:class:`float`
The number of seconds to wait before this cooldown will be reset.
"""
current = current or time.time()
tokens = self.get_tokens(current)
if tokens == 0:
return self.per - (current - self._window)
return 0.0
def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]:
"""Updates the cooldown rate limit.
Parameters
----------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to update the rate limit at.
If not supplied, then :func:`time.time()` is used.
Returns
-------
Optional[:class:`float`]
The retry-after time in seconds if rate limited.
"""
current = current or time.time()
self._last = current
self._tokens = self.get_tokens(current)
# first token used means that we start a new rate limit window
if self._tokens == self.rate:
self._window = current
# check if we are rate limited
if self._tokens == 0:
return self.per - (current - self._window)
# we're not so decrement our tokens
self._tokens -= 1
def reset(self) -> None:
"""Reset the cooldown to its initial state."""
self._tokens = self.rate
self._last = 0.0
def copy(self) -> Cooldown:
"""Creates a copy of this cooldown.
Returns
-------
:class:`Cooldown`
A new instance of this cooldown.
"""
return Cooldown(self.rate, self.per)
def __repr__(self) -> str:
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
class CooldownMapping:
def __init__(
self,
original: Optional[Cooldown],
type: Callable[[Message], Any],
) -> None:
if not callable(type):
raise TypeError("Cooldown type must be a BucketType or callable")
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
self._type: Callable[[Message], Any] = type
def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self) -> bool:
return self._cooldown is not None
@property
def type(self) -> Callable[[Message], Any]:
return self._type
@classmethod
def from_cooldown(cls, rate: float, per: float, type) -> Self:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any:
return self._type(msg)
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
# we want to delete all cache objects that haven't been used
# in a cooldown window. e.g. if we have a command that has a
# cooldown of 60s and it has not been used in 60s then that key should be deleted
current = current or time.time()
dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per]
for k in dead_keys:
del self._cache[k]
def _is_default(self) -> bool:
# This method can be overridden in subclasses
return self._type is BucketType.default
def create_bucket(self, message: Message) -> Cooldown:
return self._cooldown.copy() # type: ignore
def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown:
if self._is_default():
return self._cooldown # type: ignore
self._verify_cache_integrity(current)
key = self._bucket_key(message)
if key not in self._cache:
bucket = self.create_bucket(message)
if bucket is not None:
self._cache[key] = bucket
else:
bucket = self._cache[key]
return bucket
def update_rate_limit(
self, message: Message, current: Optional[float] = None
) -> Optional[float]:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(
self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]
) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self) -> bool:
return True
def _is_default(self) -> bool:
# In dynamic mappings even default bucket types may have custom behavior
return False
def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
class _Semaphore:
"""A custom version of a semaphore.
If you're wondering why asyncio.Semaphore isn't being used,
it's because it doesn't expose the internal value. This internal
value is necessary because I need to support both `wait=True` and
`wait=False`.
An asyncio.Queue could have been used to do this as well -- but it is
not as inefficient since internally that uses two queues and is a bit
overkill for what is basically a counter.
"""
__slots__ = ("value", "loop", "_waiters")
def __init__(self, number: int) -> None:
self.value: int = number
self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str:
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
def locked(self) -> bool:
return self.value == 0
def is_active(self) -> bool:
return len(self._waiters) > 0
def wake_up(self) -> None:
while self._waiters:
future = self._waiters.popleft()
if not future.done():
future.set_result(None)
return
async def acquire(self, *, wait: bool = False) -> bool:
if not wait and self.value <= 0:
# signal that we're not acquiring
return False
while self.value <= 0:
future = self.loop.create_future()
self._waiters.append(future)
try:
await future
except Exception:
future.cancel()
if self.value > 0 and not future.cancelled():
self.wake_up()
raise
self.value -= 1
return True
def release(self) -> None:
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ("number", "per", "wait", "_mapping")
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
self.per: BucketType = per
self.number: int = number
self.wait: bool = wait
if number <= 0:
raise ValueError("max_concurrency 'number' cannot be less than 1")
if not isinstance(per, BucketType):
raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}")
def copy(self) -> Self:
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str:
return f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
def get_key(self, message: Message) -> Any:
return self.per.get_key(message)
async def acquire(self, message: Message) -> None:
key = self.get_key(message)
try:
sem = self._mapping[key]
except KeyError:
self._mapping[key] = sem = _Semaphore(self.number)
acquired = await sem.acquire(wait=self.wait)
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message: Message) -> None:
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(message)
try:
sem = self._mapping[key]
except KeyError:
# ...? peculiar
return
else:
sem.release()
if sem.value >= self.number and not sem.is_active():
del self._mapping[key]