from __future__ import annotations import json import math import uuid from collections import defaultdict from qdrant_client import QdrantClient, models from app.core.config import Settings from app.core.logging import logger from app.domain.schemas import JobCard, QueryFilters, WorkerCard class LightRAGAdapter: def __init__(self, settings: Settings): self.settings = settings self.client = QdrantClient(url=settings.qdrant_url) self.skill_graph = self._load_skill_graph() def ensure_ready(self) -> None: collections = {item.name for item in self.client.get_collections().collections} if self.settings.qdrant_collection not in collections: self.client.create_collection( collection_name=self.settings.qdrant_collection, vectors_config=models.VectorParams(size=self.settings.vector_size, distance=models.Distance.COSINE), ) def health(self) -> str: self.ensure_ready() self.client.get_collection(self.settings.qdrant_collection) return "ok" def upsert_job(self, job: JobCard) -> None: self.ensure_ready() payload = { "entity_type": "job", "entity_id": job.job_id, "city": job.city, "region": job.region, "category": job.category, "skills": job.skills, "tags": job.tags, "document": self._serialize_job(job), } self.client.upsert( collection_name=self.settings.qdrant_collection, points=[ models.PointStruct( id=self._point_id(job.job_id), vector=self._vectorize(payload["document"]), payload=payload, ) ], ) def upsert_worker(self, worker: WorkerCard) -> None: self.ensure_ready() payload = { "entity_type": "worker", "entity_id": worker.worker_id, "city": worker.cities[0] if worker.cities else "", "region": worker.regions[0] if worker.regions else "", "category": worker.experience_tags[0] if worker.experience_tags else "", "skills": [item.name for item in worker.skills], "tags": worker.experience_tags, "document": self._serialize_worker(worker), } self.client.upsert( collection_name=self.settings.qdrant_collection, points=[ models.PointStruct( id=self._point_id(worker.worker_id), vector=self._vectorize(payload["document"]), payload=payload, ) ], ) def search(self, query_text: str, filters: QueryFilters, limit: int) -> list[str]: self.ensure_ready() must = [models.FieldCondition(key="entity_type", match=models.MatchValue(value=filters.entity_type))] if filters.city: must.append(models.FieldCondition(key="city", match=models.MatchValue(value=filters.city))) query_filter = models.Filter(must=must) results = self.client.search( collection_name=self.settings.qdrant_collection, query_vector=self._vectorize(query_text), query_filter=query_filter, limit=limit, with_payload=True, ) ids = [] for point in results: payload = point.payload or {} if filters.region and payload.get("region") != filters.region: continue ids.append(str(payload.get("entity_id", point.id))) return ids def expand_skills(self, skills: list[str]) -> set[str]: expanded = set(skills) for skill in skills: expanded.update(self.skill_graph.get(skill, [])) return expanded def _load_skill_graph(self) -> dict[str, set[str]]: relations_path = self.settings.sample_data_dir / "skill_relations.json" if not relations_path.exists(): return defaultdict(set) data = json.loads(relations_path.read_text(encoding="utf-8")) graph: dict[str, set[str]] = defaultdict(set) for source, targets in data.items(): graph[source].update(targets) for target in targets: graph[target].add(source) return graph def _serialize_job(self, job: JobCard) -> str: return " ".join([job.title, job.category, job.city, job.region, *job.skills, *job.tags, job.description]) def _serialize_worker(self, worker: WorkerCard) -> str: return " ".join( [worker.name, *worker.cities, *worker.regions, *[item.name for item in worker.skills], *worker.experience_tags, worker.description] ) def _vectorize(self, text: str) -> list[float]: vector = [0.0 for _ in range(self.settings.vector_size)] tokens = self._tokenize(text) for token in tokens: index = hash(token) % self.settings.vector_size vector[index] += 1.0 norm = math.sqrt(sum(item * item for item in vector)) or 1.0 return [item / norm for item in vector] def _tokenize(self, text: str) -> list[str]: cleaned = [part.strip().lower() for part in text.replace(",", " ").replace("、", " ").replace("。", " ").split()] tokens = [part for part in cleaned if part] for size in (2, 3): for index in range(max(len(text) - size + 1, 0)): chunk = text[index : index + size].strip() if chunk: tokens.append(chunk) return tokens def _point_id(self, entity_id: str) -> str: # Qdrant v1.14 accepts point IDs as UUID or unsigned int. return str(uuid.uuid5(uuid.NAMESPACE_URL, entity_id))