feat: add new folder
This commit is contained in:
143
gig-poc/apps/api/app/services/rag/lightrag_adapter.py
Normal file
143
gig-poc/apps/api/app/services/rag/lightrag_adapter.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
from app.core.config import Settings
|
||||
from app.core.logging import logger
|
||||
from app.domain.schemas import JobCard, QueryFilters, WorkerCard
|
||||
|
||||
|
||||
class LightRAGAdapter:
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.client = QdrantClient(url=settings.qdrant_url)
|
||||
self.skill_graph = self._load_skill_graph()
|
||||
|
||||
def ensure_ready(self) -> None:
|
||||
collections = {item.name for item in self.client.get_collections().collections}
|
||||
if self.settings.qdrant_collection not in collections:
|
||||
self.client.create_collection(
|
||||
collection_name=self.settings.qdrant_collection,
|
||||
vectors_config=models.VectorParams(size=self.settings.vector_size, distance=models.Distance.COSINE),
|
||||
)
|
||||
|
||||
def health(self) -> str:
|
||||
self.ensure_ready()
|
||||
self.client.get_collection(self.settings.qdrant_collection)
|
||||
return "ok"
|
||||
|
||||
def upsert_job(self, job: JobCard) -> None:
|
||||
self.ensure_ready()
|
||||
payload = {
|
||||
"entity_type": "job",
|
||||
"entity_id": job.job_id,
|
||||
"city": job.city,
|
||||
"region": job.region,
|
||||
"category": job.category,
|
||||
"skills": job.skills,
|
||||
"tags": job.tags,
|
||||
"document": self._serialize_job(job),
|
||||
}
|
||||
self.client.upsert(
|
||||
collection_name=self.settings.qdrant_collection,
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=job.job_id,
|
||||
vector=self._vectorize(payload["document"]),
|
||||
payload=payload,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def upsert_worker(self, worker: WorkerCard) -> None:
|
||||
self.ensure_ready()
|
||||
payload = {
|
||||
"entity_type": "worker",
|
||||
"entity_id": worker.worker_id,
|
||||
"city": worker.cities[0] if worker.cities else "",
|
||||
"region": worker.regions[0] if worker.regions else "",
|
||||
"category": worker.experience_tags[0] if worker.experience_tags else "",
|
||||
"skills": [item.name for item in worker.skills],
|
||||
"tags": worker.experience_tags,
|
||||
"document": self._serialize_worker(worker),
|
||||
}
|
||||
self.client.upsert(
|
||||
collection_name=self.settings.qdrant_collection,
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=worker.worker_id,
|
||||
vector=self._vectorize(payload["document"]),
|
||||
payload=payload,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def search(self, query_text: str, filters: QueryFilters, limit: int) -> list[str]:
|
||||
self.ensure_ready()
|
||||
must = [models.FieldCondition(key="entity_type", match=models.MatchValue(value=filters.entity_type))]
|
||||
if filters.city:
|
||||
must.append(models.FieldCondition(key="city", match=models.MatchValue(value=filters.city)))
|
||||
query_filter = models.Filter(must=must)
|
||||
results = self.client.search(
|
||||
collection_name=self.settings.qdrant_collection,
|
||||
query_vector=self._vectorize(query_text),
|
||||
query_filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
)
|
||||
ids = []
|
||||
for point in results:
|
||||
payload = point.payload or {}
|
||||
if filters.region and payload.get("region") != filters.region:
|
||||
continue
|
||||
ids.append(str(payload.get("entity_id", point.id)))
|
||||
return ids
|
||||
|
||||
def expand_skills(self, skills: list[str]) -> set[str]:
|
||||
expanded = set(skills)
|
||||
for skill in skills:
|
||||
expanded.update(self.skill_graph.get(skill, []))
|
||||
return expanded
|
||||
|
||||
def _load_skill_graph(self) -> dict[str, set[str]]:
|
||||
relations_path = self.settings.sample_data_dir / "skill_relations.json"
|
||||
if not relations_path.exists():
|
||||
return defaultdict(set)
|
||||
data = json.loads(relations_path.read_text(encoding="utf-8"))
|
||||
graph: dict[str, set[str]] = defaultdict(set)
|
||||
for source, targets in data.items():
|
||||
graph[source].update(targets)
|
||||
for target in targets:
|
||||
graph[target].add(source)
|
||||
return graph
|
||||
|
||||
def _serialize_job(self, job: JobCard) -> str:
|
||||
return " ".join([job.title, job.category, job.city, job.region, *job.skills, *job.tags, job.description])
|
||||
|
||||
def _serialize_worker(self, worker: WorkerCard) -> str:
|
||||
return " ".join(
|
||||
[worker.name, *worker.cities, *worker.regions, *[item.name for item in worker.skills], *worker.experience_tags, worker.description]
|
||||
)
|
||||
|
||||
def _vectorize(self, text: str) -> list[float]:
|
||||
vector = [0.0 for _ in range(self.settings.vector_size)]
|
||||
tokens = self._tokenize(text)
|
||||
for token in tokens:
|
||||
index = hash(token) % self.settings.vector_size
|
||||
vector[index] += 1.0
|
||||
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
|
||||
return [item / norm for item in vector]
|
||||
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
cleaned = [part.strip().lower() for part in text.replace(",", " ").replace("、", " ").replace("。", " ").split()]
|
||||
tokens = [part for part in cleaned if part]
|
||||
for size in (2, 3):
|
||||
for index in range(max(len(text) - size + 1, 0)):
|
||||
chunk = text[index : index + size].strip()
|
||||
if chunk:
|
||||
tokens.append(chunk)
|
||||
return tokens
|
||||
Reference in New Issue
Block a user