import asyncio from typing import Any from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect from app.schemas import GenerateRequest, HealthResponse, TaskResultResponse, TaskStatusResponse router = APIRouter() @router.post("/generate", response_model=TaskStatusResponse) async def create_generate_task(req: GenerateRequest): task_manager = router.task_manager return await task_manager.create_task(req) @router.get("/tasks/{task_id}", response_model=TaskStatusResponse) def get_task_status(task_id: str): task_manager = router.task_manager try: return task_manager.get_status(task_id) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc @router.get("/tasks/{task_id}/result", response_model=TaskResultResponse) def get_task_result(task_id: str): task_manager = router.task_manager try: return task_manager.get_result(task_id) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc @router.get("/health", response_model=HealthResponse) def health_check(): torch = router.torch ltx_backend = router.ltx_backend hunyuan_backend = router.hunyuan_backend cuda_ok = bool(torch.cuda.is_available()) gpu_name = torch.cuda.get_device_name(0) if cuda_ok else None return HealthResponse( service_status="ok", cuda_available=cuda_ok, gpu_name=gpu_name, ltx_loaded=ltx_backend.is_loaded(), hunyuan_loaded=hunyuan_backend.is_loaded(), ) @router.websocket("/ws/generate") async def ws_generate(websocket: WebSocket): await websocket.accept() task_manager = router.task_manager try: message: dict[str, Any] = await websocket.receive_json() action = message.get("action") if action == "watch": task_id = message.get("task_id") if not task_id: await websocket.send_json({"event": "error", "error": "task_id is required for watch"}) await websocket.close(code=1008) return else: payload = message.get("payload", message) try: req = GenerateRequest.model_validate(payload) except Exception as exc: await websocket.send_json({"event": "error", "error": f"invalid request: {exc}"}) await websocket.close(code=1008) return created = await task_manager.create_task(req) task_id = created.task_id await websocket.send_json( { "event": "accepted", "task_id": task_id, "status": created.status, "backend": created.backend, "model_name": created.model_name, "progress": created.progress, "created_at": created.created_at.isoformat(), "updated_at": created.updated_at.isoformat(), } ) last_status = None while True: status = task_manager.get_status(task_id) if status.status != last_status: await websocket.send_json( { "event": "status", "task_id": status.task_id, "status": status.status, "backend": status.backend, "model_name": status.model_name, "progress": status.progress, "created_at": status.created_at.isoformat(), "updated_at": status.updated_at.isoformat(), } ) last_status = status.status if status.status in {"SUCCEEDED", "FAILED"}: result = task_manager.get_result(task_id) await websocket.send_json( { "event": "result", "task_id": result.task_id, "status": result.status, "video_path": result.video_path, "first_frame_path": result.first_frame_path, "metadata_path": result.metadata_path, "log_path": result.log_path, "error": result.error, } ) await websocket.close(code=1000) break await asyncio.sleep(1) except WebSocketDisconnect: return except KeyError as exc: await websocket.send_json({"event": "error", "error": str(exc)}) await websocket.close(code=1008) except Exception as exc: await websocket.send_json({"event": "error", "error": f"unexpected error: {exc}"}) await websocket.close(code=1011)