Files
Airtep/gig-poc/apps/api/app/services/cache_service.py
2026-04-01 14:19:25 +08:00

147 lines
4.4 KiB
Python

from __future__ import annotations
import json
import time
from functools import lru_cache
from threading import Lock
from typing import Any, Protocol
from app.core.config import get_settings
from app.core.logging import logger
try:
from redis import Redis
except Exception: # pragma: no cover
Redis = None # type: ignore[assignment]
class Cache(Protocol):
def get(self, key: str): ...
def set(self, key: str, value: Any) -> None: ...
def delete(self, key: str) -> None: ...
def clear(self) -> None: ...
def stats(self) -> dict[str, int | float | str]: ...
class TTLCache:
def __init__(self, ttl_seconds: int):
self.ttl_seconds = ttl_seconds
self._store: dict[str, tuple[float, Any]] = {}
self._lock = Lock()
self._hits = 0
self._misses = 0
def get(self, key: str):
now = time.time()
with self._lock:
item = self._store.get(key)
if item is None:
self._misses += 1
return None
expires_at, value = item
if expires_at < now:
self._store.pop(key, None)
self._misses += 1
return None
self._hits += 1
return value
def set(self, key: str, value: Any) -> None:
expires_at = time.time() + self.ttl_seconds
with self._lock:
self._store[key] = (expires_at, value)
def delete(self, key: str) -> None:
with self._lock:
self._store.pop(key, None)
def clear(self) -> None:
with self._lock:
self._store.clear()
def stats(self) -> dict[str, int | float | str]:
with self._lock:
requests = self._hits + self._misses
hit_rate = (self._hits / requests) if requests else 0.0
return {
"backend": "memory",
"size": len(self._store),
"hits": self._hits,
"misses": self._misses,
"hit_rate": round(hit_rate, 4),
}
class RedisCache:
def __init__(self, url: str, prefix: str, ttl_seconds: int):
if Redis is None:
raise RuntimeError("redis package is not installed")
self.client = Redis.from_url(url, decode_responses=True)
self.prefix = prefix
self.ttl_seconds = ttl_seconds
self._hits = 0
self._misses = 0
self._lock = Lock()
def get(self, key: str):
raw = self.client.get(self._key(key))
with self._lock:
if raw is None:
self._misses += 1
return None
self._hits += 1
return json.loads(raw)
def set(self, key: str, value: Any) -> None:
self.client.set(self._key(key), json.dumps(value, ensure_ascii=False), ex=self.ttl_seconds)
def delete(self, key: str) -> None:
self.client.delete(self._key(key))
def clear(self) -> None:
pattern = f"{self.prefix}:*"
cursor = 0
while True:
cursor, keys = self.client.scan(cursor=cursor, match=pattern, count=200)
if keys:
self.client.delete(*keys)
if cursor == 0:
break
def stats(self) -> dict[str, int | float | str]:
with self._lock:
requests = self._hits + self._misses
hit_rate = (self._hits / requests) if requests else 0.0
return {
"backend": "redis",
"size": int(self.client.dbsize()),
"hits": self._hits,
"misses": self._misses,
"hit_rate": round(hit_rate, 4),
}
def _key(self, key: str) -> str:
return f"{self.prefix}:{key}"
def _build_cache(namespace: str, ttl_seconds: int) -> Cache:
settings = get_settings()
if settings.cache_backend == "redis":
try:
return RedisCache(settings.redis_url, f"{settings.redis_prefix}:{namespace}", ttl_seconds=ttl_seconds)
except Exception:
logger.exception("failed to init redis cache namespace=%s fallback to memory cache", namespace)
return TTLCache(ttl_seconds=ttl_seconds)
@lru_cache
def get_match_cache() -> Cache:
settings = get_settings()
return _build_cache("match", settings.match_cache_ttl_seconds)
@lru_cache
def get_query_cache() -> Cache:
settings = get_settings()
return _build_cache("query", settings.query_cache_ttl_seconds)