103 lines
3.6 KiB
Python
103 lines
3.6 KiB
Python
import asyncio
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from uuid import uuid4
|
|
|
|
from app.schemas import GenerateRequest, TaskResultResponse, TaskStatusResponse
|
|
from app.task_store import TaskRecord, TaskStore
|
|
from app.utils.files import TASK_LOG_NAME, TASK_METADATA_NAME, ensure_dir, task_output_dir
|
|
|
|
|
|
class TaskManager:
|
|
def __init__(self, store: TaskStore, output_root: Path):
|
|
self.store = store
|
|
self.output_root = output_root
|
|
self.queue: asyncio.Queue[str] = asyncio.Queue()
|
|
|
|
async def create_task(self, req: GenerateRequest) -> TaskStatusResponse:
|
|
task_id = uuid4().hex
|
|
output_dir = task_output_dir(self.output_root, task_id)
|
|
ensure_dir(output_dir)
|
|
self.store.create_task(task_id=task_id, request_json=req.model_dump(), output_dir=str(output_dir.resolve()))
|
|
await self.queue.put(task_id)
|
|
return self.get_status(task_id)
|
|
|
|
def get_task_record(self, task_id: str) -> TaskRecord:
|
|
task = self.store.get_task(task_id)
|
|
if task is None:
|
|
raise KeyError(f"Task not found: {task_id}")
|
|
return task
|
|
|
|
def get_status(self, task_id: str) -> TaskStatusResponse:
|
|
task = self.get_task_record(task_id)
|
|
return TaskStatusResponse(
|
|
task_id=task.task_id,
|
|
status=task.status,
|
|
backend=task.backend,
|
|
model_name=task.model_name,
|
|
progress=task.progress,
|
|
created_at=datetime.fromisoformat(task.created_at),
|
|
updated_at=datetime.fromisoformat(task.updated_at),
|
|
)
|
|
|
|
def get_result(self, task_id: str) -> TaskResultResponse:
|
|
task = self.get_task_record(task_id)
|
|
return TaskResultResponse(
|
|
task_id=task.task_id,
|
|
status=task.status,
|
|
video_path=task.video_path,
|
|
first_frame_path=task.first_frame_path,
|
|
metadata_path=task.metadata_path,
|
|
log_path=task.log_path,
|
|
error=task.error_message,
|
|
)
|
|
|
|
def mark_running(self, task_id: str, backend: str, model_name: str) -> None:
|
|
self.store.update_task(
|
|
task_id,
|
|
status="RUNNING",
|
|
backend=backend,
|
|
model_name=model_name,
|
|
progress=0.1,
|
|
started_at=datetime.utcnow().isoformat(),
|
|
)
|
|
|
|
def mark_progress(self, task_id: str, progress: float) -> None:
|
|
self.store.update_task(task_id, progress=max(0.0, min(1.0, progress)))
|
|
|
|
def mark_succeeded(
|
|
self,
|
|
task_id: str,
|
|
video_path: str,
|
|
first_frame_path: str,
|
|
metadata_path: str,
|
|
log_path: str,
|
|
) -> None:
|
|
self.store.update_task(
|
|
task_id,
|
|
status="SUCCEEDED",
|
|
progress=1.0,
|
|
video_path=video_path,
|
|
first_frame_path=first_frame_path,
|
|
metadata_path=metadata_path,
|
|
log_path=log_path,
|
|
finished_at=datetime.utcnow().isoformat(),
|
|
)
|
|
|
|
def mark_failed(self, task_id: str, error_message: str, log_path: str | None = None) -> None:
|
|
updates = {
|
|
"status": "FAILED",
|
|
"progress": 1.0,
|
|
"error_message": error_message,
|
|
"finished_at": datetime.utcnow().isoformat(),
|
|
}
|
|
if log_path is not None:
|
|
updates["log_path"] = log_path
|
|
self.store.update_task(task_id, **updates)
|
|
|
|
def build_metadata_path(self, task: TaskRecord) -> Path:
|
|
return Path(task.output_dir) / TASK_METADATA_NAME
|
|
|
|
def build_log_path(self, task: TaskRecord) -> Path:
|
|
return Path(task.output_dir) / TASK_LOG_NAME
|