157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
import json
|
|
import sqlite3
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterator, Optional
|
|
|
|
|
|
STATUS_PENDING = "PENDING"
|
|
STATUS_RUNNING = "RUNNING"
|
|
STATUS_SUCCEEDED = "SUCCEEDED"
|
|
STATUS_FAILED = "FAILED"
|
|
|
|
SCHEMA_VERSION = 2
|
|
|
|
|
|
def utc_now_iso() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
@dataclass
|
|
class TaskRecord:
|
|
task_id: str
|
|
status: str
|
|
backend: Optional[str]
|
|
model_name: Optional[str]
|
|
request_json: Dict[str, Any]
|
|
output_dir: str
|
|
progress: float
|
|
error_message: Optional[str]
|
|
video_path: Optional[str]
|
|
first_frame_path: Optional[str]
|
|
metadata_path: Optional[str]
|
|
log_path: Optional[str]
|
|
created_at: str
|
|
updated_at: str
|
|
started_at: Optional[str]
|
|
finished_at: Optional[str]
|
|
|
|
|
|
class TaskStore:
|
|
def __init__(self, sqlite_path: Path):
|
|
self.sqlite_path = sqlite_path
|
|
self.sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
@contextmanager
|
|
def conn(self) -> Iterator[sqlite3.Connection]:
|
|
connection = sqlite3.connect(self.sqlite_path, check_same_thread=False)
|
|
connection.row_factory = sqlite3.Row
|
|
try:
|
|
yield connection
|
|
connection.commit()
|
|
finally:
|
|
connection.close()
|
|
|
|
def migrate(self) -> None:
|
|
with self.conn() as connection:
|
|
connection.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version INTEGER PRIMARY KEY,
|
|
applied_at TEXT NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
current = connection.execute("SELECT MAX(version) as v FROM schema_migrations").fetchone()["v"]
|
|
current_version = int(current or 0)
|
|
|
|
if current_version < 1:
|
|
connection.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS tasks (
|
|
task_id TEXT PRIMARY KEY,
|
|
status TEXT NOT NULL,
|
|
backend TEXT,
|
|
model_name TEXT,
|
|
request_json TEXT NOT NULL,
|
|
output_dir TEXT NOT NULL,
|
|
progress REAL NOT NULL DEFAULT 0,
|
|
error_message TEXT,
|
|
video_path TEXT,
|
|
first_frame_path TEXT,
|
|
metadata_path TEXT,
|
|
log_path TEXT,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL,
|
|
started_at TEXT,
|
|
finished_at TEXT
|
|
)
|
|
"""
|
|
)
|
|
connection.execute(
|
|
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
|
|
(1, utc_now_iso()),
|
|
)
|
|
|
|
if current_version < 2:
|
|
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)")
|
|
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at)")
|
|
connection.execute(
|
|
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
|
|
(2, utc_now_iso()),
|
|
)
|
|
|
|
def create_task(self, task_id: str, request_json: Dict[str, Any], output_dir: str) -> None:
|
|
now = utc_now_iso()
|
|
with self.conn() as connection:
|
|
connection.execute(
|
|
"""
|
|
INSERT INTO tasks (
|
|
task_id, status, request_json, output_dir, created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(task_id, STATUS_PENDING, json.dumps(request_json, ensure_ascii=False), output_dir, now, now),
|
|
)
|
|
|
|
def get_task(self, task_id: str) -> Optional[TaskRecord]:
|
|
with self.conn() as connection:
|
|
row = connection.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
|
|
if row is None:
|
|
return None
|
|
return TaskRecord(
|
|
task_id=row["task_id"],
|
|
status=row["status"],
|
|
backend=row["backend"],
|
|
model_name=row["model_name"],
|
|
request_json=json.loads(row["request_json"]),
|
|
output_dir=row["output_dir"],
|
|
progress=float(row["progress"]),
|
|
error_message=row["error_message"],
|
|
video_path=row["video_path"],
|
|
first_frame_path=row["first_frame_path"],
|
|
metadata_path=row["metadata_path"],
|
|
log_path=row["log_path"],
|
|
created_at=row["created_at"],
|
|
updated_at=row["updated_at"],
|
|
started_at=row["started_at"],
|
|
finished_at=row["finished_at"],
|
|
)
|
|
|
|
def update_task(self, task_id: str, **fields: Any) -> None:
|
|
if not fields:
|
|
return
|
|
fields["updated_at"] = utc_now_iso()
|
|
keys = sorted(fields.keys())
|
|
assignments = ", ".join([f"{key} = ?" for key in keys])
|
|
values = [fields[key] for key in keys]
|
|
values.append(task_id)
|
|
with self.conn() as connection:
|
|
connection.execute(f"UPDATE tasks SET {assignments} WHERE task_id = ?", values)
|
|
|
|
def list_migrations(self) -> list[dict[str, Any]]:
|
|
with self.conn() as connection:
|
|
rows = connection.execute("SELECT version, applied_at FROM schema_migrations ORDER BY version ASC").fetchall()
|
|
return [{"version": row["version"], "applied_at": row["applied_at"]} for row in rows]
|