feat: 新增代码
This commit is contained in:
102
video_worker/app/task_manager.py
Normal file
102
video_worker/app/task_manager.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from app.schemas import GenerateRequest, TaskResultResponse, TaskStatusResponse
|
||||
from app.task_store import TaskRecord, TaskStore
|
||||
from app.utils.files import TASK_LOG_NAME, TASK_METADATA_NAME, ensure_dir, task_output_dir
|
||||
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self, store: TaskStore, output_root: Path):
|
||||
self.store = store
|
||||
self.output_root = output_root
|
||||
self.queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def create_task(self, req: GenerateRequest) -> TaskStatusResponse:
|
||||
task_id = uuid4().hex
|
||||
output_dir = task_output_dir(self.output_root, task_id)
|
||||
ensure_dir(output_dir)
|
||||
self.store.create_task(task_id=task_id, request_json=req.model_dump(), output_dir=str(output_dir.resolve()))
|
||||
await self.queue.put(task_id)
|
||||
return self.get_status(task_id)
|
||||
|
||||
def get_task_record(self, task_id: str) -> TaskRecord:
|
||||
task = self.store.get_task(task_id)
|
||||
if task is None:
|
||||
raise KeyError(f"Task not found: {task_id}")
|
||||
return task
|
||||
|
||||
def get_status(self, task_id: str) -> TaskStatusResponse:
|
||||
task = self.get_task_record(task_id)
|
||||
return TaskStatusResponse(
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
backend=task.backend,
|
||||
model_name=task.model_name,
|
||||
progress=task.progress,
|
||||
created_at=datetime.fromisoformat(task.created_at),
|
||||
updated_at=datetime.fromisoformat(task.updated_at),
|
||||
)
|
||||
|
||||
def get_result(self, task_id: str) -> TaskResultResponse:
|
||||
task = self.get_task_record(task_id)
|
||||
return TaskResultResponse(
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
video_path=task.video_path,
|
||||
first_frame_path=task.first_frame_path,
|
||||
metadata_path=task.metadata_path,
|
||||
log_path=task.log_path,
|
||||
error=task.error_message,
|
||||
)
|
||||
|
||||
def mark_running(self, task_id: str, backend: str, model_name: str) -> None:
|
||||
self.store.update_task(
|
||||
task_id,
|
||||
status="RUNNING",
|
||||
backend=backend,
|
||||
model_name=model_name,
|
||||
progress=0.1,
|
||||
started_at=datetime.utcnow().isoformat(),
|
||||
)
|
||||
|
||||
def mark_progress(self, task_id: str, progress: float) -> None:
|
||||
self.store.update_task(task_id, progress=max(0.0, min(1.0, progress)))
|
||||
|
||||
def mark_succeeded(
|
||||
self,
|
||||
task_id: str,
|
||||
video_path: str,
|
||||
first_frame_path: str,
|
||||
metadata_path: str,
|
||||
log_path: str,
|
||||
) -> None:
|
||||
self.store.update_task(
|
||||
task_id,
|
||||
status="SUCCEEDED",
|
||||
progress=1.0,
|
||||
video_path=video_path,
|
||||
first_frame_path=first_frame_path,
|
||||
metadata_path=metadata_path,
|
||||
log_path=log_path,
|
||||
finished_at=datetime.utcnow().isoformat(),
|
||||
)
|
||||
|
||||
def mark_failed(self, task_id: str, error_message: str, log_path: str | None = None) -> None:
|
||||
updates = {
|
||||
"status": "FAILED",
|
||||
"progress": 1.0,
|
||||
"error_message": error_message,
|
||||
"finished_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
if log_path is not None:
|
||||
updates["log_path"] = log_path
|
||||
self.store.update_task(task_id, **updates)
|
||||
|
||||
def build_metadata_path(self, task: TaskRecord) -> Path:
|
||||
return Path(task.output_dir) / TASK_METADATA_NAME
|
||||
|
||||
def build_log_path(self, task: TaskRecord) -> Path:
|
||||
return Path(task.output_dir) / TASK_LOG_NAME
|
||||
Reference in New Issue
Block a user