Files
Airtep/gig-poc/apps/api/app/services/weight_service.py
2026-04-01 14:19:25 +08:00

78 lines
3.1 KiB
Python

from __future__ import annotations
import json
from pathlib import Path
from app.core.config import Settings
from app.core.logging import logger
from app.domain.schemas import MatchBreakdown
class MatchWeightService:
def __init__(self, settings: Settings):
self.settings = settings
self.path: Path = settings.match_weights_path
def default_weights(self) -> dict[str, float]:
return {
"skill": self.settings.score_skill_weight,
"region": self.settings.score_region_weight,
"time": self.settings.score_time_weight,
"experience": self.settings.score_experience_weight,
"reliability": self.settings.score_reliability_weight,
}
def get_weights(self) -> dict[str, float]:
weights = self.default_weights()
if not self.path.exists():
return self._normalize(weights)
try:
data = json.loads(self.path.read_text(encoding="utf-8"))
for key in weights:
value = data.get(key)
if isinstance(value, (int, float)):
weights[key] = float(value)
except Exception:
logger.exception("failed to read learned ranking weights, fallback to defaults")
return self._normalize(weights)
def score(self, breakdown: MatchBreakdown) -> float:
weights = self.get_weights()
return (
weights["skill"] * breakdown.skill_score
+ weights["region"] * breakdown.region_score
+ weights["time"] * breakdown.time_score
+ weights["experience"] * breakdown.experience_score
+ weights["reliability"] * breakdown.reliability_score
)
def update_from_feedback(self, breakdown: MatchBreakdown, accepted: bool) -> dict[str, float]:
weights = self.get_weights()
features = {
"skill": breakdown.skill_score,
"region": breakdown.region_score,
"time": breakdown.time_score,
"experience": breakdown.experience_score,
"reliability": breakdown.reliability_score,
}
target = 1.0 if accepted else 0.0
prediction = sum(weights[name] * value for name, value in features.items())
error = target - prediction
lr = self.settings.ranking_learning_rate
updated = {name: max(0.0, weights[name] + lr * error * value) for name, value in features.items()}
normalized = self._normalize(updated)
self._save_weights(normalized)
return normalized
def _save_weights(self, weights: dict[str, float]) -> None:
self.settings.data_dir.mkdir(parents=True, exist_ok=True)
self.path.write_text(json.dumps(weights, ensure_ascii=False, indent=2), encoding="utf-8")
def _normalize(self, weights: dict[str, float]) -> dict[str, float]:
total = sum(max(value, 0.0) for value in weights.values())
if total <= 0:
fallback = self.default_weights()
total = sum(fallback.values())
return {key: value / total for key, value in fallback.items()}
return {key: max(value, 0.0) / total for key, value in weights.items()}