94 lines
3.8 KiB
Python
94 lines
3.8 KiB
Python
import asyncio
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from app.model_router import ModelRouter
|
|
from app.task_manager import TaskManager
|
|
from app.utils.files import write_json
|
|
from app.utils.logger import build_logger
|
|
|
|
|
|
class GPUWorker:
|
|
def __init__(self, task_manager: TaskManager, router: ModelRouter, log_level: str = "INFO"):
|
|
self.task_manager = task_manager
|
|
self.router = router
|
|
self.log_level = log_level
|
|
self._runner: asyncio.Task | None = None
|
|
self._stopped = asyncio.Event()
|
|
self._stopped.clear()
|
|
self.logger = build_logger("gpu_worker", log_level=log_level)
|
|
|
|
async def start(self) -> None:
|
|
if self._runner and not self._runner.done():
|
|
return
|
|
self._runner = asyncio.create_task(self._run_loop(), name="gpu-worker-loop")
|
|
|
|
async def stop(self) -> None:
|
|
self._stopped.set()
|
|
if self._runner:
|
|
self._runner.cancel()
|
|
try:
|
|
await self._runner
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def _run_loop(self) -> None:
|
|
while not self._stopped.is_set():
|
|
task_id = await self.task_manager.queue.get()
|
|
try:
|
|
await self._process(task_id)
|
|
finally:
|
|
self.task_manager.queue.task_done()
|
|
|
|
async def _process(self, task_id: str) -> None:
|
|
task = self.task_manager.get_task_record(task_id)
|
|
req = task.request_json
|
|
backend = self.router.route(req["quality_mode"])
|
|
|
|
log_path = self.task_manager.build_log_path(task)
|
|
task_logger = build_logger(f"task.{task_id}", log_level=self.log_level, log_file=log_path)
|
|
|
|
try:
|
|
self.task_manager.mark_running(task_id, backend.backend_name, backend.model_name)
|
|
task_logger.info("Task started with backend=%s model=%s", backend.backend_name, backend.model_name)
|
|
|
|
await asyncio.to_thread(self.task_manager.mark_progress, task_id, 0.3)
|
|
result = await asyncio.to_thread(backend.generate, task_id, req, task.output_dir)
|
|
await asyncio.to_thread(self.task_manager.mark_progress, task_id, 0.8)
|
|
|
|
metadata_path = self.task_manager.build_metadata_path(task)
|
|
current = self.task_manager.get_task_record(task_id)
|
|
finished_at = datetime.now(timezone.utc).isoformat()
|
|
metadata = {
|
|
"task_id": task.task_id,
|
|
"backend": backend.backend_name,
|
|
"model_name": backend.model_name,
|
|
"prompt": req.get("prompt"),
|
|
"negative_prompt": req.get("negative_prompt"),
|
|
"seed": req.get("seed"),
|
|
"width": req.get("width"),
|
|
"height": req.get("height"),
|
|
"fps": req.get("fps"),
|
|
"steps": req.get("steps"),
|
|
"duration_sec": req.get("duration_sec"),
|
|
"status": "SUCCEEDED",
|
|
"created_at": task.created_at,
|
|
"started_at": current.started_at,
|
|
"finished_at": finished_at,
|
|
"video_path": result["video_path"],
|
|
}
|
|
await asyncio.to_thread(write_json, metadata_path, metadata)
|
|
|
|
self.task_manager.mark_succeeded(
|
|
task_id=task_id,
|
|
video_path=result["video_path"],
|
|
first_frame_path=result["first_frame_path"],
|
|
metadata_path=str(Path(metadata_path).resolve()),
|
|
log_path=str(Path(log_path).resolve()),
|
|
)
|
|
task_logger.info("Task succeeded: %s", json.dumps(result, ensure_ascii=False))
|
|
except Exception as exc:
|
|
task_logger.exception("Task failed")
|
|
self.task_manager.mark_failed(task_id, str(exc), log_path=str(Path(log_path).resolve()))
|