Files
AI_A4000/video_worker/app/edge_dispatch_service.py

413 lines
14 KiB
Python

import asyncio
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
from uuid import uuid4
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from app.oss_client import oss_uploader
from app.schemas import GenerateRequest
from app.settings import settings
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
class DispatchGenerateRequest(BaseModel):
device_id: Optional[str] = None
request: GenerateRequest
class DispatchResponse(BaseModel):
dispatch_id: str
device_id: str
status: str
created_at: str
class DeviceCommandRequest(BaseModel):
command: str
dispatch_id: Optional[str] = None
branch: Optional[str] = None
shell_command: Optional[str] = None
payload: Optional[dict[str, Any]] = None
class DeviceCommandResponse(BaseModel):
dispatch_id: str
device_id: str
command: str
status: str
created_at: str
@dataclass
class EdgeConnection:
device_id: str
websocket: WebSocket
connected_at: str
last_seen: str
busy: bool = False
class EdgeDispatchManager:
def __init__(self) -> None:
self.connections: dict[str, EdgeConnection] = {}
self.dispatches: dict[str, dict[str, Any]] = {}
self.commands: dict[str, dict[str, Any]] = {}
self.lock = asyncio.Lock()
async def register(self, device_id: str, websocket: WebSocket) -> EdgeConnection:
async with self.lock:
conn = EdgeConnection(
device_id=device_id,
websocket=websocket,
connected_at=utc_now_iso(),
last_seen=utc_now_iso(),
busy=False,
)
self.connections[device_id] = conn
return conn
async def unregister(self, device_id: str) -> None:
async with self.lock:
self.connections.pop(device_id, None)
async def list_devices(self) -> list[dict[str, Any]]:
async with self.lock:
return [
{
"device_id": conn.device_id,
"connected_at": conn.connected_at,
"last_seen": conn.last_seen,
"busy": conn.busy,
}
for conn in self.connections.values()
]
async def select_device(self, preferred: Optional[str]) -> EdgeConnection:
async with self.lock:
if preferred:
conn = self.connections.get(preferred)
if conn is None:
raise HTTPException(status_code=404, detail=f"device not found: {preferred}")
if conn.busy:
raise HTTPException(status_code=409, detail=f"device is busy: {preferred}")
return conn
for conn in self.connections.values():
if not conn.busy:
return conn
raise HTTPException(status_code=409, detail="no idle edge device available")
async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]:
await self._prune_if_needed()
dispatch_id = uuid4().hex
now = utc_now_iso()
record = {
"dispatch_id": dispatch_id,
"device_id": conn.device_id,
"status": "DISPATCHED",
"request": req.model_dump(),
"created_at": now,
"updated_at": now,
"result": None,
"error": None,
"artifact_urls": {},
}
async with self.lock:
self.dispatches[dispatch_id] = record
target = self.connections.get(conn.device_id)
if target is None:
raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}")
target.busy = True
target.last_seen = now
payload = {
"event": "generate",
"dispatch_id": dispatch_id,
"request": req.model_dump(),
}
try:
await conn.websocket.send_json(payload)
except Exception as exc:
async with self.lock:
record["status"] = "FAILED"
record["error"] = f"dispatch send failed: {exc}"
record["updated_at"] = utc_now_iso()
target = self.connections.get(conn.device_id)
if target:
target.busy = False
raise
return record
async def create_command(self, conn: EdgeConnection, body: DeviceCommandRequest) -> dict[str, Any]:
await self._prune_if_needed()
command = body.command.strip().lower()
allowed = {"update_code", "restart_service", "ping"}
if command not in allowed:
raise HTTPException(status_code=400, detail=f"unsupported command: {command}")
dispatch_id = body.dispatch_id or uuid4().hex
now = utc_now_iso()
record = {
"dispatch_id": dispatch_id,
"device_id": conn.device_id,
"command": command,
"status": "DISPATCHED",
"created_at": now,
"updated_at": now,
"request": body.model_dump(),
"result": None,
"error": None,
}
payload: dict[str, Any] = {
"event": command,
"dispatch_id": dispatch_id,
}
if body.branch:
payload["branch"] = body.branch
if body.shell_command:
payload["command"] = body.shell_command
if body.payload:
payload.update(body.payload)
async with self.lock:
self.commands[dispatch_id] = record
target = self.connections.get(conn.device_id)
if target is None:
raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}")
target.last_seen = now
try:
await conn.websocket.send_json(payload)
except Exception as exc:
async with self.lock:
record["status"] = "FAILED"
record["error"] = f"command send failed: {exc}"
record["updated_at"] = utc_now_iso()
raise
return record
async def _prune_if_needed(self) -> None:
max_records = max(100, int(settings.edge_max_dispatch_records))
async with self.lock:
total = len(self.dispatches) + len(self.commands)
if total < max_records:
return
over = total - max_records + 1
done = [v for v in self.dispatches.values() if v.get("status") in {"SUCCEEDED", "FAILED"}]
done.extend([v for v in self.commands.values() if v.get("status") in {"SUCCEEDED", "FAILED"}])
done.sort(key=lambda x: x.get("updated_at", ""))
for rec in done[:over]:
self.dispatches.pop(rec["dispatch_id"], None)
self.commands.pop(rec["dispatch_id"], None)
async def mark_event(self, device_id: str, event: dict[str, Any]) -> None:
now = utc_now_iso()
async with self.lock:
conn = self.connections.get(device_id)
if conn:
conn.last_seen = now
dispatch_id = event.get("dispatch_id")
if not dispatch_id:
return
record = self.dispatches.get(dispatch_id)
if record is None:
record = self.commands.get(dispatch_id)
if record is None:
return
evt = event.get("event")
if evt == "accepted":
record["status"] = "RUNNING"
elif evt == "status":
status_val = event.get("status")
if status_val:
record["status"] = status_val
elif evt == "result":
status_val = event.get("status") or "SUCCEEDED"
record["status"] = status_val
record["result"] = event
if conn:
conn.busy = False
elif evt == "error":
record["status"] = "FAILED"
record["error"] = event.get("error")
if conn:
conn.busy = False
elif evt == "command_status":
status_val = event.get("status") or "RUNNING"
record["status"] = status_val
elif evt == "command_result":
status_val = event.get("status") or "SUCCEEDED"
record["status"] = status_val
record["result"] = event
if status_val == "FAILED":
record["error"] = event.get("error") or event.get("stderr")
elif evt == "pong":
record["status"] = "SUCCEEDED"
record["result"] = event
record["updated_at"] = now
async def mark_disconnect_failed(self, device_id: str) -> None:
now = utc_now_iso()
async with self.lock:
for record in self.dispatches.values():
if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}:
record["status"] = "FAILED"
record["error"] = f"device disconnected: {device_id}"
record["updated_at"] = now
for record in self.commands.values():
if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}:
record["status"] = "FAILED"
record["error"] = f"device disconnected: {device_id}"
record["updated_at"] = now
manager = EdgeDispatchManager()
app = FastAPI(title="Edge Dispatch Service", version="0.1.0")
@app.get("/health")
async def health() -> dict[str, Any]:
devices = await manager.list_devices()
return {
"service": "edge_dispatch",
"status": "ok",
"connected_devices": len(devices),
}
@app.get("/devices")
async def list_devices() -> dict[str, Any]:
return {"devices": await manager.list_devices()}
@app.post("/dispatch/generate", response_model=DispatchResponse)
async def dispatch_generate(body: DispatchGenerateRequest) -> DispatchResponse:
conn = await manager.select_device(body.device_id)
try:
record = await manager.create_dispatch(conn, body.request)
except WebSocketDisconnect as exc:
await manager.unregister(conn.device_id)
raise HTTPException(status_code=503, detail=f"device disconnected during dispatch: {conn.device_id}") from exc
except RuntimeError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return DispatchResponse(
dispatch_id=record["dispatch_id"],
device_id=record["device_id"],
status=record["status"],
created_at=record["created_at"],
)
@app.get("/dispatch/{dispatch_id}")
async def get_dispatch(dispatch_id: str) -> dict[str, Any]:
record = manager.dispatches.get(dispatch_id)
if record is None:
raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}")
return record
@app.post("/devices/{device_id}/command", response_model=DeviceCommandResponse)
async def device_command(device_id: str, body: DeviceCommandRequest) -> DeviceCommandResponse:
conn = await manager.select_device(device_id)
try:
record = await manager.create_command(conn, body)
except WebSocketDisconnect as exc:
await manager.unregister(conn.device_id)
raise HTTPException(status_code=503, detail=f"device disconnected during command: {conn.device_id}") from exc
except RuntimeError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return DeviceCommandResponse(
dispatch_id=record["dispatch_id"],
device_id=record["device_id"],
command=record["command"],
status=record["status"],
created_at=record["created_at"],
)
@app.get("/commands/{dispatch_id}")
async def get_command(dispatch_id: str) -> dict[str, Any]:
record = manager.commands.get(dispatch_id)
if record is None:
raise HTTPException(status_code=404, detail=f"command not found: {dispatch_id}")
return record
@app.post("/dispatch/{dispatch_id}/artifacts")
async def upload_artifacts(
dispatch_id: str,
task_id: str = Form(default=""),
status: str = Form(default="SUCCEEDED"),
files: list[UploadFile] = File(default_factory=list),
) -> dict[str, Any]:
record = manager.dispatches.get(dispatch_id)
if record is None:
raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}")
if not oss_uploader.enabled:
raise HTTPException(status_code=400, detail="OSS upload is disabled, set OSS_ENABLED=true")
if not files:
raise HTTPException(status_code=400, detail="no files uploaded")
uploaded: dict[str, dict[str, str]] = {}
for file in files:
name = file.filename or "artifact.bin"
try:
result = await asyncio.to_thread(oss_uploader.upload_fileobj, dispatch_id, name, file.file)
uploaded[Path(name).name] = result
finally:
await file.close()
now = utc_now_iso()
async with manager.lock:
record["artifact_urls"] = uploaded
record["result"] = {
"event": "result",
"dispatch_id": dispatch_id,
"task_id": task_id or None,
"status": status,
"artifact_urls": uploaded,
}
record["status"] = status
record["updated_at"] = now
conn = manager.connections.get(record["device_id"])
if conn:
conn.busy = False
return {
"dispatch_id": dispatch_id,
"status": status,
"artifact_urls": uploaded,
"updated_at": now,
}
@app.websocket("/ws/edge/{device_id}")
async def edge_socket(websocket: WebSocket, device_id: str) -> None:
await websocket.accept()
await manager.register(device_id, websocket)
try:
await websocket.send_json({"event": "registered", "device_id": device_id, "ts": utc_now_iso()})
while True:
msg = await websocket.receive_json()
await manager.mark_event(device_id, msg)
except WebSocketDisconnect:
await manager.mark_disconnect_failed(device_id)
await manager.unregister(device_id)