import asyncio from typing import Any import requests from fastapi import FastAPI, WebSocket, WebSocketDisconnect from app.schemas import GenerateRequest from app.settings import settings app = FastAPI(title="Video Worker WS Gateway", version="0.1.0") def _url(path: str) -> str: return f"{settings.worker_base_url.rstrip('/')}{path}" def _http_get(path: str) -> dict[str, Any]: resp = requests.get(_url(path), timeout=20) resp.raise_for_status() return resp.json() def _http_post(path: str, payload: dict[str, Any]) -> dict[str, Any]: resp = requests.post(_url(path), json=payload, timeout=30) resp.raise_for_status() return resp.json() @app.get("/health") def health() -> dict[str, Any]: gateway_status = "ok" upstream_ok = False upstream_error = None try: _http_get("/health") upstream_ok = True except Exception as exc: upstream_error = str(exc) return { "service": "ws_gateway", "status": gateway_status, "worker_base_url": settings.worker_base_url, "upstream_ok": upstream_ok, "upstream_error": upstream_error, } @app.websocket("/ws/generate") async def ws_generate(websocket: WebSocket) -> None: await websocket.accept() try: message = await websocket.receive_json() action = message.get("action") if action == "watch": task_id = message.get("task_id") if not task_id: await websocket.send_json({"event": "error", "error": "task_id is required for watch"}) await websocket.close(code=1008) return else: payload = message.get("payload", message) req = GenerateRequest.model_validate(payload) created = await asyncio.to_thread(_http_post, "/generate", req.model_dump()) task_id = created["task_id"] await websocket.send_json({"event": "accepted", **created}) last_status = None while True: status = await asyncio.to_thread(_http_get, f"/tasks/{task_id}") if status.get("status") != last_status: await websocket.send_json({"event": "status", **status}) last_status = status.get("status") if status.get("status") in {"SUCCEEDED", "FAILED"}: result = await asyncio.to_thread(_http_get, f"/tasks/{task_id}/result") await websocket.send_json({"event": "result", **result}) await websocket.close(code=1000) return await asyncio.sleep(settings.ws_poll_interval_sec) except WebSocketDisconnect: return except Exception as exc: await websocket.send_json({"event": "error", "error": str(exc)}) await websocket.close(code=1011)