106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from queue import Empty, Full, Queue
|
|
from threading import Event, Lock, Thread
|
|
from typing import Any
|
|
|
|
from app.core.config import Settings
|
|
from app.core.logging import logger
|
|
from app.db.session import SessionLocal
|
|
from app.domain.schemas import JobCard, WorkerCard
|
|
from app.services.ingest_service import IngestService
|
|
from app.utils.ids import generate_id
|
|
|
|
|
|
@dataclass
|
|
class QueueTask:
|
|
task_id: str
|
|
kind: str
|
|
payload: dict[str, Any]
|
|
|
|
|
|
class IngestQueue:
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
self.queue: Queue[QueueTask] = Queue(maxsize=settings.ingest_queue_max_size)
|
|
self._stop_event = Event()
|
|
self._thread: Thread | None = None
|
|
self._lock = Lock()
|
|
self._status: dict[str, str] = {}
|
|
self._processed = 0
|
|
self._failed = 0
|
|
|
|
def start(self) -> None:
|
|
if not self.settings.ingest_async_enabled:
|
|
return
|
|
if self._thread and self._thread.is_alive():
|
|
return
|
|
self._thread = Thread(target=self._run, daemon=True, name="ingest-queue-worker")
|
|
self._thread.start()
|
|
logger.info("ingest queue worker started")
|
|
|
|
def stop(self) -> None:
|
|
self._stop_event.set()
|
|
if self._thread and self._thread.is_alive():
|
|
self._thread.join(timeout=3)
|
|
|
|
def enqueue_job(self, card: JobCard) -> str:
|
|
return self._enqueue("job", card.model_dump(mode="json"))
|
|
|
|
def enqueue_worker(self, card: WorkerCard) -> str:
|
|
return self._enqueue("worker", card.model_dump(mode="json"))
|
|
|
|
def task_status(self, task_id: str) -> str:
|
|
with self._lock:
|
|
return self._status.get(task_id, "not_found")
|
|
|
|
def stats(self) -> dict[str, int]:
|
|
with self._lock:
|
|
return {
|
|
"queued": self.queue.qsize(),
|
|
"processed": self._processed,
|
|
"failed": self._failed,
|
|
}
|
|
|
|
def _enqueue(self, kind: str, payload: dict[str, Any]) -> str:
|
|
task_id = generate_id("queue")
|
|
task = QueueTask(task_id=task_id, kind=kind, payload=payload)
|
|
with self._lock:
|
|
self._status[task_id] = "queued"
|
|
try:
|
|
self.queue.put_nowait(task)
|
|
except Full as exc:
|
|
with self._lock:
|
|
self._status[task_id] = "rejected"
|
|
raise RuntimeError("ingest queue is full") from exc
|
|
return task_id
|
|
|
|
def _run(self) -> None:
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
task = self.queue.get(timeout=0.5)
|
|
except Empty:
|
|
continue
|
|
try:
|
|
with self._lock:
|
|
self._status[task.task_id] = "processing"
|
|
with SessionLocal() as db:
|
|
service = IngestService(db)
|
|
if task.kind == "job":
|
|
service.ingest_job(JobCard(**task.payload))
|
|
elif task.kind == "worker":
|
|
service.ingest_worker(WorkerCard(**task.payload))
|
|
else:
|
|
raise ValueError(f"unknown task kind {task.kind}")
|
|
with self._lock:
|
|
self._status[task.task_id] = "done"
|
|
self._processed += 1
|
|
except Exception:
|
|
logger.exception("ingest queue task failed task_id=%s kind=%s", task.task_id, task.kind)
|
|
with self._lock:
|
|
self._status[task.task_id] = "failed"
|
|
self._failed += 1
|
|
finally:
|
|
self.queue.task_done()
|