import asyncio from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Optional from uuid import uuid4 from fastapi import FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect from pydantic import BaseModel from app.oss_client import oss_uploader from app.schemas import GenerateRequest from app.settings import settings def utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() class DispatchGenerateRequest(BaseModel): device_id: Optional[str] = None request: GenerateRequest class DispatchResponse(BaseModel): dispatch_id: str device_id: str status: str created_at: str @dataclass class EdgeConnection: device_id: str websocket: WebSocket connected_at: str last_seen: str busy: bool = False class EdgeDispatchManager: def __init__(self) -> None: self.connections: dict[str, EdgeConnection] = {} self.dispatches: dict[str, dict[str, Any]] = {} self.lock = asyncio.Lock() async def register(self, device_id: str, websocket: WebSocket) -> EdgeConnection: async with self.lock: conn = EdgeConnection( device_id=device_id, websocket=websocket, connected_at=utc_now_iso(), last_seen=utc_now_iso(), busy=False, ) self.connections[device_id] = conn return conn async def unregister(self, device_id: str) -> None: async with self.lock: self.connections.pop(device_id, None) async def list_devices(self) -> list[dict[str, Any]]: async with self.lock: return [ { "device_id": conn.device_id, "connected_at": conn.connected_at, "last_seen": conn.last_seen, "busy": conn.busy, } for conn in self.connections.values() ] async def select_device(self, preferred: Optional[str]) -> EdgeConnection: async with self.lock: if preferred: conn = self.connections.get(preferred) if conn is None: raise HTTPException(status_code=404, detail=f"device not found: {preferred}") if conn.busy: raise HTTPException(status_code=409, detail=f"device is busy: {preferred}") return conn for conn in self.connections.values(): if not conn.busy: return conn raise HTTPException(status_code=409, detail="no idle edge device available") async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]: await self._prune_if_needed() dispatch_id = uuid4().hex now = utc_now_iso() record = { "dispatch_id": dispatch_id, "device_id": conn.device_id, "status": "DISPATCHED", "request": req.model_dump(), "created_at": now, "updated_at": now, "result": None, "error": None, "artifact_urls": {}, } async with self.lock: self.dispatches[dispatch_id] = record target = self.connections.get(conn.device_id) if target is None: raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}") target.busy = True target.last_seen = now payload = { "event": "generate", "dispatch_id": dispatch_id, "request": req.model_dump(), } try: await conn.websocket.send_json(payload) except Exception as exc: async with self.lock: record["status"] = "FAILED" record["error"] = f"dispatch send failed: {exc}" record["updated_at"] = utc_now_iso() target = self.connections.get(conn.device_id) if target: target.busy = False raise return record async def _prune_if_needed(self) -> None: max_records = max(100, int(settings.edge_max_dispatch_records)) async with self.lock: total = len(self.dispatches) if total < max_records: return over = total - max_records + 1 done = [v for v in self.dispatches.values() if v.get("status") in {"SUCCEEDED", "FAILED"}] done.sort(key=lambda x: x.get("updated_at", "")) for rec in done[:over]: self.dispatches.pop(rec["dispatch_id"], None) async def mark_event(self, device_id: str, event: dict[str, Any]) -> None: now = utc_now_iso() async with self.lock: conn = self.connections.get(device_id) if conn: conn.last_seen = now dispatch_id = event.get("dispatch_id") if not dispatch_id: return record = self.dispatches.get(dispatch_id) if record is None: return evt = event.get("event") if evt == "accepted": record["status"] = "RUNNING" elif evt == "status": status_val = event.get("status") if status_val: record["status"] = status_val elif evt == "result": status_val = event.get("status") or "SUCCEEDED" record["status"] = status_val record["result"] = event if conn: conn.busy = False elif evt == "error": record["status"] = "FAILED" record["error"] = event.get("error") if conn: conn.busy = False record["updated_at"] = now async def mark_disconnect_failed(self, device_id: str) -> None: now = utc_now_iso() async with self.lock: for record in self.dispatches.values(): if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}: record["status"] = "FAILED" record["error"] = f"device disconnected: {device_id}" record["updated_at"] = now manager = EdgeDispatchManager() app = FastAPI(title="Edge Dispatch Service", version="0.1.0") @app.get("/health") async def health() -> dict[str, Any]: devices = await manager.list_devices() return { "service": "edge_dispatch", "status": "ok", "connected_devices": len(devices), } @app.get("/devices") async def list_devices() -> dict[str, Any]: return {"devices": await manager.list_devices()} @app.post("/dispatch/generate", response_model=DispatchResponse) async def dispatch_generate(body: DispatchGenerateRequest) -> DispatchResponse: conn = await manager.select_device(body.device_id) try: record = await manager.create_dispatch(conn, body.request) except WebSocketDisconnect as exc: await manager.unregister(conn.device_id) raise HTTPException(status_code=503, detail=f"device disconnected during dispatch: {conn.device_id}") from exc except RuntimeError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc return DispatchResponse( dispatch_id=record["dispatch_id"], device_id=record["device_id"], status=record["status"], created_at=record["created_at"], ) @app.get("/dispatch/{dispatch_id}") async def get_dispatch(dispatch_id: str) -> dict[str, Any]: record = manager.dispatches.get(dispatch_id) if record is None: raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}") return record @app.post("/dispatch/{dispatch_id}/artifacts") async def upload_artifacts( dispatch_id: str, task_id: str = Form(default=""), status: str = Form(default="SUCCEEDED"), files: list[UploadFile] = File(default_factory=list), ) -> dict[str, Any]: record = manager.dispatches.get(dispatch_id) if record is None: raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}") if not oss_uploader.enabled: raise HTTPException(status_code=400, detail="OSS upload is disabled, set OSS_ENABLED=true") if not files: raise HTTPException(status_code=400, detail="no files uploaded") uploaded: dict[str, dict[str, str]] = {} for file in files: name = file.filename or "artifact.bin" try: result = await asyncio.to_thread(oss_uploader.upload_fileobj, dispatch_id, name, file.file) uploaded[Path(name).name] = result finally: await file.close() now = utc_now_iso() async with manager.lock: record["artifact_urls"] = uploaded record["result"] = { "event": "result", "dispatch_id": dispatch_id, "task_id": task_id or None, "status": status, "artifact_urls": uploaded, } record["status"] = status record["updated_at"] = now conn = manager.connections.get(record["device_id"]) if conn: conn.busy = False return { "dispatch_id": dispatch_id, "status": status, "artifact_urls": uploaded, "updated_at": now, } @app.websocket("/ws/edge/{device_id}") async def edge_socket(websocket: WebSocket, device_id: str) -> None: await websocket.accept() await manager.register(device_id, websocket) try: await websocket.send_json({"event": "registered", "device_id": device_id, "ts": utc_now_iso()}) while True: msg = await websocket.receive_json() await manager.mark_event(device_id, msg) except WebSocketDisconnect: await manager.mark_disconnect_failed(device_id) await manager.unregister(device_id)