fix:优化服务
This commit is contained in:
87
video_worker/app/ws_service.py
Normal file
87
video_worker/app/ws_service.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
|
||||
from app.schemas import GenerateRequest
|
||||
from app.settings import settings
|
||||
|
||||
app = FastAPI(title="Video Worker WS Gateway", version="0.1.0")
|
||||
|
||||
|
||||
def _url(path: str) -> str:
|
||||
return f"{settings.worker_base_url.rstrip('/')}{path}"
|
||||
|
||||
|
||||
def _http_get(path: str) -> dict[str, Any]:
|
||||
resp = requests.get(_url(path), timeout=20)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _http_post(path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
resp = requests.post(_url(path), json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict[str, Any]:
|
||||
gateway_status = "ok"
|
||||
upstream_ok = False
|
||||
upstream_error = None
|
||||
try:
|
||||
_http_get("/health")
|
||||
upstream_ok = True
|
||||
except Exception as exc:
|
||||
upstream_error = str(exc)
|
||||
return {
|
||||
"service": "ws_gateway",
|
||||
"status": gateway_status,
|
||||
"worker_base_url": settings.worker_base_url,
|
||||
"upstream_ok": upstream_ok,
|
||||
"upstream_error": upstream_error,
|
||||
}
|
||||
|
||||
|
||||
@app.websocket("/ws/generate")
|
||||
async def ws_generate(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
message = await websocket.receive_json()
|
||||
action = message.get("action")
|
||||
|
||||
if action == "watch":
|
||||
task_id = message.get("task_id")
|
||||
if not task_id:
|
||||
await websocket.send_json({"event": "error", "error": "task_id is required for watch"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
else:
|
||||
payload = message.get("payload", message)
|
||||
req = GenerateRequest.model_validate(payload)
|
||||
created = await asyncio.to_thread(_http_post, "/generate", req.model_dump())
|
||||
task_id = created["task_id"]
|
||||
await websocket.send_json({"event": "accepted", **created})
|
||||
|
||||
last_status = None
|
||||
while True:
|
||||
status = await asyncio.to_thread(_http_get, f"/tasks/{task_id}")
|
||||
if status.get("status") != last_status:
|
||||
await websocket.send_json({"event": "status", **status})
|
||||
last_status = status.get("status")
|
||||
|
||||
if status.get("status") in {"SUCCEEDED", "FAILED"}:
|
||||
result = await asyncio.to_thread(_http_get, f"/tasks/{task_id}/result")
|
||||
await websocket.send_json({"event": "result", **result})
|
||||
await websocket.close(code=1000)
|
||||
return
|
||||
|
||||
await asyncio.sleep(settings.ws_poll_interval_sec)
|
||||
except WebSocketDisconnect:
|
||||
return
|
||||
except Exception as exc:
|
||||
await websocket.send_json({"event": "error", "error": str(exc)})
|
||||
await websocket.close(code=1011)
|
||||
Reference in New Issue
Block a user