feat: 新增代码
This commit is contained in:
0
video_worker/app/__init__.py
Normal file
0
video_worker/app/__init__.py
Normal file
48
video_worker/app/api.py
Normal file
48
video_worker/app/api.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas import GenerateRequest, HealthResponse, TaskResultResponse, TaskStatusResponse
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/generate", response_model=TaskStatusResponse)
|
||||
async def create_generate_task(req: GenerateRequest):
|
||||
task_manager = router.task_manager
|
||||
return await task_manager.create_task(req)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
|
||||
def get_task_status(task_id: str):
|
||||
task_manager = router.task_manager
|
||||
try:
|
||||
return task_manager.get_status(task_id)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}/result", response_model=TaskResultResponse)
|
||||
def get_task_result(task_id: str):
|
||||
task_manager = router.task_manager
|
||||
try:
|
||||
return task_manager.get_result(task_id)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
def health_check():
|
||||
torch = router.torch
|
||||
ltx_backend = router.ltx_backend
|
||||
hunyuan_backend = router.hunyuan_backend
|
||||
|
||||
cuda_ok = bool(torch.cuda.is_available())
|
||||
gpu_name = torch.cuda.get_device_name(0) if cuda_ok else None
|
||||
|
||||
return HealthResponse(
|
||||
service_status="ok",
|
||||
cuda_available=cuda_ok,
|
||||
gpu_name=gpu_name,
|
||||
ltx_loaded=ltx_backend.is_loaded(),
|
||||
hunyuan_loaded=hunyuan_backend.is_loaded(),
|
||||
)
|
||||
0
video_worker/app/backends/__init__.py
Normal file
0
video_worker/app/backends/__init__.py
Normal file
19
video_worker/app/backends/base.py
Normal file
19
video_worker/app/backends/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class BaseVideoBackend(ABC):
|
||||
backend_name: str
|
||||
model_name: str
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_loaded(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, task_id: str, request_data: Dict[str, Any], output_dir: str) -> Dict[str, str]:
|
||||
raise NotImplementedError
|
||||
59
video_worker/app/backends/hunyuan_backend.py
Normal file
59
video_worker/app/backends/hunyuan_backend.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.backends.base import BaseVideoBackend
|
||||
from app.utils.ffmpeg_utils import extract_first_frame, frames_to_video
|
||||
from app.utils.files import TASK_FIRST_FRAME_NAME, TASK_VIDEO_NAME
|
||||
from app.utils.image_utils import make_dummy_frame
|
||||
|
||||
|
||||
class HunyuanBackend(BaseVideoBackend):
|
||||
backend_name = "hunyuan_backend"
|
||||
model_name = "HunyuanVideo-1.5"
|
||||
|
||||
def __init__(self, model_dir: Path, enable_cpu_offload: bool = True, enable_vae_tiling: bool = True):
|
||||
self.model_dir = model_dir
|
||||
self.enable_cpu_offload = enable_cpu_offload
|
||||
self.enable_vae_tiling = enable_vae_tiling
|
||||
self._loaded = False
|
||||
self._pipeline = None
|
||||
|
||||
def load(self) -> None:
|
||||
if self._loaded:
|
||||
return
|
||||
# TODO: Replace with real HunyuanVideo loading and memory optimization hooks.
|
||||
# Example hooks: self._pipeline.enable_model_cpu_offload(), self._pipeline.vae.enable_tiling()
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._pipeline = "hunyuan_pipeline_placeholder"
|
||||
self._loaded = True
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
|
||||
def generate(self, task_id: str, request_data: Dict[str, Any], output_dir: str) -> Dict[str, str]:
|
||||
self.load()
|
||||
output = Path(output_dir)
|
||||
frames_dir = output / "frames"
|
||||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
duration = int(request_data["duration_sec"])
|
||||
fps = int(request_data["fps"])
|
||||
width = int(request_data["width"])
|
||||
height = int(request_data["height"])
|
||||
prompt = request_data["prompt"]
|
||||
|
||||
total_frames = duration * fps
|
||||
for i in range(total_frames):
|
||||
frame_path = frames_dir / f"frame_{i:04d}.jpg"
|
||||
make_dummy_frame(frame_path, width, height, f"Hunyuan refine | {prompt[:60]}", i)
|
||||
|
||||
video_path = output / TASK_VIDEO_NAME
|
||||
frames_to_video(str(frames_dir / "frame_%04d.jpg"), fps, video_path)
|
||||
|
||||
first_frame_path = output / TASK_FIRST_FRAME_NAME
|
||||
extract_first_frame(video_path, first_frame_path)
|
||||
|
||||
return {
|
||||
"video_path": str(video_path.resolve()),
|
||||
"first_frame_path": str(first_frame_path.resolve()),
|
||||
}
|
||||
56
video_worker/app/backends/ltx_backend.py
Normal file
56
video_worker/app/backends/ltx_backend.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.backends.base import BaseVideoBackend
|
||||
from app.utils.ffmpeg_utils import extract_first_frame, frames_to_video
|
||||
from app.utils.files import TASK_FIRST_FRAME_NAME, TASK_VIDEO_NAME
|
||||
from app.utils.image_utils import make_dummy_frame
|
||||
|
||||
|
||||
class LTXBackend(BaseVideoBackend):
|
||||
backend_name = "ltx_backend"
|
||||
model_name = "LTX-Video"
|
||||
|
||||
def __init__(self, model_dir: Path):
|
||||
self.model_dir = model_dir
|
||||
self._loaded = False
|
||||
self._pipeline = None
|
||||
|
||||
def load(self) -> None:
|
||||
if self._loaded:
|
||||
return
|
||||
# TODO: Replace with real LTX loading, e.g. DiffusionPipeline.from_pretrained(...)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._pipeline = "ltx_pipeline_placeholder"
|
||||
self._loaded = True
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
|
||||
def generate(self, task_id: str, request_data: Dict[str, Any], output_dir: str) -> Dict[str, str]:
|
||||
self.load()
|
||||
output = Path(output_dir)
|
||||
frames_dir = output / "frames"
|
||||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
duration = int(request_data["duration_sec"])
|
||||
fps = int(request_data["fps"])
|
||||
width = int(request_data["width"])
|
||||
height = int(request_data["height"])
|
||||
prompt = request_data["prompt"]
|
||||
|
||||
total_frames = duration * fps
|
||||
for i in range(total_frames):
|
||||
frame_path = frames_dir / f"frame_{i:04d}.jpg"
|
||||
make_dummy_frame(frame_path, width, height, f"LTX preview | {prompt[:60]}", i)
|
||||
|
||||
video_path = output / TASK_VIDEO_NAME
|
||||
frames_to_video(str(frames_dir / "frame_%04d.jpg"), fps, video_path)
|
||||
|
||||
first_frame_path = output / TASK_FIRST_FRAME_NAME
|
||||
extract_first_frame(video_path, first_frame_path)
|
||||
|
||||
return {
|
||||
"video_path": str(video_path.resolve()),
|
||||
"first_frame_path": str(first_frame_path.resolve()),
|
||||
}
|
||||
93
video_worker/app/gpu_worker.py
Normal file
93
video_worker/app/gpu_worker.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from app.model_router import ModelRouter
|
||||
from app.task_manager import TaskManager
|
||||
from app.utils.files import write_json
|
||||
from app.utils.logger import build_logger
|
||||
|
||||
|
||||
class GPUWorker:
|
||||
def __init__(self, task_manager: TaskManager, router: ModelRouter, log_level: str = "INFO"):
|
||||
self.task_manager = task_manager
|
||||
self.router = router
|
||||
self.log_level = log_level
|
||||
self._runner: asyncio.Task | None = None
|
||||
self._stopped = asyncio.Event()
|
||||
self._stopped.clear()
|
||||
self.logger = build_logger("gpu_worker", log_level=log_level)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._runner and not self._runner.done():
|
||||
return
|
||||
self._runner = asyncio.create_task(self._run_loop(), name="gpu-worker-loop")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stopped.set()
|
||||
if self._runner:
|
||||
self._runner.cancel()
|
||||
try:
|
||||
await self._runner
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
while not self._stopped.is_set():
|
||||
task_id = await self.task_manager.queue.get()
|
||||
try:
|
||||
await self._process(task_id)
|
||||
finally:
|
||||
self.task_manager.queue.task_done()
|
||||
|
||||
async def _process(self, task_id: str) -> None:
|
||||
task = self.task_manager.get_task_record(task_id)
|
||||
req = task.request_json
|
||||
backend = self.router.route(req["quality_mode"])
|
||||
|
||||
log_path = self.task_manager.build_log_path(task)
|
||||
task_logger = build_logger(f"task.{task_id}", log_level=self.log_level, log_file=log_path)
|
||||
|
||||
try:
|
||||
self.task_manager.mark_running(task_id, backend.backend_name, backend.model_name)
|
||||
task_logger.info("Task started with backend=%s model=%s", backend.backend_name, backend.model_name)
|
||||
|
||||
await asyncio.to_thread(self.task_manager.mark_progress, task_id, 0.3)
|
||||
result = await asyncio.to_thread(backend.generate, task_id, req, task.output_dir)
|
||||
await asyncio.to_thread(self.task_manager.mark_progress, task_id, 0.8)
|
||||
|
||||
metadata_path = self.task_manager.build_metadata_path(task)
|
||||
current = self.task_manager.get_task_record(task_id)
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
metadata = {
|
||||
"task_id": task.task_id,
|
||||
"backend": backend.backend_name,
|
||||
"model_name": backend.model_name,
|
||||
"prompt": req.get("prompt"),
|
||||
"negative_prompt": req.get("negative_prompt"),
|
||||
"seed": req.get("seed"),
|
||||
"width": req.get("width"),
|
||||
"height": req.get("height"),
|
||||
"fps": req.get("fps"),
|
||||
"steps": req.get("steps"),
|
||||
"duration_sec": req.get("duration_sec"),
|
||||
"status": "SUCCEEDED",
|
||||
"created_at": task.created_at,
|
||||
"started_at": current.started_at,
|
||||
"finished_at": finished_at,
|
||||
"video_path": result["video_path"],
|
||||
}
|
||||
await asyncio.to_thread(write_json, metadata_path, metadata)
|
||||
|
||||
self.task_manager.mark_succeeded(
|
||||
task_id=task_id,
|
||||
video_path=result["video_path"],
|
||||
first_frame_path=result["first_frame_path"],
|
||||
metadata_path=str(Path(metadata_path).resolve()),
|
||||
log_path=str(Path(log_path).resolve()),
|
||||
)
|
||||
task_logger.info("Task succeeded: %s", json.dumps(result, ensure_ascii=False))
|
||||
except Exception as exc:
|
||||
task_logger.exception("Task failed")
|
||||
self.task_manager.mark_failed(task_id, str(exc), log_path=str(Path(log_path).resolve()))
|
||||
53
video_worker/app/main.py
Normal file
53
video_worker/app/main.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.api import router
|
||||
from app.backends.hunyuan_backend import HunyuanBackend
|
||||
from app.backends.ltx_backend import LTXBackend
|
||||
from app.gpu_worker import GPUWorker
|
||||
from app.model_router import ModelRouter
|
||||
from app.settings import settings
|
||||
from app.task_manager import TaskManager
|
||||
from app.task_store import TaskStore
|
||||
from app.utils.files import ensure_dir
|
||||
from app.utils.logger import build_logger
|
||||
|
||||
|
||||
def build_app() -> FastAPI:
|
||||
logger = build_logger("video_worker", settings.log_level)
|
||||
|
||||
ensure_dir(settings.output_dir)
|
||||
ensure_dir(settings.runtime_dir)
|
||||
ensure_dir(settings.runtime_dir / "logs")
|
||||
|
||||
store = TaskStore(settings.sqlite_path)
|
||||
store.migrate()
|
||||
|
||||
ltx_backend = LTXBackend(settings.ltx_model_dir)
|
||||
hunyuan_backend = HunyuanBackend(settings.hunyuan_model_dir)
|
||||
model_router = ModelRouter(ltx_backend, hunyuan_backend)
|
||||
task_manager = TaskManager(store=store, output_root=settings.output_dir)
|
||||
gpu_worker = GPUWorker(task_manager, model_router, log_level=settings.log_level)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
logger.info("Starting GPU worker")
|
||||
await gpu_worker.start()
|
||||
yield
|
||||
logger.info("Stopping GPU worker")
|
||||
await gpu_worker.stop()
|
||||
|
||||
app = FastAPI(title="Local Video Worker", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
router.task_manager = task_manager
|
||||
router.ltx_backend = ltx_backend
|
||||
router.hunyuan_backend = hunyuan_backend
|
||||
router.torch = torch
|
||||
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
|
||||
app = build_app()
|
||||
15
video_worker/app/model_router.py
Normal file
15
video_worker/app/model_router.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from app.backends.hunyuan_backend import HunyuanBackend
|
||||
from app.backends.ltx_backend import LTXBackend
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
def __init__(self, ltx_backend: LTXBackend, hunyuan_backend: HunyuanBackend):
|
||||
self._ltx = ltx_backend
|
||||
self._hunyuan = hunyuan_backend
|
||||
|
||||
def route(self, quality_mode: str):
|
||||
if quality_mode == "preview":
|
||||
return self._ltx
|
||||
if quality_mode == "refine":
|
||||
return self._hunyuan
|
||||
raise ValueError(f"Unsupported quality_mode: {quality_mode}")
|
||||
44
video_worker/app/schemas.py
Normal file
44
video_worker/app/schemas.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., min_length=1, max_length=1000)
|
||||
negative_prompt: str = Field(default="", max_length=1000)
|
||||
quality_mode: Literal["preview", "refine"]
|
||||
duration_sec: int = Field(default=5, ge=1, le=5)
|
||||
width: int = Field(default=832, ge=64, le=832)
|
||||
height: int = Field(default=480, ge=64, le=480)
|
||||
fps: int = Field(default=16, ge=1, le=24)
|
||||
steps: int = Field(default=8, ge=1, le=100)
|
||||
seed: Optional[int] = Field(default=None, ge=0, le=2**31 - 1)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
task_id: str
|
||||
status: Literal["PENDING", "RUNNING", "SUCCEEDED", "FAILED"]
|
||||
backend: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
progress: float = 0.0
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class TaskResultResponse(BaseModel):
|
||||
task_id: str
|
||||
status: Literal["PENDING", "RUNNING", "SUCCEEDED", "FAILED"]
|
||||
video_path: Optional[str] = None
|
||||
first_frame_path: Optional[str] = None
|
||||
metadata_path: Optional[str] = None
|
||||
log_path: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
service_status: str
|
||||
cuda_available: bool
|
||||
gpu_name: Optional[str]
|
||||
ltx_loaded: bool
|
||||
hunyuan_loaded: bool
|
||||
30
video_worker/app/settings.py
Normal file
30
video_worker/app/settings.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_host: str = Field(default="0.0.0.0", alias="APP_HOST")
|
||||
app_port: int = Field(default=8000, alias="APP_PORT")
|
||||
|
||||
output_dir: Path = Field(default=Path("./outputs"), alias="OUTPUT_DIR")
|
||||
runtime_dir: Path = Field(default=Path("./runtime"), alias="RUNTIME_DIR")
|
||||
sqlite_path: Path = Field(default=Path("./runtime/tasks.db"), alias="SQLITE_PATH")
|
||||
|
||||
ltx_model_dir: Path = Field(default=Path("./models/ltx"), alias="LTX_MODEL_DIR")
|
||||
hunyuan_model_dir: Path = Field(default=Path("./models/hunyuan"), alias="HUNYUAN_MODEL_DIR")
|
||||
|
||||
default_width: int = Field(default=832, alias="DEFAULT_WIDTH")
|
||||
default_height: int = Field(default=480, alias="DEFAULT_HEIGHT")
|
||||
default_fps: int = Field(default=16, alias="DEFAULT_FPS")
|
||||
default_duration: int = Field(default=5, alias="DEFAULT_DURATION")
|
||||
default_steps_preview: int = Field(default=8, alias="DEFAULT_STEPS_PREVIEW")
|
||||
default_steps_refine: int = Field(default=12, alias="DEFAULT_STEPS_REFINE")
|
||||
|
||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
102
video_worker/app/task_manager.py
Normal file
102
video_worker/app/task_manager.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from app.schemas import GenerateRequest, TaskResultResponse, TaskStatusResponse
|
||||
from app.task_store import TaskRecord, TaskStore
|
||||
from app.utils.files import TASK_LOG_NAME, TASK_METADATA_NAME, ensure_dir, task_output_dir
|
||||
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self, store: TaskStore, output_root: Path):
|
||||
self.store = store
|
||||
self.output_root = output_root
|
||||
self.queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def create_task(self, req: GenerateRequest) -> TaskStatusResponse:
|
||||
task_id = uuid4().hex
|
||||
output_dir = task_output_dir(self.output_root, task_id)
|
||||
ensure_dir(output_dir)
|
||||
self.store.create_task(task_id=task_id, request_json=req.model_dump(), output_dir=str(output_dir.resolve()))
|
||||
await self.queue.put(task_id)
|
||||
return self.get_status(task_id)
|
||||
|
||||
def get_task_record(self, task_id: str) -> TaskRecord:
|
||||
task = self.store.get_task(task_id)
|
||||
if task is None:
|
||||
raise KeyError(f"Task not found: {task_id}")
|
||||
return task
|
||||
|
||||
def get_status(self, task_id: str) -> TaskStatusResponse:
|
||||
task = self.get_task_record(task_id)
|
||||
return TaskStatusResponse(
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
backend=task.backend,
|
||||
model_name=task.model_name,
|
||||
progress=task.progress,
|
||||
created_at=datetime.fromisoformat(task.created_at),
|
||||
updated_at=datetime.fromisoformat(task.updated_at),
|
||||
)
|
||||
|
||||
def get_result(self, task_id: str) -> TaskResultResponse:
|
||||
task = self.get_task_record(task_id)
|
||||
return TaskResultResponse(
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
video_path=task.video_path,
|
||||
first_frame_path=task.first_frame_path,
|
||||
metadata_path=task.metadata_path,
|
||||
log_path=task.log_path,
|
||||
error=task.error_message,
|
||||
)
|
||||
|
||||
def mark_running(self, task_id: str, backend: str, model_name: str) -> None:
|
||||
self.store.update_task(
|
||||
task_id,
|
||||
status="RUNNING",
|
||||
backend=backend,
|
||||
model_name=model_name,
|
||||
progress=0.1,
|
||||
started_at=datetime.utcnow().isoformat(),
|
||||
)
|
||||
|
||||
def mark_progress(self, task_id: str, progress: float) -> None:
|
||||
self.store.update_task(task_id, progress=max(0.0, min(1.0, progress)))
|
||||
|
||||
def mark_succeeded(
|
||||
self,
|
||||
task_id: str,
|
||||
video_path: str,
|
||||
first_frame_path: str,
|
||||
metadata_path: str,
|
||||
log_path: str,
|
||||
) -> None:
|
||||
self.store.update_task(
|
||||
task_id,
|
||||
status="SUCCEEDED",
|
||||
progress=1.0,
|
||||
video_path=video_path,
|
||||
first_frame_path=first_frame_path,
|
||||
metadata_path=metadata_path,
|
||||
log_path=log_path,
|
||||
finished_at=datetime.utcnow().isoformat(),
|
||||
)
|
||||
|
||||
def mark_failed(self, task_id: str, error_message: str, log_path: str | None = None) -> None:
|
||||
updates = {
|
||||
"status": "FAILED",
|
||||
"progress": 1.0,
|
||||
"error_message": error_message,
|
||||
"finished_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
if log_path is not None:
|
||||
updates["log_path"] = log_path
|
||||
self.store.update_task(task_id, **updates)
|
||||
|
||||
def build_metadata_path(self, task: TaskRecord) -> Path:
|
||||
return Path(task.output_dir) / TASK_METADATA_NAME
|
||||
|
||||
def build_log_path(self, task: TaskRecord) -> Path:
|
||||
return Path(task.output_dir) / TASK_LOG_NAME
|
||||
156
video_worker/app/task_store.py
Normal file
156
video_worker/app/task_store.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, Optional
|
||||
|
||||
|
||||
STATUS_PENDING = "PENDING"
|
||||
STATUS_RUNNING = "RUNNING"
|
||||
STATUS_SUCCEEDED = "SUCCEEDED"
|
||||
STATUS_FAILED = "FAILED"
|
||||
|
||||
SCHEMA_VERSION = 2
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRecord:
|
||||
task_id: str
|
||||
status: str
|
||||
backend: Optional[str]
|
||||
model_name: Optional[str]
|
||||
request_json: Dict[str, Any]
|
||||
output_dir: str
|
||||
progress: float
|
||||
error_message: Optional[str]
|
||||
video_path: Optional[str]
|
||||
first_frame_path: Optional[str]
|
||||
metadata_path: Optional[str]
|
||||
log_path: Optional[str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
started_at: Optional[str]
|
||||
finished_at: Optional[str]
|
||||
|
||||
|
||||
class TaskStore:
|
||||
def __init__(self, sqlite_path: Path):
|
||||
self.sqlite_path = sqlite_path
|
||||
self.sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@contextmanager
|
||||
def conn(self) -> Iterator[sqlite3.Connection]:
|
||||
connection = sqlite3.connect(self.sqlite_path, check_same_thread=False)
|
||||
connection.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield connection
|
||||
connection.commit()
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
def migrate(self) -> None:
|
||||
with self.conn() as connection:
|
||||
connection.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
current = connection.execute("SELECT MAX(version) as v FROM schema_migrations").fetchone()["v"]
|
||||
current_version = int(current or 0)
|
||||
|
||||
if current_version < 1:
|
||||
connection.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
status TEXT NOT NULL,
|
||||
backend TEXT,
|
||||
model_name TEXT,
|
||||
request_json TEXT NOT NULL,
|
||||
output_dir TEXT NOT NULL,
|
||||
progress REAL NOT NULL DEFAULT 0,
|
||||
error_message TEXT,
|
||||
video_path TEXT,
|
||||
first_frame_path TEXT,
|
||||
metadata_path TEXT,
|
||||
log_path TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
finished_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
connection.execute(
|
||||
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
|
||||
(1, utc_now_iso()),
|
||||
)
|
||||
|
||||
if current_version < 2:
|
||||
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)")
|
||||
connection.execute("CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at)")
|
||||
connection.execute(
|
||||
"INSERT INTO schema_migrations(version, applied_at) VALUES (?, ?)",
|
||||
(2, utc_now_iso()),
|
||||
)
|
||||
|
||||
def create_task(self, task_id: str, request_json: Dict[str, Any], output_dir: str) -> None:
|
||||
now = utc_now_iso()
|
||||
with self.conn() as connection:
|
||||
connection.execute(
|
||||
"""
|
||||
INSERT INTO tasks (
|
||||
task_id, status, request_json, output_dir, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(task_id, STATUS_PENDING, json.dumps(request_json, ensure_ascii=False), output_dir, now, now),
|
||||
)
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[TaskRecord]:
|
||||
with self.conn() as connection:
|
||||
row = connection.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return TaskRecord(
|
||||
task_id=row["task_id"],
|
||||
status=row["status"],
|
||||
backend=row["backend"],
|
||||
model_name=row["model_name"],
|
||||
request_json=json.loads(row["request_json"]),
|
||||
output_dir=row["output_dir"],
|
||||
progress=float(row["progress"]),
|
||||
error_message=row["error_message"],
|
||||
video_path=row["video_path"],
|
||||
first_frame_path=row["first_frame_path"],
|
||||
metadata_path=row["metadata_path"],
|
||||
log_path=row["log_path"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
started_at=row["started_at"],
|
||||
finished_at=row["finished_at"],
|
||||
)
|
||||
|
||||
def update_task(self, task_id: str, **fields: Any) -> None:
|
||||
if not fields:
|
||||
return
|
||||
fields["updated_at"] = utc_now_iso()
|
||||
keys = sorted(fields.keys())
|
||||
assignments = ", ".join([f"{key} = ?" for key in keys])
|
||||
values = [fields[key] for key in keys]
|
||||
values.append(task_id)
|
||||
with self.conn() as connection:
|
||||
connection.execute(f"UPDATE tasks SET {assignments} WHERE task_id = ?", values)
|
||||
|
||||
def list_migrations(self) -> list[dict[str, Any]]:
|
||||
with self.conn() as connection:
|
||||
rows = connection.execute("SELECT version, applied_at FROM schema_migrations ORDER BY version ASC").fetchall()
|
||||
return [{"version": row["version"], "applied_at": row["applied_at"]} for row in rows]
|
||||
0
video_worker/app/utils/__init__.py
Normal file
0
video_worker/app/utils/__init__.py
Normal file
40
video_worker/app/utils/ffmpeg_utils.py
Normal file
40
video_worker/app/utils/ffmpeg_utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_cmd(cmd: list[str]) -> None:
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"Command failed: {' '.join(cmd)}\nSTDOUT: {proc.stdout}\nSTDERR: {proc.stderr}")
|
||||
|
||||
|
||||
def frames_to_video(frames_pattern: str, fps: int, output_video_path: Path) -> None:
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-framerate",
|
||||
str(fps),
|
||||
"-i",
|
||||
frames_pattern,
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
str(output_video_path),
|
||||
]
|
||||
run_cmd(cmd)
|
||||
|
||||
|
||||
def extract_first_frame(video_path: Path, first_frame_path: Path) -> None:
|
||||
first_frame_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-vf",
|
||||
"select=eq(n\\,0)",
|
||||
"-vframes",
|
||||
"1",
|
||||
str(first_frame_path),
|
||||
]
|
||||
run_cmd(cmd)
|
||||
24
video_worker/app/utils/files.py
Normal file
24
video_worker/app/utils/files.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
TASK_VIDEO_NAME = "video.mp4"
|
||||
TASK_FIRST_FRAME_NAME = "first_frame.jpg"
|
||||
TASK_METADATA_NAME = "metadata.json"
|
||||
TASK_LOG_NAME = "run.log"
|
||||
|
||||
|
||||
def ensure_dir(path: Path) -> Path:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def task_output_dir(base_output_dir: Path, task_id: str) -> Path:
|
||||
return ensure_dir(base_output_dir / task_id)
|
||||
|
||||
|
||||
def write_json(path: Path, data: Dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
12
video_worker/app/utils/image_utils.py
Normal file
12
video_worker/app/utils/image_utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def make_dummy_frame(path: Path, width: int, height: int, text: str, step: int) -> None:
|
||||
image = Image.new("RGB", (width, height), color=(25 + step * 5 % 200, 40, 60))
|
||||
draw = ImageDraw.Draw(image)
|
||||
font = ImageFont.load_default()
|
||||
draw.text((16, 16), text, fill=(240, 240, 240), font=font)
|
||||
draw.text((16, 38), f"frame={step}", fill=(220, 220, 220), font=font)
|
||||
image.save(path, format="JPEG", quality=90)
|
||||
24
video_worker/app/utils/logger.py
Normal file
24
video_worker/app/utils/logger.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def build_logger(name: str, log_level: str = "INFO", log_file: Path | None = None) -> logging.Logger:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
||||
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s")
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
if log_file is not None:
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
Reference in New Issue
Block a user