feat: 新增代码

This commit is contained in:
Daniel
2026-04-07 00:37:39 +08:00
commit 8d0b729f2f
29 changed files with 1768 additions and 0 deletions

View File

@@ -0,0 +1,156 @@
import json
import sqlite3
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Iterator, Optional
STATUS_PENDING = "PENDING"
STATUS_RUNNING = "RUNNING"
STATUS_SUCCEEDED = "SUCCEEDED"
STATUS_FAILED = "FAILED"
SCHEMA_VERSION = 2
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
@dataclass
class TaskRecord:
task_id: str
status: str
backend: Optional[str]
model_name: Optional[str]
request_json: Dict[str, Any]
output_dir: str
progress: float
error_message: Optional[str]
video_path: Optional[str]
first_frame_path: Optional[str]
metadata_path: Optional[str]
log_path: Optional[str]
created_at: str
updated_at: str
started_at: Optional[str]
finished_at: Optional[str]
class TaskStore:
def __init__(self, sqlite_path: Path):
self.sqlite_path = sqlite_path
self.sqlite_path.parent.mkdir(parents=True, exist_ok=True)
@contextmanager
def conn(self) -> Iterator[sqlite3.Connection]:
connection = sqlite3.connect(self.sqlite_path, check_same_thread=False)
connection.row_factory = sqlite3.Row
try:
yield connection
connection.commit()
finally:
connection.close()
def migrate(self) -> None:
with self.conn() as connection:
connection.execute(
"""
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TEXT NOT NULL
)
"""
)
current = connection.execute("SELECT MAX(version) as v FROM schema_migrations").fetchone()["v"]
current_version = int(current or 0)
if current_version < 1:
connection.execute(
"""
CREATE TABLE IF NOT EXISTS tasks (
task_id TEXT PRIMARY KEY,
status TEXT NOT NULL,
backend TEXT,
model_name TEXT,
request_json TEXT NOT NULL,
output_dir TEXT NOT NULL,
progress REAL NOT NULL DEFAULT 0,
error_message TEXT,
video_path TEXT,
first_frame_path TEXT,
metadata_path TEXT,
log_path TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
started_at TEXT,
finished_at TEXT
)
"""
)
connection.execute(
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
(1, utc_now_iso()),
)
if current_version < 2:
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)")
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at)")
connection.execute(
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
(2, utc_now_iso()),
)
def create_task(self, task_id: str, request_json: Dict[str, Any], output_dir: str) -> None:
now = utc_now_iso()
with self.conn() as connection:
connection.execute(
"""
INSERT INTO tasks (
task_id, status, request_json, output_dir, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?)
""",
(task_id, STATUS_PENDING, json.dumps(request_json, ensure_ascii=False), output_dir, now, now),
)
def get_task(self, task_id: str) -> Optional[TaskRecord]:
with self.conn() as connection:
row = connection.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
if row is None:
return None
return TaskRecord(
task_id=row["task_id"],
status=row["status"],
backend=row["backend"],
model_name=row["model_name"],
request_json=json.loads(row["request_json"]),
output_dir=row["output_dir"],
progress=float(row["progress"]),
error_message=row["error_message"],
video_path=row["video_path"],
first_frame_path=row["first_frame_path"],
metadata_path=row["metadata_path"],
log_path=row["log_path"],
created_at=row["created_at"],
updated_at=row["updated_at"],
started_at=row["started_at"],
finished_at=row["finished_at"],
)
def update_task(self, task_id: str, **fields: Any) -> None:
if not fields:
return
fields["updated_at"] = utc_now_iso()
keys = sorted(fields.keys())
assignments = ", ".join([f"{key} = ?" for key in keys])
values = [fields[key] for key in keys]
values.append(task_id)
with self.conn() as connection:
connection.execute(f"UPDATE tasks SET {assignments} WHERE task_id = ?", values)
def list_migrations(self) -> list[dict[str, Any]]:
with self.conn() as connection:
rows = connection.execute("SELECT version, applied_at FROM schema_migrations ORDER BY version ASC").fetchall()
return [{"version": row["version"], "applied_at": row["applied_at"]} for row in rows]