import json import sqlite3 from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, Iterator, Optional STATUS_PENDING = "PENDING" STATUS_RUNNING = "RUNNING" STATUS_SUCCEEDED = "SUCCEEDED" STATUS_FAILED = "FAILED" SCHEMA_VERSION = 2 def utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() @dataclass class TaskRecord: task_id: str status: str backend: Optional[str] model_name: Optional[str] request_json: Dict[str, Any] output_dir: str progress: float error_message: Optional[str] video_path: Optional[str] first_frame_path: Optional[str] metadata_path: Optional[str] log_path: Optional[str] created_at: str updated_at: str started_at: Optional[str] finished_at: Optional[str] class TaskStore: def __init__(self, sqlite_path: Path): self.sqlite_path = sqlite_path self.sqlite_path.parent.mkdir(parents=True, exist_ok=True) @contextmanager def conn(self) -> Iterator[sqlite3.Connection]: connection = sqlite3.connect(self.sqlite_path, check_same_thread=False) connection.row_factory = sqlite3.Row try: yield connection connection.commit() finally: connection.close() def migrate(self) -> None: with self.conn() as connection: connection.execute( """ CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at TEXT NOT NULL ) """ ) current = connection.execute("SELECT MAX(version) as v FROM schema_migrations").fetchone()["v"] current_version = int(current or 0) if current_version < 1: connection.execute( """ CREATE TABLE IF NOT EXISTS tasks ( task_id TEXT PRIMARY KEY, status TEXT NOT NULL, backend TEXT, model_name TEXT, request_json TEXT NOT NULL, output_dir TEXT NOT NULL, progress REAL NOT NULL DEFAULT 0, error_message TEXT, video_path TEXT, first_frame_path TEXT, metadata_path TEXT, log_path TEXT, created_at TEXT NOT NULL, updated_at TEXT NOT NULL, started_at TEXT, finished_at TEXT ) """ ) connection.execute( "INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)", (1, utc_now_iso()), ) if current_version < 2: connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)") connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at)") connection.execute( "INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)", (2, utc_now_iso()), ) def create_task(self, task_id: str, request_json: Dict[str, Any], output_dir: str) -> None: now = utc_now_iso() with self.conn() as connection: connection.execute( """ INSERT INTO tasks ( task_id, status, request_json, output_dir, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?) """, (task_id, STATUS_PENDING, json.dumps(request_json, ensure_ascii=False), output_dir, now, now), ) def get_task(self, task_id: str) -> Optional[TaskRecord]: with self.conn() as connection: row = connection.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone() if row is None: return None return TaskRecord( task_id=row["task_id"], status=row["status"], backend=row["backend"], model_name=row["model_name"], request_json=json.loads(row["request_json"]), output_dir=row["output_dir"], progress=float(row["progress"]), error_message=row["error_message"], video_path=row["video_path"], first_frame_path=row["first_frame_path"], metadata_path=row["metadata_path"], log_path=row["log_path"], created_at=row["created_at"], updated_at=row["updated_at"], started_at=row["started_at"], finished_at=row["finished_at"], ) def update_task(self, task_id: str, **fields: Any) -> None: if not fields: return fields["updated_at"] = utc_now_iso() keys = sorted(fields.keys()) assignments = ", ".join([f"{key} = ?" for key in keys]) values = [fields[key] for key in keys] values.append(task_id) with self.conn() as connection: connection.execute(f"UPDATE tasks SET {assignments} WHERE task_id = ?", values) def list_migrations(self) -> list[dict[str, Any]]: with self.conn() as connection: rows = connection.execute("SELECT version, applied_at FROM schema_migrations ORDER BY version ASC").fetchall() return [{"version": row["version"], "applied_at": row["applied_at"]} for row in rows]