fix:优化服务

This commit is contained in:
Daniel
2026-04-07 01:00:56 +08:00
parent 8d0b729f2f
commit e606b3dcd6
19 changed files with 899 additions and 7 deletions

View File

@@ -1,4 +1,7 @@
from fastapi import APIRouter, HTTPException
import asyncio
from typing import Any
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from app.schemas import GenerateRequest, HealthResponse, TaskResultResponse, TaskStatusResponse
@@ -46,3 +49,88 @@ def health_check():
ltx_loaded=ltx_backend.is_loaded(),
hunyuan_loaded=hunyuan_backend.is_loaded(),
)
@router.websocket("/ws/generate")
async def ws_generate(websocket: WebSocket):
await websocket.accept()
task_manager = router.task_manager
try:
message: dict[str, Any] = await websocket.receive_json()
action = message.get("action")
if action == "watch":
task_id = message.get("task_id")
if not task_id:
await websocket.send_json({"event": "error", "error": "task_id is required for watch"})
await websocket.close(code=1008)
return
else:
payload = message.get("payload", message)
try:
req = GenerateRequest.model_validate(payload)
except Exception as exc:
await websocket.send_json({"event": "error", "error": f"invalid request: {exc}"})
await websocket.close(code=1008)
return
created = await task_manager.create_task(req)
task_id = created.task_id
await websocket.send_json(
{
"event": "accepted",
"task_id": task_id,
"status": created.status,
"backend": created.backend,
"model_name": created.model_name,
"progress": created.progress,
"created_at": created.created_at.isoformat(),
"updated_at": created.updated_at.isoformat(),
}
)
last_status = None
while True:
status = task_manager.get_status(task_id)
if status.status != last_status:
await websocket.send_json(
{
"event": "status",
"task_id": status.task_id,
"status": status.status,
"backend": status.backend,
"model_name": status.model_name,
"progress": status.progress,
"created_at": status.created_at.isoformat(),
"updated_at": status.updated_at.isoformat(),
}
)
last_status = status.status
if status.status in {"SUCCEEDED", "FAILED"}:
result = task_manager.get_result(task_id)
await websocket.send_json(
{
"event": "result",
"task_id": result.task_id,
"status": result.status,
"video_path": result.video_path,
"first_frame_path": result.first_frame_path,
"metadata_path": result.metadata_path,
"log_path": result.log_path,
"error": result.error,
}
)
await websocket.close(code=1000)
break
await asyncio.sleep(1)
except WebSocketDisconnect:
return
except KeyError as exc:
await websocket.send_json({"event": "error", "error": str(exc)})
await websocket.close(code=1008)
except Exception as exc:
await websocket.send_json({"event": "error", "error": f"unexpected error: {exc}"})
await websocket.close(code=1011)

View File

@@ -0,0 +1,232 @@
import asyncio
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Optional
from uuid import uuid4
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from app.schemas import GenerateRequest
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
class DispatchGenerateRequest(BaseModel):
device_id: Optional[str] = None
request: GenerateRequest
class DispatchResponse(BaseModel):
dispatch_id: str
device_id: str
status: str
created_at: str
@dataclass
class EdgeConnection:
device_id: str
websocket: WebSocket
connected_at: str
last_seen: str
busy: bool = False
class EdgeDispatchManager:
def __init__(self) -> None:
self.connections: dict[str, EdgeConnection] = {}
self.dispatches: dict[str, dict[str, Any]] = {}
self.lock = asyncio.Lock()
async def register(self, device_id: str, websocket: WebSocket) -> EdgeConnection:
async with self.lock:
conn = EdgeConnection(
device_id=device_id,
websocket=websocket,
connected_at=utc_now_iso(),
last_seen=utc_now_iso(),
busy=False,
)
self.connections[device_id] = conn
return conn
async def unregister(self, device_id: str) -> None:
async with self.lock:
self.connections.pop(device_id, None)
async def list_devices(self) -> list[dict[str, Any]]:
async with self.lock:
return [
{
"device_id": conn.device_id,
"connected_at": conn.connected_at,
"last_seen": conn.last_seen,
"busy": conn.busy,
}
for conn in self.connections.values()
]
async def select_device(self, preferred: Optional[str]) -> EdgeConnection:
async with self.lock:
if preferred:
conn = self.connections.get(preferred)
if conn is None:
raise HTTPException(status_code=404, detail=f"device not found: {preferred}")
if conn.busy:
raise HTTPException(status_code=409, detail=f"device is busy: {preferred}")
return conn
for conn in self.connections.values():
if not conn.busy:
return conn
raise HTTPException(status_code=409, detail="no idle edge device available")
async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]:
dispatch_id = uuid4().hex
now = utc_now_iso()
record = {
"dispatch_id": dispatch_id,
"device_id": conn.device_id,
"status": "DISPATCHED",
"request": req.model_dump(),
"created_at": now,
"updated_at": now,
"result": None,
"error": None,
}
async with self.lock:
self.dispatches[dispatch_id] = record
target = self.connections.get(conn.device_id)
if target is None:
raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}")
target.busy = True
target.last_seen = now
payload = {
"event": "generate",
"dispatch_id": dispatch_id,
"request": req.model_dump(),
}
try:
await conn.websocket.send_json(payload)
except Exception as exc:
async with self.lock:
record["status"] = "FAILED"
record["error"] = f"dispatch send failed: {exc}"
record["updated_at"] = utc_now_iso()
target = self.connections.get(conn.device_id)
if target:
target.busy = False
raise
return record
async def mark_event(self, device_id: str, event: dict[str, Any]) -> None:
now = utc_now_iso()
async with self.lock:
conn = self.connections.get(device_id)
if conn:
conn.last_seen = now
dispatch_id = event.get("dispatch_id")
if not dispatch_id:
return
record = self.dispatches.get(dispatch_id)
if record is None:
return
evt = event.get("event")
if evt == "accepted":
record["status"] = "RUNNING"
elif evt == "status":
status_val = event.get("status")
if status_val:
record["status"] = status_val
elif evt == "result":
status_val = event.get("status") or "SUCCEEDED"
record["status"] = status_val
record["result"] = event
if conn:
conn.busy = False
elif evt == "error":
record["status"] = "FAILED"
record["error"] = event.get("error")
if conn:
conn.busy = False
record["updated_at"] = now
async def mark_disconnect_failed(self, device_id: str) -> None:
now = utc_now_iso()
async with self.lock:
for record in self.dispatches.values():
if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}:
record["status"] = "FAILED"
record["error"] = f"device disconnected: {device_id}"
record["updated_at"] = now
manager = EdgeDispatchManager()
app = FastAPI(title="Edge Dispatch Service", version="0.1.0")
@app.get("/health")
async def health() -> dict[str, Any]:
devices = await manager.list_devices()
return {
"service": "edge_dispatch",
"status": "ok",
"connected_devices": len(devices),
}
@app.get("/devices")
async def list_devices() -> dict[str, Any]:
return {"devices": await manager.list_devices()}
@app.post("/dispatch/generate", response_model=DispatchResponse)
async def dispatch_generate(body: DispatchGenerateRequest) -> DispatchResponse:
conn = await manager.select_device(body.device_id)
try:
record = await manager.create_dispatch(conn, body.request)
except WebSocketDisconnect as exc:
await manager.unregister(conn.device_id)
raise HTTPException(status_code=503, detail=f"device disconnected during dispatch: {conn.device_id}") from exc
except RuntimeError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return DispatchResponse(
dispatch_id=record["dispatch_id"],
device_id=record["device_id"],
status=record["status"],
created_at=record["created_at"],
)
@app.get("/dispatch/{dispatch_id}")
async def get_dispatch(dispatch_id: str) -> dict[str, Any]:
record = manager.dispatches.get(dispatch_id)
if record is None:
raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}")
return record
@app.websocket("/ws/edge/{device_id}")
async def edge_socket(websocket: WebSocket, device_id: str) -> None:
await websocket.accept()
await manager.register(device_id, websocket)
try:
await websocket.send_json({"event": "registered", "device_id": device_id, "ts": utc_now_iso()})
while True:
msg = await websocket.receive_json()
await manager.mark_event(device_id, msg)
except WebSocketDisconnect:
await manager.mark_disconnect_failed(device_id)
await manager.unregister(device_id)

View File

@@ -7,6 +7,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
app_host: str = Field(default="0.0.0.0", alias="APP_HOST")
app_port: int = Field(default=8000, alias="APP_PORT")
ws_gateway_host: str = Field(default="0.0.0.0", alias="WS_GATEWAY_HOST")
ws_gateway_port: int = Field(default=8010, alias="WS_GATEWAY_PORT")
worker_base_url: str = Field(default="http://127.0.0.1:8000", alias="WORKER_BASE_URL")
ws_poll_interval_sec: float = Field(default=1.0, alias="WS_POLL_INTERVAL_SEC")
edge_dispatch_host: str = Field(default="0.0.0.0", alias="EDGE_DISPATCH_HOST")
edge_dispatch_port: int = Field(default=8020, alias="EDGE_DISPATCH_PORT")
output_dir: Path = Field(default=Path("./outputs"), alias="OUTPUT_DIR")
runtime_dir: Path = Field(default=Path("./runtime"), alias="RUNTIME_DIR")

View File

@@ -0,0 +1,87 @@
import asyncio
from typing import Any
import requests
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from app.schemas import GenerateRequest
from app.settings import settings
app = FastAPI(title="Video Worker WS Gateway", version="0.1.0")
def _url(path: str) -> str:
return f"{settings.worker_base_url.rstrip('/')}{path}"
def _http_get(path: str) -> dict[str, Any]:
resp = requests.get(_url(path), timeout=20)
resp.raise_for_status()
return resp.json()
def _http_post(path: str, payload: dict[str, Any]) -> dict[str, Any]:
resp = requests.post(_url(path), json=payload, timeout=30)
resp.raise_for_status()
return resp.json()
@app.get("/health")
def health() -> dict[str, Any]:
gateway_status = "ok"
upstream_ok = False
upstream_error = None
try:
_http_get("/health")
upstream_ok = True
except Exception as exc:
upstream_error = str(exc)
return {
"service": "ws_gateway",
"status": gateway_status,
"worker_base_url": settings.worker_base_url,
"upstream_ok": upstream_ok,
"upstream_error": upstream_error,
}
@app.websocket("/ws/generate")
async def ws_generate(websocket: WebSocket) -> None:
await websocket.accept()
try:
message = await websocket.receive_json()
action = message.get("action")
if action == "watch":
task_id = message.get("task_id")
if not task_id:
await websocket.send_json({"event": "error", "error": "task_id is required for watch"})
await websocket.close(code=1008)
return
else:
payload = message.get("payload", message)
req = GenerateRequest.model_validate(payload)
created = await asyncio.to_thread(_http_post, "/generate", req.model_dump())
task_id = created["task_id"]
await websocket.send_json({"event": "accepted", **created})
last_status = None
while True:
status = await asyncio.to_thread(_http_get, f"/tasks/{task_id}")
if status.get("status") != last_status:
await websocket.send_json({"event": "status", **status})
last_status = status.get("status")
if status.get("status") in {"SUCCEEDED", "FAILED"}:
result = await asyncio.to_thread(_http_get, f"/tasks/{task_id}/result")
await websocket.send_json({"event": "result", **result})
await websocket.close(code=1000)
return
await asyncio.sleep(settings.ws_poll_interval_sec)
except WebSocketDisconnect:
return
except Exception as exc:
await websocket.send_json({"event": "error", "error": str(exc)})
await websocket.close(code=1011)