Files
AI_A4000/video_worker/app/gpu_worker.py
2026-04-07 00:37:39 +08:00

94 lines
3.8 KiB
Python

import asyncio
import json
from datetime import datetime, timezone
from pathlib import Path
from app.model_router import ModelRouter
from app.task_manager import TaskManager
from app.utils.files import write_json
from app.utils.logger import build_logger
class GPUWorker:
def __init__(self, task_manager: TaskManager, router: ModelRouter, log_level: str = "INFO"):
self.task_manager = task_manager
self.router = router
self.log_level = log_level
self._runner: asyncio.Task | None = None
self._stopped = asyncio.Event()
self._stopped.clear()
self.logger = build_logger("gpu_worker", log_level=log_level)
async def start(self) -> None:
if self._runner and not self._runner.done():
return
self._runner = asyncio.create_task(self._run_loop(), name="gpu-worker-loop")
async def stop(self) -> None:
self._stopped.set()
if self._runner:
self._runner.cancel()
try:
await self._runner
except asyncio.CancelledError:
pass
async def _run_loop(self) -> None:
while not self._stopped.is_set():
task_id = await self.task_manager.queue.get()
try:
await self._process(task_id)
finally:
self.task_manager.queue.task_done()
async def _process(self, task_id: str) -> None:
task = self.task_manager.get_task_record(task_id)
req = task.request_json
backend = self.router.route(req["quality_mode"])
log_path = self.task_manager.build_log_path(task)
task_logger = build_logger(f"task.{task_id}", log_level=self.log_level, log_file=log_path)
try:
self.task_manager.mark_running(task_id, backend.backend_name, backend.model_name)
task_logger.info("Task started with backend=%s model=%s", backend.backend_name, backend.model_name)
await asyncio.to_thread(self.task_manager.mark_progress, task_id, 0.3)
result = await asyncio.to_thread(backend.generate, task_id, req, task.output_dir)
await asyncio.to_thread(self.task_manager.mark_progress, task_id, 0.8)
metadata_path = self.task_manager.build_metadata_path(task)
current = self.task_manager.get_task_record(task_id)
finished_at = datetime.now(timezone.utc).isoformat()
metadata = {
"task_id": task.task_id,
"backend": backend.backend_name,
"model_name": backend.model_name,
"prompt": req.get("prompt"),
"negative_prompt": req.get("negative_prompt"),
"seed": req.get("seed"),
"width": req.get("width"),
"height": req.get("height"),
"fps": req.get("fps"),
"steps": req.get("steps"),
"duration_sec": req.get("duration_sec"),
"status": "SUCCEEDED",
"created_at": task.created_at,
"started_at": current.started_at,
"finished_at": finished_at,
"video_path": result["video_path"],
}
await asyncio.to_thread(write_json, metadata_path, metadata)
self.task_manager.mark_succeeded(
task_id=task_id,
video_path=result["video_path"],
first_frame_path=result["first_frame_path"],
metadata_path=str(Path(metadata_path).resolve()),
log_path=str(Path(log_path).resolve()),
)
task_logger.info("Task succeeded: %s", json.dumps(result, ensure_ascii=False))
except Exception as exc:
task_logger.exception("Task failed")
self.task_manager.mark_failed(task_id, str(exc), log_path=str(Path(log_path).resolve()))