feat: 初始化零工后端代码
This commit is contained in:
77
gig-poc/apps/api/app/services/weight_service.py
Normal file
77
gig-poc/apps/api/app/services/weight_service.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import Settings
|
||||
from app.core.logging import logger
|
||||
from app.domain.schemas import MatchBreakdown
|
||||
|
||||
|
||||
class MatchWeightService:
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.path: Path = settings.match_weights_path
|
||||
|
||||
def default_weights(self) -> dict[str, float]:
|
||||
return {
|
||||
"skill": self.settings.score_skill_weight,
|
||||
"region": self.settings.score_region_weight,
|
||||
"time": self.settings.score_time_weight,
|
||||
"experience": self.settings.score_experience_weight,
|
||||
"reliability": self.settings.score_reliability_weight,
|
||||
}
|
||||
|
||||
def get_weights(self) -> dict[str, float]:
|
||||
weights = self.default_weights()
|
||||
if not self.path.exists():
|
||||
return self._normalize(weights)
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
for key in weights:
|
||||
value = data.get(key)
|
||||
if isinstance(value, (int, float)):
|
||||
weights[key] = float(value)
|
||||
except Exception:
|
||||
logger.exception("failed to read learned ranking weights, fallback to defaults")
|
||||
return self._normalize(weights)
|
||||
|
||||
def score(self, breakdown: MatchBreakdown) -> float:
|
||||
weights = self.get_weights()
|
||||
return (
|
||||
weights["skill"] * breakdown.skill_score
|
||||
+ weights["region"] * breakdown.region_score
|
||||
+ weights["time"] * breakdown.time_score
|
||||
+ weights["experience"] * breakdown.experience_score
|
||||
+ weights["reliability"] * breakdown.reliability_score
|
||||
)
|
||||
|
||||
def update_from_feedback(self, breakdown: MatchBreakdown, accepted: bool) -> dict[str, float]:
|
||||
weights = self.get_weights()
|
||||
features = {
|
||||
"skill": breakdown.skill_score,
|
||||
"region": breakdown.region_score,
|
||||
"time": breakdown.time_score,
|
||||
"experience": breakdown.experience_score,
|
||||
"reliability": breakdown.reliability_score,
|
||||
}
|
||||
target = 1.0 if accepted else 0.0
|
||||
prediction = sum(weights[name] * value for name, value in features.items())
|
||||
error = target - prediction
|
||||
lr = self.settings.ranking_learning_rate
|
||||
updated = {name: max(0.0, weights[name] + lr * error * value) for name, value in features.items()}
|
||||
normalized = self._normalize(updated)
|
||||
self._save_weights(normalized)
|
||||
return normalized
|
||||
|
||||
def _save_weights(self, weights: dict[str, float]) -> None:
|
||||
self.settings.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.path.write_text(json.dumps(weights, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
def _normalize(self, weights: dict[str, float]) -> dict[str, float]:
|
||||
total = sum(max(value, 0.0) for value in weights.values())
|
||||
if total <= 0:
|
||||
fallback = self.default_weights()
|
||||
total = sum(fallback.values())
|
||||
return {key: value / total for key, value in fallback.items()}
|
||||
return {key: max(value, 0.0) / total for key, value in weights.items()}
|
||||
Reference in New Issue
Block a user