Files
AI_A4000/video_worker/app/task_store.py
2026-04-07 00:37:39 +08:00

157 lines
5.6 KiB
Python

import json
import sqlite3
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Iterator, Optional
STATUS_PENDING = "PENDING"
STATUS_RUNNING = "RUNNING"
STATUS_SUCCEEDED = "SUCCEEDED"
STATUS_FAILED = "FAILED"
SCHEMA_VERSION = 2
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
@dataclass
class TaskRecord:
task_id: str
status: str
backend: Optional[str]
model_name: Optional[str]
request_json: Dict[str, Any]
output_dir: str
progress: float
error_message: Optional[str]
video_path: Optional[str]
first_frame_path: Optional[str]
metadata_path: Optional[str]
log_path: Optional[str]
created_at: str
updated_at: str
started_at: Optional[str]
finished_at: Optional[str]
class TaskStore:
def __init__(self, sqlite_path: Path):
self.sqlite_path = sqlite_path
self.sqlite_path.parent.mkdir(parents=True, exist_ok=True)
@contextmanager
def conn(self) -> Iterator[sqlite3.Connection]:
connection = sqlite3.connect(self.sqlite_path, check_same_thread=False)
connection.row_factory = sqlite3.Row
try:
yield connection
connection.commit()
finally:
connection.close()
def migrate(self) -> None:
with self.conn() as connection:
connection.execute(
"""
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TEXT NOT NULL
)
"""
)
current = connection.execute("SELECT MAX(version) as v FROM schema_migrations").fetchone()["v"]
current_version = int(current or 0)
if current_version < 1:
connection.execute(
"""
CREATE TABLE IF NOT EXISTS tasks (
task_id TEXT PRIMARY KEY,
status TEXT NOT NULL,
backend TEXT,
model_name TEXT,
request_json TEXT NOT NULL,
output_dir TEXT NOT NULL,
progress REAL NOT NULL DEFAULT 0,
error_message TEXT,
video_path TEXT,
first_frame_path TEXT,
metadata_path TEXT,
log_path TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
started_at TEXT,
finished_at TEXT
)
"""
)
connection.execute(
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
(1, utc_now_iso()),
)
if current_version < 2:
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)")
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at)")
connection.execute(
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
(2, utc_now_iso()),
)
def create_task(self, task_id: str, request_json: Dict[str, Any], output_dir: str) -> None:
now = utc_now_iso()
with self.conn() as connection:
connection.execute(
"""
INSERT INTO tasks (
task_id, status, request_json, output_dir, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?)
""",
(task_id, STATUS_PENDING, json.dumps(request_json, ensure_ascii=False), output_dir, now, now),
)
def get_task(self, task_id: str) -> Optional[TaskRecord]:
with self.conn() as connection:
row = connection.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
if row is None:
return None
return TaskRecord(
task_id=row["task_id"],
status=row["status"],
backend=row["backend"],
model_name=row["model_name"],
request_json=json.loads(row["request_json"]),
output_dir=row["output_dir"],
progress=float(row["progress"]),
error_message=row["error_message"],
video_path=row["video_path"],
first_frame_path=row["first_frame_path"],
metadata_path=row["metadata_path"],
log_path=row["log_path"],
created_at=row["created_at"],
updated_at=row["updated_at"],
started_at=row["started_at"],
finished_at=row["finished_at"],
)
def update_task(self, task_id: str, **fields: Any) -> None:
if not fields:
return
fields["updated_at"] = utc_now_iso()
keys = sorted(fields.keys())
assignments = ", ".join([f"{key} = ?" for key in keys])
values = [fields[key] for key in keys]
values.append(task_id)
with self.conn() as connection:
connection.execute(f"UPDATE tasks SET {assignments} WHERE task_id = ?", values)
def list_migrations(self) -> list[dict[str, Any]]:
with self.conn() as connection:
rows = connection.execute("SELECT version, applied_at FROM schema_migrations ORDER BY version ASC").fetchall()
return [{"version": row["version"], "applied_at": row["applied_at"]} for row in rows]