from __future__ import annotations import json import time from functools import lru_cache from threading import Lock from typing import Any, Protocol from app.core.config import get_settings from app.core.logging import logger try: from redis import Redis except Exception: # pragma: no cover Redis = None # type: ignore[assignment] class Cache(Protocol): def get(self, key: str): ... def set(self, key: str, value: Any) -> None: ... def delete(self, key: str) -> None: ... def clear(self) -> None: ... def stats(self) -> dict[str, int | float | str]: ... class TTLCache: def __init__(self, ttl_seconds: int): self.ttl_seconds = ttl_seconds self._store: dict[str, tuple[float, Any]] = {} self._lock = Lock() self._hits = 0 self._misses = 0 def get(self, key: str): now = time.time() with self._lock: item = self._store.get(key) if item is None: self._misses += 1 return None expires_at, value = item if expires_at < now: self._store.pop(key, None) self._misses += 1 return None self._hits += 1 return value def set(self, key: str, value: Any) -> None: expires_at = time.time() + self.ttl_seconds with self._lock: self._store[key] = (expires_at, value) def delete(self, key: str) -> None: with self._lock: self._store.pop(key, None) def clear(self) -> None: with self._lock: self._store.clear() def stats(self) -> dict[str, int | float | str]: with self._lock: requests = self._hits + self._misses hit_rate = (self._hits / requests) if requests else 0.0 return { "backend": "memory", "size": len(self._store), "hits": self._hits, "misses": self._misses, "hit_rate": round(hit_rate, 4), } class RedisCache: def __init__(self, url: str, prefix: str, ttl_seconds: int): if Redis is None: raise RuntimeError("redis package is not installed") self.client = Redis.from_url(url, decode_responses=True) self.prefix = prefix self.ttl_seconds = ttl_seconds self._hits = 0 self._misses = 0 self._lock = Lock() def get(self, key: str): raw = self.client.get(self._key(key)) with self._lock: if raw is None: self._misses += 1 return None self._hits += 1 return json.loads(raw) def set(self, key: str, value: Any) -> None: self.client.set(self._key(key), json.dumps(value, ensure_ascii=False), ex=self.ttl_seconds) def delete(self, key: str) -> None: self.client.delete(self._key(key)) def clear(self) -> None: pattern = f"{self.prefix}:*" cursor = 0 while True: cursor, keys = self.client.scan(cursor=cursor, match=pattern, count=200) if keys: self.client.delete(*keys) if cursor == 0: break def stats(self) -> dict[str, int | float | str]: with self._lock: requests = self._hits + self._misses hit_rate = (self._hits / requests) if requests else 0.0 return { "backend": "redis", "size": int(self.client.dbsize()), "hits": self._hits, "misses": self._misses, "hit_rate": round(hit_rate, 4), } def _key(self, key: str) -> str: return f"{self.prefix}:{key}" def _build_cache(namespace: str, ttl_seconds: int) -> Cache: settings = get_settings() if settings.cache_backend == "redis": try: return RedisCache(settings.redis_url, f"{settings.redis_prefix}:{namespace}", ttl_seconds=ttl_seconds) except Exception: logger.exception("failed to init redis cache namespace=%s fallback to memory cache", namespace) return TTLCache(ttl_seconds=ttl_seconds) @lru_cache def get_match_cache() -> Cache: settings = get_settings() return _build_cache("match", settings.match_cache_ttl_seconds) @lru_cache def get_query_cache() -> Cache: settings = get_settings() return _build_cache("query", settings.query_cache_ttl_seconds)