78 lines
3.1 KiB
Python
78 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
from app.core.config import Settings
|
|
from app.core.logging import logger
|
|
from app.domain.schemas import MatchBreakdown
|
|
|
|
|
|
class MatchWeightService:
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
self.path: Path = settings.match_weights_path
|
|
|
|
def default_weights(self) -> dict[str, float]:
|
|
return {
|
|
"skill": self.settings.score_skill_weight,
|
|
"region": self.settings.score_region_weight,
|
|
"time": self.settings.score_time_weight,
|
|
"experience": self.settings.score_experience_weight,
|
|
"reliability": self.settings.score_reliability_weight,
|
|
}
|
|
|
|
def get_weights(self) -> dict[str, float]:
|
|
weights = self.default_weights()
|
|
if not self.path.exists():
|
|
return self._normalize(weights)
|
|
try:
|
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
|
for key in weights:
|
|
value = data.get(key)
|
|
if isinstance(value, (int, float)):
|
|
weights[key] = float(value)
|
|
except Exception:
|
|
logger.exception("failed to read learned ranking weights, fallback to defaults")
|
|
return self._normalize(weights)
|
|
|
|
def score(self, breakdown: MatchBreakdown) -> float:
|
|
weights = self.get_weights()
|
|
return (
|
|
weights["skill"] * breakdown.skill_score
|
|
+ weights["region"] * breakdown.region_score
|
|
+ weights["time"] * breakdown.time_score
|
|
+ weights["experience"] * breakdown.experience_score
|
|
+ weights["reliability"] * breakdown.reliability_score
|
|
)
|
|
|
|
def update_from_feedback(self, breakdown: MatchBreakdown, accepted: bool) -> dict[str, float]:
|
|
weights = self.get_weights()
|
|
features = {
|
|
"skill": breakdown.skill_score,
|
|
"region": breakdown.region_score,
|
|
"time": breakdown.time_score,
|
|
"experience": breakdown.experience_score,
|
|
"reliability": breakdown.reliability_score,
|
|
}
|
|
target = 1.0 if accepted else 0.0
|
|
prediction = sum(weights[name] * value for name, value in features.items())
|
|
error = target - prediction
|
|
lr = self.settings.ranking_learning_rate
|
|
updated = {name: max(0.0, weights[name] + lr * error * value) for name, value in features.items()}
|
|
normalized = self._normalize(updated)
|
|
self._save_weights(normalized)
|
|
return normalized
|
|
|
|
def _save_weights(self, weights: dict[str, float]) -> None:
|
|
self.settings.data_dir.mkdir(parents=True, exist_ok=True)
|
|
self.path.write_text(json.dumps(weights, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
|
def _normalize(self, weights: dict[str, float]) -> dict[str, float]:
|
|
total = sum(max(value, 0.0) for value in weights.values())
|
|
if total <= 0:
|
|
fallback = self.default_weights()
|
|
total = sum(fallback.values())
|
|
return {key: value / total for key, value in fallback.items()}
|
|
return {key: max(value, 0.0) / total for key, value in weights.items()}
|