237 lines
9.6 KiB
Python
237 lines
9.6 KiB
Python
from __future__ import annotations
|
||
|
||
from datetime import datetime, timedelta, timezone
|
||
from enum import Enum
|
||
|
||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||
|
||
|
||
class SalaryType(str, Enum):
|
||
daily = "daily"
|
||
hourly = "hourly"
|
||
monthly = "monthly"
|
||
task = "task"
|
||
|
||
|
||
class SourceType(str, Enum):
|
||
job_to_worker = "job_to_worker"
|
||
worker_to_job = "worker_to_job"
|
||
|
||
|
||
class Salary(BaseModel):
|
||
type: SalaryType = Field(default=SalaryType.daily, description="薪资类型:daily/hourly/monthly/task")
|
||
amount: float = Field(default=0, description="薪资金额")
|
||
currency: str = Field(default="CNY", description="货币类型,默认 CNY")
|
||
|
||
|
||
class SkillScore(BaseModel):
|
||
name: str = Field(description="技能名称")
|
||
score: float = Field(ge=0, le=1, description="技能熟练度,范围 0~1")
|
||
|
||
|
||
class JobCard(BaseModel):
|
||
job_id: str = Field(description="岗位唯一 ID")
|
||
title: str = Field(description="岗位标题")
|
||
category: str = Field(description="岗位类别")
|
||
description: str = Field(description="岗位描述")
|
||
skills: list[str] = Field(default_factory=list, description="岗位技能要求列表")
|
||
city: str = Field(description="城市")
|
||
region: str = Field(description="区域")
|
||
location_detail: str = Field(description="详细地点描述")
|
||
start_time: datetime = Field(description="岗位开始时间,ISO-8601")
|
||
duration_hours: float = Field(gt=0, description="工时(小时),必须大于 0")
|
||
headcount: int = Field(gt=0, description="招聘人数,必须大于 0")
|
||
salary: Salary = Field(description="薪资信息")
|
||
work_mode: str = Field(description="工作模式,如兼职、全职、活动")
|
||
tags: list[str] = Field(default_factory=list, description="业务标签列表")
|
||
confidence: float = Field(ge=0, le=1, description="数据置信度,范围 0~1")
|
||
|
||
@field_validator("start_time", mode="after")
|
||
@classmethod
|
||
def normalize_start_time(cls, value: datetime) -> datetime:
|
||
shanghai_tz = timezone(timedelta(hours=8))
|
||
if value.tzinfo is None:
|
||
value = value.replace(tzinfo=shanghai_tz)
|
||
else:
|
||
value = value.astimezone(shanghai_tz)
|
||
return value.replace(second=0, microsecond=0)
|
||
|
||
|
||
class WorkerCard(BaseModel):
|
||
worker_id: str = Field(description="工人唯一 ID")
|
||
name: str = Field(description="工人姓名或昵称")
|
||
description: str = Field(description="工人自我描述")
|
||
skills: list[SkillScore] = Field(default_factory=list, description="技能及熟练度列表")
|
||
cities: list[str] = Field(default_factory=list, description="可接单城市列表")
|
||
regions: list[str] = Field(default_factory=list, description="可接单区域列表")
|
||
availability: list[str] = Field(default_factory=list, description="可上岗时间描述")
|
||
experience_tags: list[str] = Field(default_factory=list, description="经验标签列表")
|
||
reliability_score: float = Field(ge=0, le=1, description="履约可靠性分,范围 0~1")
|
||
profile_completion: float = Field(ge=0, le=1, description="档案完善度,范围 0~1")
|
||
confidence: float = Field(ge=0, le=1, description="数据置信度,范围 0~1")
|
||
|
||
|
||
class MatchBreakdown(BaseModel):
|
||
skill_score: float = Field(ge=0, le=1, description="技能匹配分,范围 0~1")
|
||
region_score: float = Field(ge=0, le=1, description="地域匹配分,范围 0~1")
|
||
time_score: float = Field(ge=0, le=1, description="时间匹配分,范围 0~1")
|
||
experience_score: float = Field(ge=0, le=1, description="经验匹配分,范围 0~1")
|
||
reliability_score: float = Field(ge=0, le=1, description="可靠性匹配分,范围 0~1")
|
||
|
||
|
||
class MatchResult(BaseModel):
|
||
match_id: str = Field(description="匹配记录 ID")
|
||
source_type: SourceType = Field(description="匹配方向:job_to_worker 或 worker_to_job")
|
||
source_id: str = Field(description="源实体 ID")
|
||
target_id: str = Field(description="目标实体 ID")
|
||
match_score: float = Field(ge=0, le=1, description="综合匹配分,范围 0~1")
|
||
breakdown: MatchBreakdown = Field(description="多维打分拆解")
|
||
reasons: list[str] = Field(default_factory=list, min_length=3, description="匹配理由,至少 3 条")
|
||
|
||
|
||
class ExtractTextRequest(BaseModel):
|
||
text: str = Field(min_length=5, description="待抽取的自然语言文本,最少 5 个字符")
|
||
model_config = ConfigDict(
|
||
json_schema_extra={
|
||
"example": {
|
||
"text": "明天下午南山会展中心需要2个签到协助,5小时,150/人,女生优先",
|
||
}
|
||
}
|
||
)
|
||
|
||
|
||
class IngestJobRequest(BaseModel):
|
||
job: JobCard = Field(description="岗位卡片对象")
|
||
|
||
|
||
class IngestWorkerRequest(BaseModel):
|
||
worker: WorkerCard = Field(description="工人卡片对象")
|
||
|
||
|
||
class MatchWorkersRequest(BaseModel):
|
||
job_id: str | None = Field(default=None, description="岗位 ID(与 job 二选一)")
|
||
job: JobCard | None = Field(default=None, description="内联岗位对象(与 job_id 二选一)")
|
||
top_n: int = Field(default=10, ge=1, le=50, description="返回条数,范围 1~50")
|
||
|
||
@model_validator(mode="after")
|
||
def validate_source(self) -> "MatchWorkersRequest":
|
||
if not self.job_id and not self.job:
|
||
raise ValueError("job_id 或 job 至少需要提供一个")
|
||
return self
|
||
|
||
|
||
class MatchJobsRequest(BaseModel):
|
||
worker_id: str | None = Field(default=None, description="工人 ID(与 worker 二选一)")
|
||
worker: WorkerCard | None = Field(default=None, description="内联工人对象(与 worker_id 二选一)")
|
||
top_n: int = Field(default=10, ge=1, le=50, description="返回条数,范围 1~50")
|
||
|
||
@model_validator(mode="after")
|
||
def validate_source(self) -> "MatchJobsRequest":
|
||
if not self.worker_id and not self.worker:
|
||
raise ValueError("worker_id 或 worker 至少需要提供一个")
|
||
return self
|
||
|
||
|
||
class ExtractResponse(BaseModel):
|
||
success: bool = Field(description="抽取是否成功")
|
||
data: JobCard | WorkerCard | None = Field(default=None, description="抽取结果对象,可能为空")
|
||
errors: list[str] = Field(default_factory=list, description="错误信息列表")
|
||
missing_fields: list[str] = Field(default_factory=list, description="缺失字段列表")
|
||
|
||
|
||
class BootstrapResponse(BaseModel):
|
||
jobs: int = Field(description="导入岗位数量")
|
||
workers: int = Field(description="导入工人数量")
|
||
skills: int = Field(description="技能词条数量")
|
||
categories: int = Field(description="类目数量")
|
||
regions: int = Field(description="区域数量")
|
||
|
||
|
||
class HealthStatus(BaseModel):
|
||
service: str = Field(description="服务状态,通常为 ok")
|
||
database: str = Field(description="数据库状态:ok 或 error")
|
||
rag: str = Field(description="RAG 组件状态:ok 或 error")
|
||
timestamp: datetime = Field(description="服务端当前时间")
|
||
|
||
|
||
class ListResponse(BaseModel):
|
||
items: list[dict] = Field(description="列表项")
|
||
total: int = Field(description="总数")
|
||
|
||
|
||
class MatchResponse(BaseModel):
|
||
items: list[MatchResult] = Field(description="匹配结果列表")
|
||
|
||
|
||
class ExplainResponse(BaseModel):
|
||
match: MatchResult = Field(description="单条匹配结果详情")
|
||
|
||
|
||
class MatchFeedbackRequest(BaseModel):
|
||
match_id: str = Field(description="匹配记录 ID")
|
||
accepted: bool = Field(description="反馈是否接受该推荐")
|
||
|
||
|
||
class MatchWeightResponse(BaseModel):
|
||
weights: dict[str, float] = Field(description="当前生效的排序权重")
|
||
learning_enabled: bool = Field(description="是否开启在线学习")
|
||
|
||
|
||
class AIObservabilityResponse(BaseModel):
|
||
metrics: dict[str, float | int] = Field(description="AI 调用观测指标")
|
||
|
||
|
||
class IngestAsyncResponse(BaseModel):
|
||
task_id: str = Field(description="异步任务 ID")
|
||
status: str = Field(description="任务状态")
|
||
|
||
|
||
class QueueStatusResponse(BaseModel):
|
||
queued: int = Field(description="当前队列中任务数量")
|
||
processed: int = Field(description="历史处理成功数量")
|
||
failed: int = Field(description="历史处理失败数量")
|
||
|
||
|
||
class MatchAsyncWorkersRequest(BaseModel):
|
||
job_id: str = Field(description="岗位 ID")
|
||
top_n: int = Field(default=10, ge=1, le=50, description="返回条数,范围 1~50")
|
||
|
||
|
||
class MatchAsyncJobsRequest(BaseModel):
|
||
worker_id: str = Field(description="工人 ID")
|
||
top_n: int = Field(default=10, ge=1, le=50, description="返回条数,范围 1~50")
|
||
|
||
|
||
class MatchAsyncResponse(BaseModel):
|
||
task_id: str = Field(description="异步任务 ID")
|
||
status: str = Field(description="任务状态")
|
||
items: list[MatchResult] | None = Field(default=None, description="任务完成后返回的匹配结果")
|
||
|
||
|
||
class SystemOpsResponse(BaseModel):
|
||
traffic: dict[str, float | int] = Field(description="全局流量护栏与错误窗口指标")
|
||
cache: dict[str, float | int | str] = Field(description="缓存命中与大小")
|
||
ingest_queue: QueueStatusResponse = Field(description="异步入库队列状态")
|
||
match_queue: QueueStatusResponse = Field(description="异步匹配队列状态")
|
||
|
||
|
||
class PromptOutput(BaseModel):
|
||
content: dict
|
||
raw_text: str
|
||
|
||
|
||
class QueryFilters(BaseModel):
|
||
entity_type: str
|
||
city: str | None = None
|
||
region: str | None = None
|
||
categories: list[str] = Field(default_factory=list)
|
||
tags: list[str] = Field(default_factory=list)
|
||
skills: list[str] = Field(default_factory=list)
|
||
|
||
@field_validator("entity_type")
|
||
@classmethod
|
||
def validate_entity_type(cls, value: str) -> str:
|
||
if value not in {"job", "worker"}:
|
||
raise ValueError("entity_type must be job or worker")
|
||
return value
|