feat:优化架构
This commit is contained in:
68
engine/task_store.py
Normal file
68
engine/task_store.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _task_path(task_id: str, base_dir: str | Path = "./outputs") -> Path:
|
||||
return Path(base_dir) / str(task_id) / "task.json"
|
||||
|
||||
|
||||
def create_task(task_id: str, shots: list[dict[str, Any]], base_dir: str | Path = "./outputs") -> dict[str, Any]:
|
||||
p = _task_path(task_id, base_dir=base_dir)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"task_id": str(task_id),
|
||||
"status": "queued",
|
||||
"shots": [
|
||||
{
|
||||
"shot_id": str(s.get("shot_id", "")),
|
||||
"status": str(s.get("status", "pending") or "pending"),
|
||||
}
|
||||
for s in shots
|
||||
],
|
||||
}
|
||||
p.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
return data
|
||||
|
||||
|
||||
def load_task(task_id: str, base_dir: str | Path = "./outputs") -> dict[str, Any]:
|
||||
p = _task_path(task_id, base_dir=base_dir)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"task file not found: {p}")
|
||||
raw = json.loads(p.read_text(encoding="utf-8"))
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError("task.json must be an object")
|
||||
return raw
|
||||
|
||||
|
||||
def _save_task(task_id: str, data: dict[str, Any], base_dir: str | Path = "./outputs") -> None:
|
||||
p = _task_path(task_id, base_dir=base_dir)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def update_shot_status(task_id: str, shot_id: str, status: str, base_dir: str | Path = "./outputs") -> dict[str, Any]:
|
||||
data = load_task(task_id, base_dir=base_dir)
|
||||
shots = data.get("shots")
|
||||
if not isinstance(shots, list):
|
||||
raise ValueError("task.json shots must be an array")
|
||||
found = False
|
||||
for s in shots:
|
||||
if isinstance(s, dict) and str(s.get("shot_id", "")) == str(shot_id):
|
||||
s["status"] = str(status)
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
shots.append({"shot_id": str(shot_id), "status": str(status)})
|
||||
_save_task(task_id, data, base_dir=base_dir)
|
||||
return data
|
||||
|
||||
|
||||
def update_task_status(task_id: str, status: str, base_dir: str | Path = "./outputs") -> dict[str, Any]:
|
||||
data = load_task(task_id, base_dir=base_dir)
|
||||
data["status"] = str(status)
|
||||
_save_task(task_id, data, base_dir=base_dir)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user