Files
Airtep/gig-poc/apps/api/app/services/rag/lightrag_adapter.py
Daniel c6fabe262c fix
2026-03-31 10:52:49 +08:00

149 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import math
import uuid
from collections import defaultdict
from qdrant_client import QdrantClient, models
from app.core.config import Settings
from app.core.logging import logger
from app.domain.schemas import JobCard, QueryFilters, WorkerCard
class LightRAGAdapter:
def __init__(self, settings: Settings):
self.settings = settings
self.client = QdrantClient(url=settings.qdrant_url)
self.skill_graph = self._load_skill_graph()
def ensure_ready(self) -> None:
collections = {item.name for item in self.client.get_collections().collections}
if self.settings.qdrant_collection not in collections:
self.client.create_collection(
collection_name=self.settings.qdrant_collection,
vectors_config=models.VectorParams(size=self.settings.vector_size, distance=models.Distance.COSINE),
)
def health(self) -> str:
self.ensure_ready()
self.client.get_collection(self.settings.qdrant_collection)
return "ok"
def upsert_job(self, job: JobCard) -> None:
self.ensure_ready()
payload = {
"entity_type": "job",
"entity_id": job.job_id,
"city": job.city,
"region": job.region,
"category": job.category,
"skills": job.skills,
"tags": job.tags,
"document": self._serialize_job(job),
}
self.client.upsert(
collection_name=self.settings.qdrant_collection,
points=[
models.PointStruct(
id=self._point_id(job.job_id),
vector=self._vectorize(payload["document"]),
payload=payload,
)
],
)
def upsert_worker(self, worker: WorkerCard) -> None:
self.ensure_ready()
payload = {
"entity_type": "worker",
"entity_id": worker.worker_id,
"city": worker.cities[0] if worker.cities else "",
"region": worker.regions[0] if worker.regions else "",
"category": worker.experience_tags[0] if worker.experience_tags else "",
"skills": [item.name for item in worker.skills],
"tags": worker.experience_tags,
"document": self._serialize_worker(worker),
}
self.client.upsert(
collection_name=self.settings.qdrant_collection,
points=[
models.PointStruct(
id=self._point_id(worker.worker_id),
vector=self._vectorize(payload["document"]),
payload=payload,
)
],
)
def search(self, query_text: str, filters: QueryFilters, limit: int) -> list[str]:
self.ensure_ready()
must = [models.FieldCondition(key="entity_type", match=models.MatchValue(value=filters.entity_type))]
if filters.city:
must.append(models.FieldCondition(key="city", match=models.MatchValue(value=filters.city)))
query_filter = models.Filter(must=must)
results = self.client.search(
collection_name=self.settings.qdrant_collection,
query_vector=self._vectorize(query_text),
query_filter=query_filter,
limit=limit,
with_payload=True,
)
ids = []
for point in results:
payload = point.payload or {}
if filters.region and payload.get("region") != filters.region:
continue
ids.append(str(payload.get("entity_id", point.id)))
return ids
def expand_skills(self, skills: list[str]) -> set[str]:
expanded = set(skills)
for skill in skills:
expanded.update(self.skill_graph.get(skill, []))
return expanded
def _load_skill_graph(self) -> dict[str, set[str]]:
relations_path = self.settings.sample_data_dir / "skill_relations.json"
if not relations_path.exists():
return defaultdict(set)
data = json.loads(relations_path.read_text(encoding="utf-8"))
graph: dict[str, set[str]] = defaultdict(set)
for source, targets in data.items():
graph[source].update(targets)
for target in targets:
graph[target].add(source)
return graph
def _serialize_job(self, job: JobCard) -> str:
return " ".join([job.title, job.category, job.city, job.region, *job.skills, *job.tags, job.description])
def _serialize_worker(self, worker: WorkerCard) -> str:
return " ".join(
[worker.name, *worker.cities, *worker.regions, *[item.name for item in worker.skills], *worker.experience_tags, worker.description]
)
def _vectorize(self, text: str) -> list[float]:
vector = [0.0 for _ in range(self.settings.vector_size)]
tokens = self._tokenize(text)
for token in tokens:
index = hash(token) % self.settings.vector_size
vector[index] += 1.0
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
return [item / norm for item in vector]
def _tokenize(self, text: str) -> list[str]:
cleaned = [part.strip().lower() for part in text.replace("", " ").replace("", " ").replace("", " ").split()]
tokens = [part for part in cleaned if part]
for size in (2, 3):
for index in range(max(len(text) - size + 1, 0)):
chunk = text[index : index + size].strip()
if chunk:
tokens.append(chunk)
return tokens
def _point_id(self, entity_id: str) -> str:
# Qdrant v1.14 accepts point IDs as UUID or unsigned int.
return str(uuid.uuid5(uuid.NAMESPACE_URL, entity_id))