Files
Airtep/gig-poc/apps/api/app/services/rag/lightrag_adapter.py
2026-04-01 14:19:25 +08:00

191 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import math
import uuid
from collections import defaultdict
from qdrant_client import QdrantClient, models
from app.core.config import Settings
from app.core.logging import logger
from app.domain.schemas import JobCard, QueryFilters, WorkerCard
from app.services.llm_client import LLMClient
class LightRAGAdapter:
def __init__(self, settings: Settings):
self.settings = settings
self.client = QdrantClient(url=settings.qdrant_url)
self.skill_graph = self._load_skill_graph()
self.llm_client = LLMClient(settings)
self.collection_vector_size: int | None = None
def ensure_ready(self) -> None:
collections = {item.name for item in self.client.get_collections().collections}
expected_size = self._configured_vector_size()
if self.settings.qdrant_collection not in collections:
self.client.create_collection(
collection_name=self.settings.qdrant_collection,
vectors_config=models.VectorParams(size=expected_size, distance=models.Distance.COSINE),
)
self.collection_vector_size = expected_size
return
info = self.client.get_collection(self.settings.qdrant_collection)
configured_size = info.config.params.vectors.size
self.collection_vector_size = int(configured_size)
if self.collection_vector_size != expected_size:
logger.warning(
"qdrant vector size mismatch, collection=%s expected=%s actual=%s; using actual size",
self.settings.qdrant_collection,
expected_size,
self.collection_vector_size,
)
def health(self) -> str:
self.ensure_ready()
self.client.get_collection(self.settings.qdrant_collection)
return "ok"
def upsert_job(self, job: JobCard) -> None:
self.ensure_ready()
payload = {
"entity_type": "job",
"entity_id": job.job_id,
"city": job.city,
"region": job.region,
"category": job.category,
"skills": job.skills,
"tags": job.tags,
"document": self._serialize_job(job),
}
self.client.upsert(
collection_name=self.settings.qdrant_collection,
points=[
models.PointStruct(
id=self._point_id(job.job_id),
vector=self._vectorize(payload["document"]),
payload=payload,
)
],
)
def upsert_worker(self, worker: WorkerCard) -> None:
self.ensure_ready()
payload = {
"entity_type": "worker",
"entity_id": worker.worker_id,
"city": worker.cities[0] if worker.cities else "",
"region": worker.regions[0] if worker.regions else "",
"category": worker.experience_tags[0] if worker.experience_tags else "",
"skills": [item.name for item in worker.skills],
"tags": worker.experience_tags,
"document": self._serialize_worker(worker),
}
self.client.upsert(
collection_name=self.settings.qdrant_collection,
points=[
models.PointStruct(
id=self._point_id(worker.worker_id),
vector=self._vectorize(payload["document"]),
payload=payload,
)
],
)
def search(self, query_text: str, filters: QueryFilters, limit: int) -> list[str]:
self.ensure_ready()
must = [models.FieldCondition(key="entity_type", match=models.MatchValue(value=filters.entity_type))]
if filters.city:
must.append(models.FieldCondition(key="city", match=models.MatchValue(value=filters.city)))
query_filter = models.Filter(must=must)
results = self.client.search(
collection_name=self.settings.qdrant_collection,
query_vector=self._vectorize(query_text),
query_filter=query_filter,
limit=limit,
with_payload=True,
)
ids = []
for point in results:
payload = point.payload or {}
if filters.region and payload.get("region") != filters.region:
continue
ids.append(str(payload.get("entity_id", point.id)))
return ids
def expand_skills(self, skills: list[str]) -> set[str]:
expanded = set(skills)
for skill in skills:
expanded.update(self.skill_graph.get(skill, []))
return expanded
def _load_skill_graph(self) -> dict[str, set[str]]:
relations_path = self.settings.sample_data_dir / "skill_relations.json"
if not relations_path.exists():
return defaultdict(set)
data = json.loads(relations_path.read_text(encoding="utf-8"))
graph: dict[str, set[str]] = defaultdict(set)
for source, targets in data.items():
graph[source].update(targets)
for target in targets:
graph[target].add(source)
return graph
def _serialize_job(self, job: JobCard) -> str:
return " ".join([job.title, job.category, job.city, job.region, *job.skills, *job.tags, job.description])
def _serialize_worker(self, worker: WorkerCard) -> str:
return " ".join(
[worker.name, *worker.cities, *worker.regions, *[item.name for item in worker.skills], *worker.experience_tags, worker.description]
)
def _vectorize(self, text: str) -> list[float]:
if self.settings.embedding_enabled and self.settings.embedding_backend == "openai_compatible":
try:
embedding = self.llm_client.embedding(text)
if embedding:
return self._normalize_embedding(embedding)
except Exception:
logger.exception("embedding request failed, fallback to hash vector")
target_size = self._active_vector_size()
vector = [0.0 for _ in range(target_size)]
tokens = self._tokenize(text)
for token in tokens:
index = hash(token) % target_size
vector[index] += 1.0
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
return [item / norm for item in vector]
def _normalize_embedding(self, embedding: list[float]) -> list[float]:
target_size = self._active_vector_size()
vector = embedding[:target_size]
if len(vector) < target_size:
vector.extend([0.0] * (target_size - len(vector)))
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
return [item / norm for item in vector]
def _active_vector_size(self) -> int:
if self.collection_vector_size:
return self.collection_vector_size
return self._configured_vector_size()
def _configured_vector_size(self) -> int:
if self.settings.embedding_enabled and self.settings.embedding_backend == "openai_compatible":
return self.settings.embedding_vector_size
return self.settings.vector_size
def _tokenize(self, text: str) -> list[str]:
cleaned = [part.strip().lower() for part in text.replace("", " ").replace("", " ").replace("", " ").split()]
tokens = [part for part in cleaned if part]
for size in (2, 3):
for index in range(max(len(text) - size + 1, 0)):
chunk = text[index : index + size].strip()
if chunk:
tokens.append(chunk)
return tokens
def _point_id(self, entity_id: str) -> str:
# Qdrant v1.14 accepts point IDs as UUID or unsigned int.
return str(uuid.uuid5(uuid.NAMESPACE_URL, entity_id))