298 lines
9.9 KiB
Python
298 lines
9.9 KiB
Python
import asyncio
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
|
from pydantic import BaseModel
|
|
|
|
from app.oss_client import oss_uploader
|
|
from app.schemas import GenerateRequest
|
|
from app.settings import settings
|
|
|
|
|
|
def utc_now_iso() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
class DispatchGenerateRequest(BaseModel):
|
|
device_id: Optional[str] = None
|
|
request: GenerateRequest
|
|
|
|
|
|
class DispatchResponse(BaseModel):
|
|
dispatch_id: str
|
|
device_id: str
|
|
status: str
|
|
created_at: str
|
|
|
|
|
|
@dataclass
|
|
class EdgeConnection:
|
|
device_id: str
|
|
websocket: WebSocket
|
|
connected_at: str
|
|
last_seen: str
|
|
busy: bool = False
|
|
|
|
|
|
class EdgeDispatchManager:
|
|
def __init__(self) -> None:
|
|
self.connections: dict[str, EdgeConnection] = {}
|
|
self.dispatches: dict[str, dict[str, Any]] = {}
|
|
self.lock = asyncio.Lock()
|
|
|
|
async def register(self, device_id: str, websocket: WebSocket) -> EdgeConnection:
|
|
async with self.lock:
|
|
conn = EdgeConnection(
|
|
device_id=device_id,
|
|
websocket=websocket,
|
|
connected_at=utc_now_iso(),
|
|
last_seen=utc_now_iso(),
|
|
busy=False,
|
|
)
|
|
self.connections[device_id] = conn
|
|
return conn
|
|
|
|
async def unregister(self, device_id: str) -> None:
|
|
async with self.lock:
|
|
self.connections.pop(device_id, None)
|
|
|
|
async def list_devices(self) -> list[dict[str, Any]]:
|
|
async with self.lock:
|
|
return [
|
|
{
|
|
"device_id": conn.device_id,
|
|
"connected_at": conn.connected_at,
|
|
"last_seen": conn.last_seen,
|
|
"busy": conn.busy,
|
|
}
|
|
for conn in self.connections.values()
|
|
]
|
|
|
|
async def select_device(self, preferred: Optional[str]) -> EdgeConnection:
|
|
async with self.lock:
|
|
if preferred:
|
|
conn = self.connections.get(preferred)
|
|
if conn is None:
|
|
raise HTTPException(status_code=404, detail=f"device not found: {preferred}")
|
|
if conn.busy:
|
|
raise HTTPException(status_code=409, detail=f"device is busy: {preferred}")
|
|
return conn
|
|
|
|
for conn in self.connections.values():
|
|
if not conn.busy:
|
|
return conn
|
|
|
|
raise HTTPException(status_code=409, detail="no idle edge device available")
|
|
|
|
async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]:
|
|
await self._prune_if_needed()
|
|
dispatch_id = uuid4().hex
|
|
now = utc_now_iso()
|
|
record = {
|
|
"dispatch_id": dispatch_id,
|
|
"device_id": conn.device_id,
|
|
"status": "DISPATCHED",
|
|
"request": req.model_dump(),
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"result": None,
|
|
"error": None,
|
|
"artifact_urls": {},
|
|
}
|
|
|
|
async with self.lock:
|
|
self.dispatches[dispatch_id] = record
|
|
target = self.connections.get(conn.device_id)
|
|
if target is None:
|
|
raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}")
|
|
target.busy = True
|
|
target.last_seen = now
|
|
|
|
payload = {
|
|
"event": "generate",
|
|
"dispatch_id": dispatch_id,
|
|
"request": req.model_dump(),
|
|
}
|
|
try:
|
|
await conn.websocket.send_json(payload)
|
|
except Exception as exc:
|
|
async with self.lock:
|
|
record["status"] = "FAILED"
|
|
record["error"] = f"dispatch send failed: {exc}"
|
|
record["updated_at"] = utc_now_iso()
|
|
target = self.connections.get(conn.device_id)
|
|
if target:
|
|
target.busy = False
|
|
raise
|
|
return record
|
|
|
|
async def _prune_if_needed(self) -> None:
|
|
max_records = max(100, int(settings.edge_max_dispatch_records))
|
|
async with self.lock:
|
|
total = len(self.dispatches)
|
|
if total < max_records:
|
|
return
|
|
over = total - max_records + 1
|
|
done = [v for v in self.dispatches.values() if v.get("status") in {"SUCCEEDED", "FAILED"}]
|
|
done.sort(key=lambda x: x.get("updated_at", ""))
|
|
for rec in done[:over]:
|
|
self.dispatches.pop(rec["dispatch_id"], None)
|
|
|
|
async def mark_event(self, device_id: str, event: dict[str, Any]) -> None:
|
|
now = utc_now_iso()
|
|
async with self.lock:
|
|
conn = self.connections.get(device_id)
|
|
if conn:
|
|
conn.last_seen = now
|
|
|
|
dispatch_id = event.get("dispatch_id")
|
|
if not dispatch_id:
|
|
return
|
|
|
|
record = self.dispatches.get(dispatch_id)
|
|
if record is None:
|
|
return
|
|
|
|
evt = event.get("event")
|
|
if evt == "accepted":
|
|
record["status"] = "RUNNING"
|
|
elif evt == "status":
|
|
status_val = event.get("status")
|
|
if status_val:
|
|
record["status"] = status_val
|
|
elif evt == "result":
|
|
status_val = event.get("status") or "SUCCEEDED"
|
|
record["status"] = status_val
|
|
record["result"] = event
|
|
if conn:
|
|
conn.busy = False
|
|
elif evt == "error":
|
|
record["status"] = "FAILED"
|
|
record["error"] = event.get("error")
|
|
if conn:
|
|
conn.busy = False
|
|
|
|
record["updated_at"] = now
|
|
|
|
async def mark_disconnect_failed(self, device_id: str) -> None:
|
|
now = utc_now_iso()
|
|
async with self.lock:
|
|
for record in self.dispatches.values():
|
|
if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}:
|
|
record["status"] = "FAILED"
|
|
record["error"] = f"device disconnected: {device_id}"
|
|
record["updated_at"] = now
|
|
|
|
|
|
manager = EdgeDispatchManager()
|
|
app = FastAPI(title="Edge Dispatch Service", version="0.1.0")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict[str, Any]:
|
|
devices = await manager.list_devices()
|
|
return {
|
|
"service": "edge_dispatch",
|
|
"status": "ok",
|
|
"connected_devices": len(devices),
|
|
}
|
|
|
|
|
|
@app.get("/devices")
|
|
async def list_devices() -> dict[str, Any]:
|
|
return {"devices": await manager.list_devices()}
|
|
|
|
|
|
@app.post("/dispatch/generate", response_model=DispatchResponse)
|
|
async def dispatch_generate(body: DispatchGenerateRequest) -> DispatchResponse:
|
|
conn = await manager.select_device(body.device_id)
|
|
try:
|
|
record = await manager.create_dispatch(conn, body.request)
|
|
except WebSocketDisconnect as exc:
|
|
await manager.unregister(conn.device_id)
|
|
raise HTTPException(status_code=503, detail=f"device disconnected during dispatch: {conn.device_id}") from exc
|
|
except RuntimeError as exc:
|
|
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
|
|
|
return DispatchResponse(
|
|
dispatch_id=record["dispatch_id"],
|
|
device_id=record["device_id"],
|
|
status=record["status"],
|
|
created_at=record["created_at"],
|
|
)
|
|
|
|
|
|
@app.get("/dispatch/{dispatch_id}")
|
|
async def get_dispatch(dispatch_id: str) -> dict[str, Any]:
|
|
record = manager.dispatches.get(dispatch_id)
|
|
if record is None:
|
|
raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}")
|
|
return record
|
|
|
|
|
|
@app.post("/dispatch/{dispatch_id}/artifacts")
|
|
async def upload_artifacts(
|
|
dispatch_id: str,
|
|
task_id: str = Form(default=""),
|
|
status: str = Form(default="SUCCEEDED"),
|
|
files: list[UploadFile] = File(default_factory=list),
|
|
) -> dict[str, Any]:
|
|
record = manager.dispatches.get(dispatch_id)
|
|
if record is None:
|
|
raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}")
|
|
if not oss_uploader.enabled:
|
|
raise HTTPException(status_code=400, detail="OSS upload is disabled, set OSS_ENABLED=true")
|
|
if not files:
|
|
raise HTTPException(status_code=400, detail="no files uploaded")
|
|
|
|
uploaded: dict[str, dict[str, str]] = {}
|
|
for file in files:
|
|
name = file.filename or "artifact.bin"
|
|
try:
|
|
result = await asyncio.to_thread(oss_uploader.upload_fileobj, dispatch_id, name, file.file)
|
|
uploaded[Path(name).name] = result
|
|
finally:
|
|
await file.close()
|
|
|
|
now = utc_now_iso()
|
|
async with manager.lock:
|
|
record["artifact_urls"] = uploaded
|
|
record["result"] = {
|
|
"event": "result",
|
|
"dispatch_id": dispatch_id,
|
|
"task_id": task_id or None,
|
|
"status": status,
|
|
"artifact_urls": uploaded,
|
|
}
|
|
record["status"] = status
|
|
record["updated_at"] = now
|
|
conn = manager.connections.get(record["device_id"])
|
|
if conn:
|
|
conn.busy = False
|
|
|
|
return {
|
|
"dispatch_id": dispatch_id,
|
|
"status": status,
|
|
"artifact_urls": uploaded,
|
|
"updated_at": now,
|
|
}
|
|
|
|
|
|
@app.websocket("/ws/edge/{device_id}")
|
|
async def edge_socket(websocket: WebSocket, device_id: str) -> None:
|
|
await websocket.accept()
|
|
await manager.register(device_id, websocket)
|
|
|
|
try:
|
|
await websocket.send_json({"event": "registered", "device_id": device_id, "ts": utc_now_iso()})
|
|
while True:
|
|
msg = await websocket.receive_json()
|
|
await manager.mark_event(device_id, msg)
|
|
except WebSocketDisconnect:
|
|
await manager.mark_disconnect_failed(device_id)
|
|
await manager.unregister(device_id)
|