191 lines
7.7 KiB
Python
191 lines
7.7 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import math
|
||
import uuid
|
||
from collections import defaultdict
|
||
|
||
from qdrant_client import QdrantClient, models
|
||
|
||
from app.core.config import Settings
|
||
from app.core.logging import logger
|
||
from app.domain.schemas import JobCard, QueryFilters, WorkerCard
|
||
from app.services.llm_client import LLMClient
|
||
|
||
|
||
class LightRAGAdapter:
|
||
def __init__(self, settings: Settings):
|
||
self.settings = settings
|
||
self.client = QdrantClient(url=settings.qdrant_url)
|
||
self.skill_graph = self._load_skill_graph()
|
||
self.llm_client = LLMClient(settings)
|
||
self.collection_vector_size: int | None = None
|
||
|
||
def ensure_ready(self) -> None:
|
||
collections = {item.name for item in self.client.get_collections().collections}
|
||
expected_size = self._configured_vector_size()
|
||
if self.settings.qdrant_collection not in collections:
|
||
self.client.create_collection(
|
||
collection_name=self.settings.qdrant_collection,
|
||
vectors_config=models.VectorParams(size=expected_size, distance=models.Distance.COSINE),
|
||
)
|
||
self.collection_vector_size = expected_size
|
||
return
|
||
info = self.client.get_collection(self.settings.qdrant_collection)
|
||
configured_size = info.config.params.vectors.size
|
||
self.collection_vector_size = int(configured_size)
|
||
if self.collection_vector_size != expected_size:
|
||
logger.warning(
|
||
"qdrant vector size mismatch, collection=%s expected=%s actual=%s; using actual size",
|
||
self.settings.qdrant_collection,
|
||
expected_size,
|
||
self.collection_vector_size,
|
||
)
|
||
|
||
def health(self) -> str:
|
||
self.ensure_ready()
|
||
self.client.get_collection(self.settings.qdrant_collection)
|
||
return "ok"
|
||
|
||
def upsert_job(self, job: JobCard) -> None:
|
||
self.ensure_ready()
|
||
payload = {
|
||
"entity_type": "job",
|
||
"entity_id": job.job_id,
|
||
"city": job.city,
|
||
"region": job.region,
|
||
"category": job.category,
|
||
"skills": job.skills,
|
||
"tags": job.tags,
|
||
"document": self._serialize_job(job),
|
||
}
|
||
self.client.upsert(
|
||
collection_name=self.settings.qdrant_collection,
|
||
points=[
|
||
models.PointStruct(
|
||
id=self._point_id(job.job_id),
|
||
vector=self._vectorize(payload["document"]),
|
||
payload=payload,
|
||
)
|
||
],
|
||
)
|
||
|
||
def upsert_worker(self, worker: WorkerCard) -> None:
|
||
self.ensure_ready()
|
||
payload = {
|
||
"entity_type": "worker",
|
||
"entity_id": worker.worker_id,
|
||
"city": worker.cities[0] if worker.cities else "",
|
||
"region": worker.regions[0] if worker.regions else "",
|
||
"category": worker.experience_tags[0] if worker.experience_tags else "",
|
||
"skills": [item.name for item in worker.skills],
|
||
"tags": worker.experience_tags,
|
||
"document": self._serialize_worker(worker),
|
||
}
|
||
self.client.upsert(
|
||
collection_name=self.settings.qdrant_collection,
|
||
points=[
|
||
models.PointStruct(
|
||
id=self._point_id(worker.worker_id),
|
||
vector=self._vectorize(payload["document"]),
|
||
payload=payload,
|
||
)
|
||
],
|
||
)
|
||
|
||
def search(self, query_text: str, filters: QueryFilters, limit: int) -> list[str]:
|
||
self.ensure_ready()
|
||
must = [models.FieldCondition(key="entity_type", match=models.MatchValue(value=filters.entity_type))]
|
||
if filters.city:
|
||
must.append(models.FieldCondition(key="city", match=models.MatchValue(value=filters.city)))
|
||
query_filter = models.Filter(must=must)
|
||
results = self.client.search(
|
||
collection_name=self.settings.qdrant_collection,
|
||
query_vector=self._vectorize(query_text),
|
||
query_filter=query_filter,
|
||
limit=limit,
|
||
with_payload=True,
|
||
)
|
||
ids = []
|
||
for point in results:
|
||
payload = point.payload or {}
|
||
if filters.region and payload.get("region") != filters.region:
|
||
continue
|
||
ids.append(str(payload.get("entity_id", point.id)))
|
||
return ids
|
||
|
||
def expand_skills(self, skills: list[str]) -> set[str]:
|
||
expanded = set(skills)
|
||
for skill in skills:
|
||
expanded.update(self.skill_graph.get(skill, []))
|
||
return expanded
|
||
|
||
def _load_skill_graph(self) -> dict[str, set[str]]:
|
||
relations_path = self.settings.sample_data_dir / "skill_relations.json"
|
||
if not relations_path.exists():
|
||
return defaultdict(set)
|
||
data = json.loads(relations_path.read_text(encoding="utf-8"))
|
||
graph: dict[str, set[str]] = defaultdict(set)
|
||
for source, targets in data.items():
|
||
graph[source].update(targets)
|
||
for target in targets:
|
||
graph[target].add(source)
|
||
return graph
|
||
|
||
def _serialize_job(self, job: JobCard) -> str:
|
||
return " ".join([job.title, job.category, job.city, job.region, *job.skills, *job.tags, job.description])
|
||
|
||
def _serialize_worker(self, worker: WorkerCard) -> str:
|
||
return " ".join(
|
||
[worker.name, *worker.cities, *worker.regions, *[item.name for item in worker.skills], *worker.experience_tags, worker.description]
|
||
)
|
||
|
||
def _vectorize(self, text: str) -> list[float]:
|
||
if self.settings.embedding_enabled and self.settings.embedding_backend == "openai_compatible":
|
||
try:
|
||
embedding = self.llm_client.embedding(text)
|
||
if embedding:
|
||
return self._normalize_embedding(embedding)
|
||
except Exception:
|
||
logger.exception("embedding request failed, fallback to hash vector")
|
||
target_size = self._active_vector_size()
|
||
vector = [0.0 for _ in range(target_size)]
|
||
tokens = self._tokenize(text)
|
||
for token in tokens:
|
||
index = hash(token) % target_size
|
||
vector[index] += 1.0
|
||
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
|
||
return [item / norm for item in vector]
|
||
|
||
def _normalize_embedding(self, embedding: list[float]) -> list[float]:
|
||
target_size = self._active_vector_size()
|
||
vector = embedding[:target_size]
|
||
if len(vector) < target_size:
|
||
vector.extend([0.0] * (target_size - len(vector)))
|
||
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
|
||
return [item / norm for item in vector]
|
||
|
||
def _active_vector_size(self) -> int:
|
||
if self.collection_vector_size:
|
||
return self.collection_vector_size
|
||
return self._configured_vector_size()
|
||
|
||
def _configured_vector_size(self) -> int:
|
||
if self.settings.embedding_enabled and self.settings.embedding_backend == "openai_compatible":
|
||
return self.settings.embedding_vector_size
|
||
return self.settings.vector_size
|
||
|
||
def _tokenize(self, text: str) -> list[str]:
|
||
cleaned = [part.strip().lower() for part in text.replace(",", " ").replace("、", " ").replace("。", " ").split()]
|
||
tokens = [part for part in cleaned if part]
|
||
for size in (2, 3):
|
||
for index in range(max(len(text) - size + 1, 0)):
|
||
chunk = text[index : index + size].strip()
|
||
if chunk:
|
||
tokens.append(chunk)
|
||
return tokens
|
||
|
||
def _point_id(self, entity_id: str) -> str:
|
||
# Qdrant v1.14 accepts point IDs as UUID or unsigned int.
|
||
return str(uuid.uuid5(uuid.NAMESPACE_URL, entity_id))
|