147 lines
4.4 KiB
Python
147 lines
4.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from functools import lru_cache
|
|
from threading import Lock
|
|
from typing import Any, Protocol
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.logging import logger
|
|
|
|
try:
|
|
from redis import Redis
|
|
except Exception: # pragma: no cover
|
|
Redis = None # type: ignore[assignment]
|
|
|
|
|
|
class Cache(Protocol):
|
|
def get(self, key: str): ...
|
|
def set(self, key: str, value: Any) -> None: ...
|
|
def delete(self, key: str) -> None: ...
|
|
def clear(self) -> None: ...
|
|
def stats(self) -> dict[str, int | float | str]: ...
|
|
|
|
|
|
class TTLCache:
|
|
def __init__(self, ttl_seconds: int):
|
|
self.ttl_seconds = ttl_seconds
|
|
self._store: dict[str, tuple[float, Any]] = {}
|
|
self._lock = Lock()
|
|
self._hits = 0
|
|
self._misses = 0
|
|
|
|
def get(self, key: str):
|
|
now = time.time()
|
|
with self._lock:
|
|
item = self._store.get(key)
|
|
if item is None:
|
|
self._misses += 1
|
|
return None
|
|
expires_at, value = item
|
|
if expires_at < now:
|
|
self._store.pop(key, None)
|
|
self._misses += 1
|
|
return None
|
|
self._hits += 1
|
|
return value
|
|
|
|
def set(self, key: str, value: Any) -> None:
|
|
expires_at = time.time() + self.ttl_seconds
|
|
with self._lock:
|
|
self._store[key] = (expires_at, value)
|
|
|
|
def delete(self, key: str) -> None:
|
|
with self._lock:
|
|
self._store.pop(key, None)
|
|
|
|
def clear(self) -> None:
|
|
with self._lock:
|
|
self._store.clear()
|
|
|
|
def stats(self) -> dict[str, int | float | str]:
|
|
with self._lock:
|
|
requests = self._hits + self._misses
|
|
hit_rate = (self._hits / requests) if requests else 0.0
|
|
return {
|
|
"backend": "memory",
|
|
"size": len(self._store),
|
|
"hits": self._hits,
|
|
"misses": self._misses,
|
|
"hit_rate": round(hit_rate, 4),
|
|
}
|
|
|
|
|
|
class RedisCache:
|
|
def __init__(self, url: str, prefix: str, ttl_seconds: int):
|
|
if Redis is None:
|
|
raise RuntimeError("redis package is not installed")
|
|
self.client = Redis.from_url(url, decode_responses=True)
|
|
self.prefix = prefix
|
|
self.ttl_seconds = ttl_seconds
|
|
self._hits = 0
|
|
self._misses = 0
|
|
self._lock = Lock()
|
|
|
|
def get(self, key: str):
|
|
raw = self.client.get(self._key(key))
|
|
with self._lock:
|
|
if raw is None:
|
|
self._misses += 1
|
|
return None
|
|
self._hits += 1
|
|
return json.loads(raw)
|
|
|
|
def set(self, key: str, value: Any) -> None:
|
|
self.client.set(self._key(key), json.dumps(value, ensure_ascii=False), ex=self.ttl_seconds)
|
|
|
|
def delete(self, key: str) -> None:
|
|
self.client.delete(self._key(key))
|
|
|
|
def clear(self) -> None:
|
|
pattern = f"{self.prefix}:*"
|
|
cursor = 0
|
|
while True:
|
|
cursor, keys = self.client.scan(cursor=cursor, match=pattern, count=200)
|
|
if keys:
|
|
self.client.delete(*keys)
|
|
if cursor == 0:
|
|
break
|
|
|
|
def stats(self) -> dict[str, int | float | str]:
|
|
with self._lock:
|
|
requests = self._hits + self._misses
|
|
hit_rate = (self._hits / requests) if requests else 0.0
|
|
return {
|
|
"backend": "redis",
|
|
"size": int(self.client.dbsize()),
|
|
"hits": self._hits,
|
|
"misses": self._misses,
|
|
"hit_rate": round(hit_rate, 4),
|
|
}
|
|
|
|
def _key(self, key: str) -> str:
|
|
return f"{self.prefix}:{key}"
|
|
|
|
|
|
def _build_cache(namespace: str, ttl_seconds: int) -> Cache:
|
|
settings = get_settings()
|
|
if settings.cache_backend == "redis":
|
|
try:
|
|
return RedisCache(settings.redis_url, f"{settings.redis_prefix}:{namespace}", ttl_seconds=ttl_seconds)
|
|
except Exception:
|
|
logger.exception("failed to init redis cache namespace=%s fallback to memory cache", namespace)
|
|
return TTLCache(ttl_seconds=ttl_seconds)
|
|
|
|
|
|
@lru_cache
|
|
def get_match_cache() -> Cache:
|
|
settings = get_settings()
|
|
return _build_cache("match", settings.match_cache_ttl_seconds)
|
|
|
|
|
|
@lru_cache
|
|
def get_query_cache() -> Cache:
|
|
settings = get_settings()
|
|
return _build_cache("query", settings.query_cache_ttl_seconds)
|