import asyncio from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Optional from uuid import uuid4 from fastapi import FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect from pydantic import BaseModel from app.oss_client import oss_uploader from app.schemas import GenerateRequest from app.settings import settings def utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() class DispatchGenerateRequest(BaseModel): device_id: Optional[str] = None request: GenerateRequest class DispatchResponse(BaseModel): dispatch_id: str device_id: str status: str created_at: str class DeviceCommandRequest(BaseModel): command: str dispatch_id: Optional[str] = None branch: Optional[str] = None shell_command: Optional[str] = None payload: Optional[dict[str, Any]] = None class DeviceCommandResponse(BaseModel): dispatch_id: str device_id: str command: str status: str created_at: str @dataclass class EdgeConnection: device_id: str websocket: WebSocket connected_at: str last_seen: str busy: bool = False class EdgeDispatchManager: def __init__(self) -> None: self.connections: dict[str, EdgeConnection] = {} self.dispatches: dict[str, dict[str, Any]] = {} self.commands: dict[str, dict[str, Any]] = {} self.lock = asyncio.Lock() async def register(self, device_id: str, websocket: WebSocket) -> EdgeConnection: async with self.lock: conn = EdgeConnection( device_id=device_id, websocket=websocket, connected_at=utc_now_iso(), last_seen=utc_now_iso(), busy=False, ) self.connections[device_id] = conn return conn async def unregister(self, device_id: str) -> None: async with self.lock: self.connections.pop(device_id, None) async def list_devices(self) -> list[dict[str, Any]]: async with self.lock: return [ { "device_id": conn.device_id, "connected_at": conn.connected_at, "last_seen": conn.last_seen, "busy": conn.busy, } for conn in self.connections.values() ] async def select_device(self, preferred: Optional[str]) -> EdgeConnection: async with self.lock: if preferred: conn = self.connections.get(preferred) if conn is None: raise HTTPException(status_code=404, detail=f"device not found: {preferred}") if conn.busy: raise HTTPException(status_code=409, detail=f"device is busy: {preferred}") return conn for conn in self.connections.values(): if not conn.busy: return conn raise HTTPException(status_code=409, detail="no idle edge device available") async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]: await self._prune_if_needed() dispatch_id = uuid4().hex now = utc_now_iso() record = { "dispatch_id": dispatch_id, "device_id": conn.device_id, "status": "DISPATCHED", "request": req.model_dump(), "created_at": now, "updated_at": now, "result": None, "error": None, "artifact_urls": {}, } async with self.lock: self.dispatches[dispatch_id] = record target = self.connections.get(conn.device_id) if target is None: raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}") target.busy = True target.last_seen = now payload = { "event": "generate", "dispatch_id": dispatch_id, "request": req.model_dump(), } try: await conn.websocket.send_json(payload) except Exception as exc: async with self.lock: record["status"] = "FAILED" record["error"] = f"dispatch send failed: {exc}" record["updated_at"] = utc_now_iso() target = self.connections.get(conn.device_id) if target: target.busy = False raise return record async def create_command(self, conn: EdgeConnection, body: DeviceCommandRequest) -> dict[str, Any]: await self._prune_if_needed() command = body.command.strip().lower() allowed = {"update_code", "restart_service", "ping"} if command not in allowed: raise HTTPException(status_code=400, detail=f"unsupported command: {command}") dispatch_id = body.dispatch_id or uuid4().hex now = utc_now_iso() record = { "dispatch_id": dispatch_id, "device_id": conn.device_id, "command": command, "status": "DISPATCHED", "created_at": now, "updated_at": now, "request": body.model_dump(), "result": None, "error": None, } payload: dict[str, Any] = { "event": command, "dispatch_id": dispatch_id, } if body.branch: payload["branch"] = body.branch if body.shell_command: payload["command"] = body.shell_command if body.payload: payload.update(body.payload) async with self.lock: self.commands[dispatch_id] = record target = self.connections.get(conn.device_id) if target is None: raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}") target.last_seen = now try: await conn.websocket.send_json(payload) except Exception as exc: async with self.lock: record["status"] = "FAILED" record["error"] = f"command send failed: {exc}" record["updated_at"] = utc_now_iso() raise return record async def _prune_if_needed(self) -> None: max_records = max(100, int(settings.edge_max_dispatch_records)) async with self.lock: total = len(self.dispatches) + len(self.commands) if total < max_records: return over = total - max_records + 1 done = [v for v in self.dispatches.values() if v.get("status") in {"SUCCEEDED", "FAILED"}] done.extend([v for v in self.commands.values() if v.get("status") in {"SUCCEEDED", "FAILED"}]) done.sort(key=lambda x: x.get("updated_at", "")) for rec in done[:over]: self.dispatches.pop(rec["dispatch_id"], None) self.commands.pop(rec["dispatch_id"], None) async def mark_event(self, device_id: str, event: dict[str, Any]) -> None: now = utc_now_iso() async with self.lock: conn = self.connections.get(device_id) if conn: conn.last_seen = now dispatch_id = event.get("dispatch_id") if not dispatch_id: return record = self.dispatches.get(dispatch_id) if record is None: record = self.commands.get(dispatch_id) if record is None: return evt = event.get("event") if evt == "accepted": record["status"] = "RUNNING" elif evt == "status": status_val = event.get("status") if status_val: record["status"] = status_val elif evt == "result": status_val = event.get("status") or "SUCCEEDED" record["status"] = status_val record["result"] = event if conn: conn.busy = False elif evt == "error": record["status"] = "FAILED" record["error"] = event.get("error") if conn: conn.busy = False elif evt == "command_status": status_val = event.get("status") or "RUNNING" record["status"] = status_val elif evt == "command_result": status_val = event.get("status") or "SUCCEEDED" record["status"] = status_val record["result"] = event if status_val == "FAILED": record["error"] = event.get("error") or event.get("stderr") elif evt == "pong": record["status"] = "SUCCEEDED" record["result"] = event record["updated_at"] = now async def mark_disconnect_failed(self, device_id: str) -> None: now = utc_now_iso() async with self.lock: for record in self.dispatches.values(): if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}: record["status"] = "FAILED" record["error"] = f"device disconnected: {device_id}" record["updated_at"] = now for record in self.commands.values(): if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}: record["status"] = "FAILED" record["error"] = f"device disconnected: {device_id}" record["updated_at"] = now manager = EdgeDispatchManager() app = FastAPI(title="Edge Dispatch Service", version="0.1.0") @app.get("/health") async def health() -> dict[str, Any]: devices = await manager.list_devices() return { "service": "edge_dispatch", "status": "ok", "connected_devices": len(devices), } @app.get("/devices") async def list_devices() -> dict[str, Any]: return {"devices": await manager.list_devices()} @app.post("/dispatch/generate", response_model=DispatchResponse) async def dispatch_generate(body: DispatchGenerateRequest) -> DispatchResponse: conn = await manager.select_device(body.device_id) try: record = await manager.create_dispatch(conn, body.request) except WebSocketDisconnect as exc: await manager.unregister(conn.device_id) raise HTTPException(status_code=503, detail=f"device disconnected during dispatch: {conn.device_id}") from exc except RuntimeError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc return DispatchResponse( dispatch_id=record["dispatch_id"], device_id=record["device_id"], status=record["status"], created_at=record["created_at"], ) @app.get("/dispatch/{dispatch_id}") async def get_dispatch(dispatch_id: str) -> dict[str, Any]: record = manager.dispatches.get(dispatch_id) if record is None: raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}") return record @app.post("/devices/{device_id}/command", response_model=DeviceCommandResponse) async def device_command(device_id: str, body: DeviceCommandRequest) -> DeviceCommandResponse: conn = await manager.select_device(device_id) try: record = await manager.create_command(conn, body) except WebSocketDisconnect as exc: await manager.unregister(conn.device_id) raise HTTPException(status_code=503, detail=f"device disconnected during command: {conn.device_id}") from exc except RuntimeError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc return DeviceCommandResponse( dispatch_id=record["dispatch_id"], device_id=record["device_id"], command=record["command"], status=record["status"], created_at=record["created_at"], ) @app.get("/commands/{dispatch_id}") async def get_command(dispatch_id: str) -> dict[str, Any]: record = manager.commands.get(dispatch_id) if record is None: raise HTTPException(status_code=404, detail=f"command not found: {dispatch_id}") return record @app.post("/dispatch/{dispatch_id}/artifacts") async def upload_artifacts( dispatch_id: str, task_id: str = Form(default=""), status: str = Form(default="SUCCEEDED"), files: list[UploadFile] = File(default_factory=list), ) -> dict[str, Any]: record = manager.dispatches.get(dispatch_id) if record is None: raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}") if not oss_uploader.enabled: raise HTTPException(status_code=400, detail="OSS upload is disabled, set OSS_ENABLED=true") if not files: raise HTTPException(status_code=400, detail="no files uploaded") uploaded: dict[str, dict[str, str]] = {} for file in files: name = file.filename or "artifact.bin" try: result = await asyncio.to_thread(oss_uploader.upload_fileobj, dispatch_id, name, file.file) uploaded[Path(name).name] = result finally: await file.close() now = utc_now_iso() async with manager.lock: record["artifact_urls"] = uploaded record["result"] = { "event": "result", "dispatch_id": dispatch_id, "task_id": task_id or None, "status": status, "artifact_urls": uploaded, } record["status"] = status record["updated_at"] = now conn = manager.connections.get(record["device_id"]) if conn: conn.busy = False return { "dispatch_id": dispatch_id, "status": status, "artifact_urls": uploaded, "updated_at": now, } @app.websocket("/ws/edge/{device_id}") async def edge_socket(websocket: WebSocket, device_id: str) -> None: await websocket.accept() await manager.register(device_id, websocket) try: await websocket.send_json({"event": "registered", "device_id": device_id, "ts": utc_now_iso()}) while True: msg = await websocket.receive_json() await manager.mark_event(device_id, msg) except WebSocketDisconnect: await manager.mark_disconnect_failed(device_id) await manager.unregister(device_id)