137 lines
4.8 KiB
Python
137 lines
4.8 KiB
Python
import asyncio
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
|
|
|
from app.schemas import GenerateRequest, HealthResponse, TaskResultResponse, TaskStatusResponse
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/generate", response_model=TaskStatusResponse)
|
|
async def create_generate_task(req: GenerateRequest):
|
|
task_manager = router.task_manager
|
|
return await task_manager.create_task(req)
|
|
|
|
|
|
@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
|
|
def get_task_status(task_id: str):
|
|
task_manager = router.task_manager
|
|
try:
|
|
return task_manager.get_status(task_id)
|
|
except KeyError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
|
|
|
|
@router.get("/tasks/{task_id}/result", response_model=TaskResultResponse)
|
|
def get_task_result(task_id: str):
|
|
task_manager = router.task_manager
|
|
try:
|
|
return task_manager.get_result(task_id)
|
|
except KeyError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
|
|
|
|
@router.get("/health", response_model=HealthResponse)
|
|
def health_check():
|
|
torch = router.torch
|
|
ltx_backend = router.ltx_backend
|
|
hunyuan_backend = router.hunyuan_backend
|
|
|
|
cuda_ok = bool(torch.cuda.is_available())
|
|
gpu_name = torch.cuda.get_device_name(0) if cuda_ok else None
|
|
|
|
return HealthResponse(
|
|
service_status="ok",
|
|
cuda_available=cuda_ok,
|
|
gpu_name=gpu_name,
|
|
ltx_loaded=ltx_backend.is_loaded(),
|
|
hunyuan_loaded=hunyuan_backend.is_loaded(),
|
|
)
|
|
|
|
|
|
@router.websocket("/ws/generate")
|
|
async def ws_generate(websocket: WebSocket):
|
|
await websocket.accept()
|
|
task_manager = router.task_manager
|
|
|
|
try:
|
|
message: dict[str, Any] = await websocket.receive_json()
|
|
action = message.get("action")
|
|
|
|
if action == "watch":
|
|
task_id = message.get("task_id")
|
|
if not task_id:
|
|
await websocket.send_json({"event": "error", "error": "task_id is required for watch"})
|
|
await websocket.close(code=1008)
|
|
return
|
|
else:
|
|
payload = message.get("payload", message)
|
|
try:
|
|
req = GenerateRequest.model_validate(payload)
|
|
except Exception as exc:
|
|
await websocket.send_json({"event": "error", "error": f"invalid request: {exc}"})
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
created = await task_manager.create_task(req)
|
|
task_id = created.task_id
|
|
await websocket.send_json(
|
|
{
|
|
"event": "accepted",
|
|
"task_id": task_id,
|
|
"status": created.status,
|
|
"backend": created.backend,
|
|
"model_name": created.model_name,
|
|
"progress": created.progress,
|
|
"created_at": created.created_at.isoformat(),
|
|
"updated_at": created.updated_at.isoformat(),
|
|
}
|
|
)
|
|
|
|
last_status = None
|
|
while True:
|
|
status = task_manager.get_status(task_id)
|
|
if status.status != last_status:
|
|
await websocket.send_json(
|
|
{
|
|
"event": "status",
|
|
"task_id": status.task_id,
|
|
"status": status.status,
|
|
"backend": status.backend,
|
|
"model_name": status.model_name,
|
|
"progress": status.progress,
|
|
"created_at": status.created_at.isoformat(),
|
|
"updated_at": status.updated_at.isoformat(),
|
|
}
|
|
)
|
|
last_status = status.status
|
|
|
|
if status.status in {"SUCCEEDED", "FAILED"}:
|
|
result = task_manager.get_result(task_id)
|
|
await websocket.send_json(
|
|
{
|
|
"event": "result",
|
|
"task_id": result.task_id,
|
|
"status": result.status,
|
|
"video_path": result.video_path,
|
|
"first_frame_path": result.first_frame_path,
|
|
"metadata_path": result.metadata_path,
|
|
"log_path": result.log_path,
|
|
"error": result.error,
|
|
}
|
|
)
|
|
await websocket.close(code=1000)
|
|
break
|
|
|
|
await asyncio.sleep(1)
|
|
except WebSocketDisconnect:
|
|
return
|
|
except KeyError as exc:
|
|
await websocket.send_json({"event": "error", "error": str(exc)})
|
|
await websocket.close(code=1008)
|
|
except Exception as exc:
|
|
await websocket.send_json({"event": "error", "error": f"unexpected error: {exc}"})
|
|
await websocket.close(code=1011)
|