Files
AI_A4000/video_worker/app/ws_service.py
2026-04-07 01:00:56 +08:00

88 lines
2.8 KiB
Python

import asyncio
from typing import Any
import requests
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from app.schemas import GenerateRequest
from app.settings import settings
app = FastAPI(title="Video Worker WS Gateway", version="0.1.0")
def _url(path: str) -> str:
return f"{settings.worker_base_url.rstrip('/')}{path}"
def _http_get(path: str) -> dict[str, Any]:
resp = requests.get(_url(path), timeout=20)
resp.raise_for_status()
return resp.json()
def _http_post(path: str, payload: dict[str, Any]) -> dict[str, Any]:
resp = requests.post(_url(path), json=payload, timeout=30)
resp.raise_for_status()
return resp.json()
@app.get("/health")
def health() -> dict[str, Any]:
gateway_status = "ok"
upstream_ok = False
upstream_error = None
try:
_http_get("/health")
upstream_ok = True
except Exception as exc:
upstream_error = str(exc)
return {
"service": "ws_gateway",
"status": gateway_status,
"worker_base_url": settings.worker_base_url,
"upstream_ok": upstream_ok,
"upstream_error": upstream_error,
}
@app.websocket("/ws/generate")
async def ws_generate(websocket: WebSocket) -> None:
await websocket.accept()
try:
message = await websocket.receive_json()
action = message.get("action")
if action == "watch":
task_id = message.get("task_id")
if not task_id:
await websocket.send_json({"event": "error", "error": "task_id is required for watch"})
await websocket.close(code=1008)
return
else:
payload = message.get("payload", message)
req = GenerateRequest.model_validate(payload)
created = await asyncio.to_thread(_http_post, "/generate", req.model_dump())
task_id = created["task_id"]
await websocket.send_json({"event": "accepted", **created})
last_status = None
while True:
status = await asyncio.to_thread(_http_get, f"/tasks/{task_id}")
if status.get("status") != last_status:
await websocket.send_json({"event": "status", **status})
last_status = status.get("status")
if status.get("status") in {"SUCCEEDED", "FAILED"}:
result = await asyncio.to_thread(_http_get, f"/tasks/{task_id}/result")
await websocket.send_json({"event": "result", **result})
await websocket.close(code=1000)
return
await asyncio.sleep(settings.ws_poll_interval_sec)
except WebSocketDisconnect:
return
except Exception as exc:
await websocket.send_json({"event": "error", "error": str(exc)})
await websocket.close(code=1011)