"""Base interface for KV clients."""
from __future__ import annotations
import abc
import asyncio
import json
import math
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
from dataclasses import dataclass
from typing import Any
[docs]
@dataclass(frozen=True, slots=True)
class RateLimitResult:
"""Result of a rate limit check."""
allowed: bool
count: int
limit: int
remaining: int
retry_after: float | None
[docs]
class KVClient(abc.ABC):
"""Byte-level async KV client."""
supports_ttl: bool
supports_scan: bool
supports_batch: bool
[docs]
@abc.abstractmethod
async def connect(self) -> None:
"""Connect to the store."""
[docs]
@abc.abstractmethod
async def disconnect(self) -> None:
"""Disconnect from the store."""
[docs]
@abc.abstractmethod
async def get(self, key: bytes) -> bytes | None:
"""Fetch a value by key."""
[docs]
@abc.abstractmethod
async def set(self, key: bytes, value: bytes, *, ttl: float | None = None) -> None:
"""Set a value by key."""
[docs]
@abc.abstractmethod
async def delete(self, key: bytes) -> bool:
"""Delete a key."""
[docs]
@abc.abstractmethod
async def exists(self, key: bytes) -> bool:
"""Check if a key exists."""
[docs]
@abc.abstractmethod
async def mget(self, keys: Sequence[bytes]) -> Sequence[bytes | None]:
"""Fetch multiple keys."""
[docs]
@abc.abstractmethod
async def mset(
self, items: Sequence[tuple[bytes, bytes]], *, ttl: float | None = None
) -> None:
"""Set multiple key/value pairs."""
[docs]
@abc.abstractmethod
async def delete_many(self, keys: Sequence[bytes]) -> int:
"""Delete multiple keys."""
[docs]
@abc.abstractmethod
async def ttl(self, key: bytes) -> float | None:
"""Return remaining TTL in seconds, or None."""
[docs]
@abc.abstractmethod
async def expire(self, key: bytes, ttl: float) -> bool:
"""Set TTL for a key. Returns False if missing."""
[docs]
@abc.abstractmethod
async def set_nx(
self, key: bytes, value: bytes, *, ttl: float | None = None
) -> bool:
"""Set a value only if the key does not exist. Returns True if set."""
[docs]
@abc.abstractmethod
async def incr(self, key: bytes) -> int:
"""Increment a key by 1 and return the new value.
Creates the key with value 1 if it does not exist.
"""
[docs]
@abc.abstractmethod
async def scan(
self, *, prefix: bytes | None = None, limit: int | None = None
) -> AsyncIterator[bytes]:
"""Iterate keys with optional prefix and limit."""
yield b"" # pragma: no cover
raise NotImplementedError # pragma: no cover
[docs]
async def guarded_get(
self,
cache_key: bytes,
*,
compute: Callable[[], Awaitable[bytes]],
ttl: float,
lock_ttl: float = 2.0,
retry_delay: float = 0.05,
) -> bytes:
"""Fetch from cache with stampede protection.
On a cache miss, only one caller acquires a lock and computes the
value. Other callers wait for the cache to be populated. If retries
are exhausted, the caller falls through and computes directly.
The wait budget is derived from ``lock_ttl`` so waiters keep
retrying for the full duration the lock could be held.
Args:
cache_key: The cache key to read/write.
compute: Async callable that produces the value on cache miss.
ttl: TTL in seconds for the cached value.
lock_ttl: TTL in seconds for the lock key.
retry_delay: Seconds to sleep between retry attempts.
Returns:
The cached or freshly computed value.
"""
cached = await self.get(cache_key)
if cached is not None:
return cached
lock_key = cache_key + b":lock"
acquired = await self.set_nx(lock_key, b"1", ttl=lock_ttl)
if acquired:
try:
cached = await self.get(cache_key)
if cached is not None:
return cached
value = await compute()
await self.set(cache_key, value, ttl=ttl)
return value
finally:
await self.delete(lock_key)
max_retries = round(lock_ttl / retry_delay)
for _ in range(max_retries):
await asyncio.sleep(retry_delay)
cached = await self.get(cache_key)
if cached is not None:
return cached
return await compute()
[docs]
async def idempotent_execute(
self,
*,
key: str,
compute: Callable[[], Awaitable[Any]],
status_code: int = 200,
ttl: float = 86400,
key_prefix: str = "derp:idempotency",
) -> tuple[Any, int, bool]:
"""Execute idempotently: run ``compute`` once per key.
On the first call for a given key, ``compute`` is invoked and
the result is cached. Subsequent calls return the cached result
without re-invoking ``compute``. Uses :meth:`guarded_get` for
stampede protection.
Args:
key: Idempotency key (typically from a client header).
compute: Async callable producing a JSON-serializable result.
status_code: HTTP status code to cache alongside the body.
ttl: Cache TTL in seconds (default 24h).
key_prefix: KV key prefix.
Returns:
``(body, status_code, is_replay)`` — *body* is the
deserialized result, *status_code* is the cached status,
and *is_replay* is ``True`` when the cached value was used.
"""
cache_key = f"{key_prefix}:{key}".encode()
was_computed = False
async def _compute() -> bytes:
nonlocal was_computed
was_computed = True
result = await compute()
payload = json.dumps(
{"status_code": status_code, "body": result},
default=str,
)
return payload.encode()
raw = await self.guarded_get(cache_key, compute=_compute, ttl=ttl)
parsed = json.loads(raw)
return parsed["body"], parsed["status_code"], not was_computed
[docs]
async def already_processed(
self,
*,
event_id: str,
ttl: float = 86400,
key_prefix: str = "derp:webhook",
) -> bool:
"""Check if an event has already been processed.
Uses :meth:`set_nx` to atomically mark the event. Returns
``True`` if the event was already seen, ``False`` on first call.
Args:
event_id: Unique event identifier (e.g. Stripe event ID).
ttl: How long to remember the event (default 24h).
key_prefix: KV key prefix.
"""
cache_key = f"{key_prefix}:{event_id}".encode()
acquired = await self.set_nx(cache_key, b"1", ttl=ttl)
return not acquired
[docs]
async def rate_limit(
self,
key: str,
*,
limit: int,
window: float,
key_prefix: str = "derp:ratelimit",
) -> RateLimitResult:
"""Fixed-window rate limit check.
Increments a counter for the given key. The counter resets after
``window`` seconds. Returns a :class:`RateLimitResult` indicating
whether the request is allowed.
Args:
key: Identifier to rate limit (e.g. ``f"checkout:{user_id}"``).
limit: Maximum number of requests per window.
window: Window duration in seconds.
key_prefix: KV key prefix.
"""
cache_key = f"{key_prefix}:{key}".encode()
count = await self.incr(cache_key)
if count == 1:
await self.expire(cache_key, window)
remaining = max(0, limit - count)
allowed = count <= limit
retry_after = None
if not allowed:
ttl = await self.ttl(cache_key)
retry_after = math.ceil(ttl) if ttl is not None else window
return RateLimitResult(
allowed=allowed,
count=count,
limit=limit,
remaining=remaining,
retry_after=retry_after,
)