fix:优化服务
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
|
||||
from app.schemas import GenerateRequest, HealthResponse, TaskResultResponse, TaskStatusResponse
|
||||
|
||||
@@ -46,3 +49,88 @@ def health_check():
|
||||
ltx_loaded=ltx_backend.is_loaded(),
|
||||
hunyuan_loaded=hunyuan_backend.is_loaded(),
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/ws/generate")
|
||||
async def ws_generate(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
task_manager = router.task_manager
|
||||
|
||||
try:
|
||||
message: dict[str, Any] = await websocket.receive_json()
|
||||
action = message.get("action")
|
||||
|
||||
if action == "watch":
|
||||
task_id = message.get("task_id")
|
||||
if not task_id:
|
||||
await websocket.send_json({"event": "error", "error": "task_id is required for watch"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
else:
|
||||
payload = message.get("payload", message)
|
||||
try:
|
||||
req = GenerateRequest.model_validate(payload)
|
||||
except Exception as exc:
|
||||
await websocket.send_json({"event": "error", "error": f"invalid request: {exc}"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
created = await task_manager.create_task(req)
|
||||
task_id = created.task_id
|
||||
await websocket.send_json(
|
||||
{
|
||||
"event": "accepted",
|
||||
"task_id": task_id,
|
||||
"status": created.status,
|
||||
"backend": created.backend,
|
||||
"model_name": created.model_name,
|
||||
"progress": created.progress,
|
||||
"created_at": created.created_at.isoformat(),
|
||||
"updated_at": created.updated_at.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
last_status = None
|
||||
while True:
|
||||
status = task_manager.get_status(task_id)
|
||||
if status.status != last_status:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"event": "status",
|
||||
"task_id": status.task_id,
|
||||
"status": status.status,
|
||||
"backend": status.backend,
|
||||
"model_name": status.model_name,
|
||||
"progress": status.progress,
|
||||
"created_at": status.created_at.isoformat(),
|
||||
"updated_at": status.updated_at.isoformat(),
|
||||
}
|
||||
)
|
||||
last_status = status.status
|
||||
|
||||
if status.status in {"SUCCEEDED", "FAILED"}:
|
||||
result = task_manager.get_result(task_id)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"event": "result",
|
||||
"task_id": result.task_id,
|
||||
"status": result.status,
|
||||
"video_path": result.video_path,
|
||||
"first_frame_path": result.first_frame_path,
|
||||
"metadata_path": result.metadata_path,
|
||||
"log_path": result.log_path,
|
||||
"error": result.error,
|
||||
}
|
||||
)
|
||||
await websocket.close(code=1000)
|
||||
break
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except WebSocketDisconnect:
|
||||
return
|
||||
except KeyError as exc:
|
||||
await websocket.send_json({"event": "error", "error": str(exc)})
|
||||
await websocket.close(code=1008)
|
||||
except Exception as exc:
|
||||
await websocket.send_json({"event": "error", "error": f"unexpected error: {exc}"})
|
||||
await websocket.close(code=1011)
|
||||
|
||||
232
video_worker/app/edge_dispatch_service.py
Normal file
232
video_worker/app/edge_dispatch_service.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schemas import GenerateRequest
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
class DispatchGenerateRequest(BaseModel):
|
||||
device_id: Optional[str] = None
|
||||
request: GenerateRequest
|
||||
|
||||
|
||||
class DispatchResponse(BaseModel):
|
||||
dispatch_id: str
|
||||
device_id: str
|
||||
status: str
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class EdgeConnection:
|
||||
device_id: str
|
||||
websocket: WebSocket
|
||||
connected_at: str
|
||||
last_seen: str
|
||||
busy: bool = False
|
||||
|
||||
|
||||
class EdgeDispatchManager:
|
||||
def __init__(self) -> None:
|
||||
self.connections: dict[str, EdgeConnection] = {}
|
||||
self.dispatches: dict[str, dict[str, Any]] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def register(self, device_id: str, websocket: WebSocket) -> EdgeConnection:
|
||||
async with self.lock:
|
||||
conn = EdgeConnection(
|
||||
device_id=device_id,
|
||||
websocket=websocket,
|
||||
connected_at=utc_now_iso(),
|
||||
last_seen=utc_now_iso(),
|
||||
busy=False,
|
||||
)
|
||||
self.connections[device_id] = conn
|
||||
return conn
|
||||
|
||||
async def unregister(self, device_id: str) -> None:
|
||||
async with self.lock:
|
||||
self.connections.pop(device_id, None)
|
||||
|
||||
async def list_devices(self) -> list[dict[str, Any]]:
|
||||
async with self.lock:
|
||||
return [
|
||||
{
|
||||
"device_id": conn.device_id,
|
||||
"connected_at": conn.connected_at,
|
||||
"last_seen": conn.last_seen,
|
||||
"busy": conn.busy,
|
||||
}
|
||||
for conn in self.connections.values()
|
||||
]
|
||||
|
||||
async def select_device(self, preferred: Optional[str]) -> EdgeConnection:
|
||||
async with self.lock:
|
||||
if preferred:
|
||||
conn = self.connections.get(preferred)
|
||||
if conn is None:
|
||||
raise HTTPException(status_code=404, detail=f"device not found: {preferred}")
|
||||
if conn.busy:
|
||||
raise HTTPException(status_code=409, detail=f"device is busy: {preferred}")
|
||||
return conn
|
||||
|
||||
for conn in self.connections.values():
|
||||
if not conn.busy:
|
||||
return conn
|
||||
|
||||
raise HTTPException(status_code=409, detail="no idle edge device available")
|
||||
|
||||
async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]:
|
||||
dispatch_id = uuid4().hex
|
||||
now = utc_now_iso()
|
||||
record = {
|
||||
"dispatch_id": dispatch_id,
|
||||
"device_id": conn.device_id,
|
||||
"status": "DISPATCHED",
|
||||
"request": req.model_dump(),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
async with self.lock:
|
||||
self.dispatches[dispatch_id] = record
|
||||
target = self.connections.get(conn.device_id)
|
||||
if target is None:
|
||||
raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}")
|
||||
target.busy = True
|
||||
target.last_seen = now
|
||||
|
||||
payload = {
|
||||
"event": "generate",
|
||||
"dispatch_id": dispatch_id,
|
||||
"request": req.model_dump(),
|
||||
}
|
||||
try:
|
||||
await conn.websocket.send_json(payload)
|
||||
except Exception as exc:
|
||||
async with self.lock:
|
||||
record["status"] = "FAILED"
|
||||
record["error"] = f"dispatch send failed: {exc}"
|
||||
record["updated_at"] = utc_now_iso()
|
||||
target = self.connections.get(conn.device_id)
|
||||
if target:
|
||||
target.busy = False
|
||||
raise
|
||||
return record
|
||||
|
||||
async def mark_event(self, device_id: str, event: dict[str, Any]) -> None:
|
||||
now = utc_now_iso()
|
||||
async with self.lock:
|
||||
conn = self.connections.get(device_id)
|
||||
if conn:
|
||||
conn.last_seen = now
|
||||
|
||||
dispatch_id = event.get("dispatch_id")
|
||||
if not dispatch_id:
|
||||
return
|
||||
|
||||
record = self.dispatches.get(dispatch_id)
|
||||
if record is None:
|
||||
return
|
||||
|
||||
evt = event.get("event")
|
||||
if evt == "accepted":
|
||||
record["status"] = "RUNNING"
|
||||
elif evt == "status":
|
||||
status_val = event.get("status")
|
||||
if status_val:
|
||||
record["status"] = status_val
|
||||
elif evt == "result":
|
||||
status_val = event.get("status") or "SUCCEEDED"
|
||||
record["status"] = status_val
|
||||
record["result"] = event
|
||||
if conn:
|
||||
conn.busy = False
|
||||
elif evt == "error":
|
||||
record["status"] = "FAILED"
|
||||
record["error"] = event.get("error")
|
||||
if conn:
|
||||
conn.busy = False
|
||||
|
||||
record["updated_at"] = now
|
||||
|
||||
async def mark_disconnect_failed(self, device_id: str) -> None:
|
||||
now = utc_now_iso()
|
||||
async with self.lock:
|
||||
for record in self.dispatches.values():
|
||||
if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}:
|
||||
record["status"] = "FAILED"
|
||||
record["error"] = f"device disconnected: {device_id}"
|
||||
record["updated_at"] = now
|
||||
|
||||
|
||||
manager = EdgeDispatchManager()
|
||||
app = FastAPI(title="Edge Dispatch Service", version="0.1.0")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, Any]:
|
||||
devices = await manager.list_devices()
|
||||
return {
|
||||
"service": "edge_dispatch",
|
||||
"status": "ok",
|
||||
"connected_devices": len(devices),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/devices")
|
||||
async def list_devices() -> dict[str, Any]:
|
||||
return {"devices": await manager.list_devices()}
|
||||
|
||||
|
||||
@app.post("/dispatch/generate", response_model=DispatchResponse)
|
||||
async def dispatch_generate(body: DispatchGenerateRequest) -> DispatchResponse:
|
||||
conn = await manager.select_device(body.device_id)
|
||||
try:
|
||||
record = await manager.create_dispatch(conn, body.request)
|
||||
except WebSocketDisconnect as exc:
|
||||
await manager.unregister(conn.device_id)
|
||||
raise HTTPException(status_code=503, detail=f"device disconnected during dispatch: {conn.device_id}") from exc
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
|
||||
return DispatchResponse(
|
||||
dispatch_id=record["dispatch_id"],
|
||||
device_id=record["device_id"],
|
||||
status=record["status"],
|
||||
created_at=record["created_at"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/dispatch/{dispatch_id}")
|
||||
async def get_dispatch(dispatch_id: str) -> dict[str, Any]:
|
||||
record = manager.dispatches.get(dispatch_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}")
|
||||
return record
|
||||
|
||||
|
||||
@app.websocket("/ws/edge/{device_id}")
|
||||
async def edge_socket(websocket: WebSocket, device_id: str) -> None:
|
||||
await websocket.accept()
|
||||
await manager.register(device_id, websocket)
|
||||
|
||||
try:
|
||||
await websocket.send_json({"event": "registered", "device_id": device_id, "ts": utc_now_iso()})
|
||||
while True:
|
||||
msg = await websocket.receive_json()
|
||||
await manager.mark_event(device_id, msg)
|
||||
except WebSocketDisconnect:
|
||||
await manager.mark_disconnect_failed(device_id)
|
||||
await manager.unregister(device_id)
|
||||
@@ -7,6 +7,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
class Settings(BaseSettings):
|
||||
app_host: str = Field(default="0.0.0.0", alias="APP_HOST")
|
||||
app_port: int = Field(default=8000, alias="APP_PORT")
|
||||
ws_gateway_host: str = Field(default="0.0.0.0", alias="WS_GATEWAY_HOST")
|
||||
ws_gateway_port: int = Field(default=8010, alias="WS_GATEWAY_PORT")
|
||||
worker_base_url: str = Field(default="http://127.0.0.1:8000", alias="WORKER_BASE_URL")
|
||||
ws_poll_interval_sec: float = Field(default=1.0, alias="WS_POLL_INTERVAL_SEC")
|
||||
edge_dispatch_host: str = Field(default="0.0.0.0", alias="EDGE_DISPATCH_HOST")
|
||||
edge_dispatch_port: int = Field(default=8020, alias="EDGE_DISPATCH_PORT")
|
||||
|
||||
output_dir: Path = Field(default=Path("./outputs"), alias="OUTPUT_DIR")
|
||||
runtime_dir: Path = Field(default=Path("./runtime"), alias="RUNTIME_DIR")
|
||||
|
||||
87
video_worker/app/ws_service.py
Normal file
87
video_worker/app/ws_service.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
|
||||
from app.schemas import GenerateRequest
|
||||
from app.settings import settings
|
||||
|
||||
app = FastAPI(title="Video Worker WS Gateway", version="0.1.0")
|
||||
|
||||
|
||||
def _url(path: str) -> str:
|
||||
return f"{settings.worker_base_url.rstrip('/')}{path}"
|
||||
|
||||
|
||||
def _http_get(path: str) -> dict[str, Any]:
|
||||
resp = requests.get(_url(path), timeout=20)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _http_post(path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
resp = requests.post(_url(path), json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict[str, Any]:
|
||||
gateway_status = "ok"
|
||||
upstream_ok = False
|
||||
upstream_error = None
|
||||
try:
|
||||
_http_get("/health")
|
||||
upstream_ok = True
|
||||
except Exception as exc:
|
||||
upstream_error = str(exc)
|
||||
return {
|
||||
"service": "ws_gateway",
|
||||
"status": gateway_status,
|
||||
"worker_base_url": settings.worker_base_url,
|
||||
"upstream_ok": upstream_ok,
|
||||
"upstream_error": upstream_error,
|
||||
}
|
||||
|
||||
|
||||
@app.websocket("/ws/generate")
|
||||
async def ws_generate(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
message = await websocket.receive_json()
|
||||
action = message.get("action")
|
||||
|
||||
if action == "watch":
|
||||
task_id = message.get("task_id")
|
||||
if not task_id:
|
||||
await websocket.send_json({"event": "error", "error": "task_id is required for watch"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
else:
|
||||
payload = message.get("payload", message)
|
||||
req = GenerateRequest.model_validate(payload)
|
||||
created = await asyncio.to_thread(_http_post, "/generate", req.model_dump())
|
||||
task_id = created["task_id"]
|
||||
await websocket.send_json({"event": "accepted", **created})
|
||||
|
||||
last_status = None
|
||||
while True:
|
||||
status = await asyncio.to_thread(_http_get, f"/tasks/{task_id}")
|
||||
if status.get("status") != last_status:
|
||||
await websocket.send_json({"event": "status", **status})
|
||||
last_status = status.get("status")
|
||||
|
||||
if status.get("status") in {"SUCCEEDED", "FAILED"}:
|
||||
result = await asyncio.to_thread(_http_get, f"/tasks/{task_id}/result")
|
||||
await websocket.send_json({"event": "result", **result})
|
||||
await websocket.close(code=1000)
|
||||
return
|
||||
|
||||
await asyncio.sleep(settings.ws_poll_interval_sec)
|
||||
except WebSocketDisconnect:
|
||||
return
|
||||
except Exception as exc:
|
||||
await websocket.send_json({"event": "error", "error": str(exc)})
|
||||
await websocket.close(code=1011)
|
||||
Reference in New Issue
Block a user