feat: 新增代码
This commit is contained in:
156
video_worker/app/task_store.py
Normal file
156
video_worker/app/task_store.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, Optional
|
||||
|
||||
|
||||
STATUS_PENDING = "PENDING"
|
||||
STATUS_RUNNING = "RUNNING"
|
||||
STATUS_SUCCEEDED = "SUCCEEDED"
|
||||
STATUS_FAILED = "FAILED"
|
||||
|
||||
SCHEMA_VERSION = 2
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRecord:
|
||||
task_id: str
|
||||
status: str
|
||||
backend: Optional[str]
|
||||
model_name: Optional[str]
|
||||
request_json: Dict[str, Any]
|
||||
output_dir: str
|
||||
progress: float
|
||||
error_message: Optional[str]
|
||||
video_path: Optional[str]
|
||||
first_frame_path: Optional[str]
|
||||
metadata_path: Optional[str]
|
||||
log_path: Optional[str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
started_at: Optional[str]
|
||||
finished_at: Optional[str]
|
||||
|
||||
|
||||
class TaskStore:
|
||||
def __init__(self, sqlite_path: Path):
|
||||
self.sqlite_path = sqlite_path
|
||||
self.sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@contextmanager
|
||||
def conn(self) -> Iterator[sqlite3.Connection]:
|
||||
connection = sqlite3.connect(self.sqlite_path, check_same_thread=False)
|
||||
connection.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield connection
|
||||
connection.commit()
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
def migrate(self) -> None:
|
||||
with self.conn() as connection:
|
||||
connection.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
current = connection.execute("SELECT MAX(version) as v FROM schema_migrations").fetchone()["v"]
|
||||
current_version = int(current or 0)
|
||||
|
||||
if current_version < 1:
|
||||
connection.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
status TEXT NOT NULL,
|
||||
backend TEXT,
|
||||
model_name TEXT,
|
||||
request_json TEXT NOT NULL,
|
||||
output_dir TEXT NOT NULL,
|
||||
progress REAL NOT NULL DEFAULT 0,
|
||||
error_message TEXT,
|
||||
video_path TEXT,
|
||||
first_frame_path TEXT,
|
||||
metadata_path TEXT,
|
||||
log_path TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
finished_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
connection.execute(
|
||||
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
|
||||
(1, utc_now_iso()),
|
||||
)
|
||||
|
||||
if current_version < 2:
|
||||
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)")
|
||||
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at)")
|
||||
connection.execute(
|
||||
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
|
||||
(2, utc_now_iso()),
|
||||
)
|
||||
|
||||
def create_task(self, task_id: str, request_json: Dict[str, Any], output_dir: str) -> None:
|
||||
now = utc_now_iso()
|
||||
with self.conn() as connection:
|
||||
connection.execute(
|
||||
"""
|
||||
INSERT INTO tasks (
|
||||
task_id, status, request_json, output_dir, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(task_id, STATUS_PENDING, json.dumps(request_json, ensure_ascii=False), output_dir, now, now),
|
||||
)
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[TaskRecord]:
|
||||
with self.conn() as connection:
|
||||
row = connection.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return TaskRecord(
|
||||
task_id=row["task_id"],
|
||||
status=row["status"],
|
||||
backend=row["backend"],
|
||||
model_name=row["model_name"],
|
||||
request_json=json.loads(row["request_json"]),
|
||||
output_dir=row["output_dir"],
|
||||
progress=float(row["progress"]),
|
||||
error_message=row["error_message"],
|
||||
video_path=row["video_path"],
|
||||
first_frame_path=row["first_frame_path"],
|
||||
metadata_path=row["metadata_path"],
|
||||
log_path=row["log_path"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
started_at=row["started_at"],
|
||||
finished_at=row["finished_at"],
|
||||
)
|
||||
|
||||
def update_task(self, task_id: str, **fields: Any) -> None:
|
||||
if not fields:
|
||||
return
|
||||
fields["updated_at"] = utc_now_iso()
|
||||
keys = sorted(fields.keys())
|
||||
assignments = ", ".join([f"{key} = ?" for key in keys])
|
||||
values = [fields[key] for key in keys]
|
||||
values.append(task_id)
|
||||
with self.conn() as connection:
|
||||
connection.execute(f"UPDATE tasks SET {assignments} WHERE task_id = ?", values)
|
||||
|
||||
def list_migrations(self) -> list[dict[str, Any]]:
|
||||
with self.conn() as connection:
|
||||
rows = connection.execute("SELECT version, applied_at FROM schema_migrations ORDER BY version ASC").fetchall()
|
||||
return [{"version": row["version"], "applied_at": row["applied_at"]} for row in rows]
|
||||
Reference in New Issue
Block a user