Files
Airtep/gig-poc/apps/api/app/services/ai_guard.py
2026-04-01 14:19:25 +08:00

88 lines
3.1 KiB
Python

from __future__ import annotations
import time
from dataclasses import dataclass
from threading import Lock
from app.core.config import Settings
@dataclass
class EndpointState:
current_minute: int = 0
minute_count: int = 0
consecutive_failures: int = 0
circuit_open_until: float = 0.0
class AIGuard:
_lock = Lock()
_endpoint_states: dict[str, EndpointState] = {}
_metrics = {
"requests_total": 0,
"success_total": 0,
"fail_total": 0,
"fallback_total": 0,
"rate_limited_total": 0,
"circuit_open_total": 0,
"endpoint_failover_total": 0,
}
def __init__(self, settings: Settings):
self.settings = settings
def allow_request(self, endpoint: str) -> tuple[bool, str]:
now = time.time()
now_minute = int(now // 60)
with self._lock:
state = self._endpoint_states.setdefault(endpoint, EndpointState())
if state.circuit_open_until > now:
self._metrics["circuit_open_total"] += 1
return False, "circuit_open"
if state.current_minute != now_minute:
state.current_minute = now_minute
state.minute_count = 0
if state.minute_count >= self.settings.ai_rate_limit_per_minute:
self._metrics["rate_limited_total"] += 1
return False, "rate_limited"
state.minute_count += 1
self._metrics["requests_total"] += 1
return True, "ok"
def record_success(self, endpoint: str) -> None:
with self._lock:
state = self._endpoint_states.setdefault(endpoint, EndpointState())
state.consecutive_failures = 0
state.circuit_open_until = 0.0
self._metrics["success_total"] += 1
def record_failure(self, endpoint: str) -> None:
with self._lock:
state = self._endpoint_states.setdefault(endpoint, EndpointState())
state.consecutive_failures += 1
self._metrics["fail_total"] += 1
if state.consecutive_failures >= self.settings.ai_circuit_breaker_fail_threshold:
state.circuit_open_until = time.time() + self.settings.ai_circuit_breaker_cooldown_seconds
state.consecutive_failures = 0
def record_fallback(self) -> None:
with self._lock:
self._metrics["fallback_total"] += 1
def record_failover(self) -> None:
with self._lock:
self._metrics["endpoint_failover_total"] += 1
def snapshot(self) -> dict:
with self._lock:
requests_total = self._metrics["requests_total"]
fallback_total = self._metrics["fallback_total"]
success_total = self._metrics["success_total"]
fail_total = self._metrics["fail_total"]
return {
**self._metrics,
"fallback_hit_rate": round(fallback_total / requests_total, 4) if requests_total else 0.0,
"success_rate": round(success_total / requests_total, 4) if requests_total else 0.0,
"failure_rate": round(fail_total / requests_total, 4) if requests_total else 0.0,
}