Files
AI_A4000/video_worker/app/task_manager.py
2026-04-07 00:37:39 +08:00

103 lines
3.6 KiB
Python

import asyncio
from datetime import datetime
from pathlib import Path
from uuid import uuid4
from app.schemas import GenerateRequest, TaskResultResponse, TaskStatusResponse
from app.task_store import TaskRecord, TaskStore
from app.utils.files import TASK_LOG_NAME, TASK_METADATA_NAME, ensure_dir, task_output_dir
class TaskManager:
def __init__(self, store: TaskStore, output_root: Path):
self.store = store
self.output_root = output_root
self.queue: asyncio.Queue[str] = asyncio.Queue()
async def create_task(self, req: GenerateRequest) -> TaskStatusResponse:
task_id = uuid4().hex
output_dir = task_output_dir(self.output_root, task_id)
ensure_dir(output_dir)
self.store.create_task(task_id=task_id, request_json=req.model_dump(), output_dir=str(output_dir.resolve()))
await self.queue.put(task_id)
return self.get_status(task_id)
def get_task_record(self, task_id: str) -> TaskRecord:
task = self.store.get_task(task_id)
if task is None:
raise KeyError(f"Task not found: {task_id}")
return task
def get_status(self, task_id: str) -> TaskStatusResponse:
task = self.get_task_record(task_id)
return TaskStatusResponse(
task_id=task.task_id,
status=task.status,
backend=task.backend,
model_name=task.model_name,
progress=task.progress,
created_at=datetime.fromisoformat(task.created_at),
updated_at=datetime.fromisoformat(task.updated_at),
)
def get_result(self, task_id: str) -> TaskResultResponse:
task = self.get_task_record(task_id)
return TaskResultResponse(
task_id=task.task_id,
status=task.status,
video_path=task.video_path,
first_frame_path=task.first_frame_path,
metadata_path=task.metadata_path,
log_path=task.log_path,
error=task.error_message,
)
def mark_running(self, task_id: str, backend: str, model_name: str) -> None:
self.store.update_task(
task_id,
status="RUNNING",
backend=backend,
model_name=model_name,
progress=0.1,
started_at=datetime.utcnow().isoformat(),
)
def mark_progress(self, task_id: str, progress: float) -> None:
self.store.update_task(task_id, progress=max(0.0, min(1.0, progress)))
def mark_succeeded(
self,
task_id: str,
video_path: str,
first_frame_path: str,
metadata_path: str,
log_path: str,
) -> None:
self.store.update_task(
task_id,
status="SUCCEEDED",
progress=1.0,
video_path=video_path,
first_frame_path=first_frame_path,
metadata_path=metadata_path,
log_path=log_path,
finished_at=datetime.utcnow().isoformat(),
)
def mark_failed(self, task_id: str, error_message: str, log_path: str | None = None) -> None:
updates = {
"status": "FAILED",
"progress": 1.0,
"error_message": error_message,
"finished_at": datetime.utcnow().isoformat(),
}
if log_path is not None:
updates["log_path"] = log_path
self.store.update_task(task_id, **updates)
def build_metadata_path(self, task: TaskRecord) -> Path:
return Path(task.output_dir) / TASK_METADATA_NAME
def build_log_path(self, task: TaskRecord) -> Path:
return Path(task.output_dir) / TASK_LOG_NAME