88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
import asyncio
|
|
from typing import Any
|
|
|
|
import requests
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
|
|
from app.schemas import GenerateRequest
|
|
from app.settings import settings
|
|
|
|
app = FastAPI(title="Video Worker WS Gateway", version="0.1.0")
|
|
|
|
|
|
def _url(path: str) -> str:
|
|
return f"{settings.worker_base_url.rstrip('/')}{path}"
|
|
|
|
|
|
def _http_get(path: str) -> dict[str, Any]:
|
|
resp = requests.get(_url(path), timeout=20)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
def _http_post(path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
resp = requests.post(_url(path), json=payload, timeout=30)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> dict[str, Any]:
|
|
gateway_status = "ok"
|
|
upstream_ok = False
|
|
upstream_error = None
|
|
try:
|
|
_http_get("/health")
|
|
upstream_ok = True
|
|
except Exception as exc:
|
|
upstream_error = str(exc)
|
|
return {
|
|
"service": "ws_gateway",
|
|
"status": gateway_status,
|
|
"worker_base_url": settings.worker_base_url,
|
|
"upstream_ok": upstream_ok,
|
|
"upstream_error": upstream_error,
|
|
}
|
|
|
|
|
|
@app.websocket("/ws/generate")
|
|
async def ws_generate(websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
|
|
try:
|
|
message = await websocket.receive_json()
|
|
action = message.get("action")
|
|
|
|
if action == "watch":
|
|
task_id = message.get("task_id")
|
|
if not task_id:
|
|
await websocket.send_json({"event": "error", "error": "task_id is required for watch"})
|
|
await websocket.close(code=1008)
|
|
return
|
|
else:
|
|
payload = message.get("payload", message)
|
|
req = GenerateRequest.model_validate(payload)
|
|
created = await asyncio.to_thread(_http_post, "/generate", req.model_dump())
|
|
task_id = created["task_id"]
|
|
await websocket.send_json({"event": "accepted", **created})
|
|
|
|
last_status = None
|
|
while True:
|
|
status = await asyncio.to_thread(_http_get, f"/tasks/{task_id}")
|
|
if status.get("status") != last_status:
|
|
await websocket.send_json({"event": "status", **status})
|
|
last_status = status.get("status")
|
|
|
|
if status.get("status") in {"SUCCEEDED", "FAILED"}:
|
|
result = await asyncio.to_thread(_http_get, f"/tasks/{task_id}/result")
|
|
await websocket.send_json({"event": "result", **result})
|
|
await websocket.close(code=1000)
|
|
return
|
|
|
|
await asyncio.sleep(settings.ws_poll_interval_sec)
|
|
except WebSocketDisconnect:
|
|
return
|
|
except Exception as exc:
|
|
await websocket.send_json({"event": "error", "error": str(exc)})
|
|
await websocket.close(code=1011)
|