feat: 新增代码
This commit is contained in:
93
video_worker/app/gpu_worker.py
Normal file
93
video_worker/app/gpu_worker.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from app.model_router import ModelRouter
|
||||
from app.task_manager import TaskManager
|
||||
from app.utils.files import write_json
|
||||
from app.utils.logger import build_logger
|
||||
|
||||
|
||||
class GPUWorker:
|
||||
def __init__(self, task_manager: TaskManager, router: ModelRouter, log_level: str = "INFO"):
|
||||
self.task_manager = task_manager
|
||||
self.router = router
|
||||
self.log_level = log_level
|
||||
self._runner: asyncio.Task | None = None
|
||||
self._stopped = asyncio.Event()
|
||||
self._stopped.clear()
|
||||
self.logger = build_logger("gpu_worker", log_level=log_level)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._runner and not self._runner.done():
|
||||
return
|
||||
self._runner = asyncio.create_task(self._run_loop(), name="gpu-worker-loop")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stopped.set()
|
||||
if self._runner:
|
||||
self._runner.cancel()
|
||||
try:
|
||||
await self._runner
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
while not self._stopped.is_set():
|
||||
task_id = await self.task_manager.queue.get()
|
||||
try:
|
||||
await self._process(task_id)
|
||||
finally:
|
||||
self.task_manager.queue.task_done()
|
||||
|
||||
async def _process(self, task_id: str) -> None:
|
||||
task = self.task_manager.get_task_record(task_id)
|
||||
req = task.request_json
|
||||
backend = self.router.route(req["quality_mode"])
|
||||
|
||||
log_path = self.task_manager.build_log_path(task)
|
||||
task_logger = build_logger(f"task.{task_id}", log_level=self.log_level, log_file=log_path)
|
||||
|
||||
try:
|
||||
self.task_manager.mark_running(task_id, backend.backend_name, backend.model_name)
|
||||
task_logger.info("Task started with backend=%s model=%s", backend.backend_name, backend.model_name)
|
||||
|
||||
await asyncio.to_thread(self.task_manager.mark_progress, task_id, 0.3)
|
||||
result = await asyncio.to_thread(backend.generate, task_id, req, task.output_dir)
|
||||
await asyncio.to_thread(self.task_manager.mark_progress, task_id, 0.8)
|
||||
|
||||
metadata_path = self.task_manager.build_metadata_path(task)
|
||||
current = self.task_manager.get_task_record(task_id)
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
metadata = {
|
||||
"task_id": task.task_id,
|
||||
"backend": backend.backend_name,
|
||||
"model_name": backend.model_name,
|
||||
"prompt": req.get("prompt"),
|
||||
"negative_prompt": req.get("negative_prompt"),
|
||||
"seed": req.get("seed"),
|
||||
"width": req.get("width"),
|
||||
"height": req.get("height"),
|
||||
"fps": req.get("fps"),
|
||||
"steps": req.get("steps"),
|
||||
"duration_sec": req.get("duration_sec"),
|
||||
"status": "SUCCEEDED",
|
||||
"created_at": task.created_at,
|
||||
"started_at": current.started_at,
|
||||
"finished_at": finished_at,
|
||||
"video_path": result["video_path"],
|
||||
}
|
||||
await asyncio.to_thread(write_json, metadata_path, metadata)
|
||||
|
||||
self.task_manager.mark_succeeded(
|
||||
task_id=task_id,
|
||||
video_path=result["video_path"],
|
||||
first_frame_path=result["first_frame_path"],
|
||||
metadata_path=str(Path(metadata_path).resolve()),
|
||||
log_path=str(Path(log_path).resolve()),
|
||||
)
|
||||
task_logger.info("Task succeeded: %s", json.dumps(result, ensure_ascii=False))
|
||||
except Exception as exc:
|
||||
task_logger.exception("Task failed")
|
||||
self.task_manager.mark_failed(task_id, str(exc), log_path=str(Path(log_path).resolve()))
|
||||
Reference in New Issue
Block a user