import asyncio from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Optional from uuid import uuid4 from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from pydantic import BaseModel from app.schemas import GenerateRequest def utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() class DispatchGenerateRequest(BaseModel): device_id: Optional[str] = None request: GenerateRequest class DispatchResponse(BaseModel): dispatch_id: str device_id: str status: str created_at: str @dataclass class EdgeConnection: device_id: str websocket: WebSocket connected_at: str last_seen: str busy: bool = False class EdgeDispatchManager: def __init__(self) -> None: self.connections: dict[str, EdgeConnection] = {} self.dispatches: dict[str, dict[str, Any]] = {} self.lock = asyncio.Lock() async def register(self, device_id: str, websocket: WebSocket) -> EdgeConnection: async with self.lock: conn = EdgeConnection( device_id=device_id, websocket=websocket, connected_at=utc_now_iso(), last_seen=utc_now_iso(), busy=False, ) self.connections[device_id] = conn return conn async def unregister(self, device_id: str) -> None: async with self.lock: self.connections.pop(device_id, None) async def list_devices(self) -> list[dict[str, Any]]: async with self.lock: return [ { "device_id": conn.device_id, "connected_at": conn.connected_at, "last_seen": conn.last_seen, "busy": conn.busy, } for conn in self.connections.values() ] async def select_device(self, preferred: Optional[str]) -> EdgeConnection: async with self.lock: if preferred: conn = self.connections.get(preferred) if conn is None: raise HTTPException(status_code=404, detail=f"device not found: {preferred}") if conn.busy: raise HTTPException(status_code=409, detail=f"device is busy: {preferred}") return conn for conn in self.connections.values(): if not conn.busy: return conn raise HTTPException(status_code=409, detail="no idle edge device available") async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]: dispatch_id = uuid4().hex now = utc_now_iso() record = { "dispatch_id": dispatch_id, "device_id": conn.device_id, "status": "DISPATCHED", "request": req.model_dump(), "created_at": now, "updated_at": now, "result": None, "error": None, } async with self.lock: self.dispatches[dispatch_id] = record target = self.connections.get(conn.device_id) if target is None: raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}") target.busy = True target.last_seen = now payload = { "event": "generate", "dispatch_id": dispatch_id, "request": req.model_dump(), } try: await conn.websocket.send_json(payload) except Exception as exc: async with self.lock: record["status"] = "FAILED" record["error"] = f"dispatch send failed: {exc}" record["updated_at"] = utc_now_iso() target = self.connections.get(conn.device_id) if target: target.busy = False raise return record async def mark_event(self, device_id: str, event: dict[str, Any]) -> None: now = utc_now_iso() async with self.lock: conn = self.connections.get(device_id) if conn: conn.last_seen = now dispatch_id = event.get("dispatch_id") if not dispatch_id: return record = self.dispatches.get(dispatch_id) if record is None: return evt = event.get("event") if evt == "accepted": record["status"] = "RUNNING" elif evt == "status": status_val = event.get("status") if status_val: record["status"] = status_val elif evt == "result": status_val = event.get("status") or "SUCCEEDED" record["status"] = status_val record["result"] = event if conn: conn.busy = False elif evt == "error": record["status"] = "FAILED" record["error"] = event.get("error") if conn: conn.busy = False record["updated_at"] = now async def mark_disconnect_failed(self, device_id: str) -> None: now = utc_now_iso() async with self.lock: for record in self.dispatches.values(): if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}: record["status"] = "FAILED" record["error"] = f"device disconnected: {device_id}" record["updated_at"] = now manager = EdgeDispatchManager() app = FastAPI(title="Edge Dispatch Service", version="0.1.0") @app.get("/health") async def health() -> dict[str, Any]: devices = await manager.list_devices() return { "service": "edge_dispatch", "status": "ok", "connected_devices": len(devices), } @app.get("/devices") async def list_devices() -> dict[str, Any]: return {"devices": await manager.list_devices()} @app.post("/dispatch/generate", response_model=DispatchResponse) async def dispatch_generate(body: DispatchGenerateRequest) -> DispatchResponse: conn = await manager.select_device(body.device_id) try: record = await manager.create_dispatch(conn, body.request) except WebSocketDisconnect as exc: await manager.unregister(conn.device_id) raise HTTPException(status_code=503, detail=f"device disconnected during dispatch: {conn.device_id}") from exc except RuntimeError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc return DispatchResponse( dispatch_id=record["dispatch_id"], device_id=record["device_id"], status=record["status"], created_at=record["created_at"], ) @app.get("/dispatch/{dispatch_id}") async def get_dispatch(dispatch_id: str) -> dict[str, Any]: record = manager.dispatches.get(dispatch_id) if record is None: raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}") return record @app.websocket("/ws/edge/{device_id}") async def edge_socket(websocket: WebSocket, device_id: str) -> None: await websocket.accept() await manager.register(device_id, websocket) try: await websocket.send_json({"event": "registered", "device_id": device_id, "ts": utc_now_iso()}) while True: msg = await websocket.receive_json() await manager.mark_event(device_id, msg) except WebSocketDisconnect: await manager.mark_disconnect_failed(device_id) await manager.unregister(device_id)