fix:优化wsl环境下的边缘设备执行逻辑
This commit is contained in:
@@ -29,6 +29,22 @@ class DispatchResponse(BaseModel):
|
||||
created_at: str
|
||||
|
||||
|
||||
class DeviceCommandRequest(BaseModel):
|
||||
command: str
|
||||
dispatch_id: Optional[str] = None
|
||||
branch: Optional[str] = None
|
||||
shell_command: Optional[str] = None
|
||||
payload: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DeviceCommandResponse(BaseModel):
|
||||
dispatch_id: str
|
||||
device_id: str
|
||||
command: str
|
||||
status: str
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class EdgeConnection:
|
||||
device_id: str
|
||||
@@ -42,6 +58,7 @@ class EdgeDispatchManager:
|
||||
def __init__(self) -> None:
|
||||
self.connections: dict[str, EdgeConnection] = {}
|
||||
self.dispatches: dict[str, dict[str, Any]] = {}
|
||||
self.commands: dict[str, dict[str, Any]] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def register(self, device_id: str, websocket: WebSocket) -> EdgeConnection:
|
||||
@@ -130,17 +147,68 @@ class EdgeDispatchManager:
|
||||
raise
|
||||
return record
|
||||
|
||||
async def create_command(self, conn: EdgeConnection, body: DeviceCommandRequest) -> dict[str, Any]:
|
||||
await self._prune_if_needed()
|
||||
command = body.command.strip().lower()
|
||||
allowed = {"update_code", "restart_service", "ping"}
|
||||
if command not in allowed:
|
||||
raise HTTPException(status_code=400, detail=f"unsupported command: {command}")
|
||||
|
||||
dispatch_id = body.dispatch_id or uuid4().hex
|
||||
now = utc_now_iso()
|
||||
record = {
|
||||
"dispatch_id": dispatch_id,
|
||||
"device_id": conn.device_id,
|
||||
"command": command,
|
||||
"status": "DISPATCHED",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"request": body.model_dump(),
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"event": command,
|
||||
"dispatch_id": dispatch_id,
|
||||
}
|
||||
if body.branch:
|
||||
payload["branch"] = body.branch
|
||||
if body.shell_command:
|
||||
payload["command"] = body.shell_command
|
||||
if body.payload:
|
||||
payload.update(body.payload)
|
||||
|
||||
async with self.lock:
|
||||
self.commands[dispatch_id] = record
|
||||
target = self.connections.get(conn.device_id)
|
||||
if target is None:
|
||||
raise HTTPException(status_code=404, detail=f"device disconnected: {conn.device_id}")
|
||||
target.last_seen = now
|
||||
|
||||
try:
|
||||
await conn.websocket.send_json(payload)
|
||||
except Exception as exc:
|
||||
async with self.lock:
|
||||
record["status"] = "FAILED"
|
||||
record["error"] = f"command send failed: {exc}"
|
||||
record["updated_at"] = utc_now_iso()
|
||||
raise
|
||||
return record
|
||||
|
||||
async def _prune_if_needed(self) -> None:
|
||||
max_records = max(100, int(settings.edge_max_dispatch_records))
|
||||
async with self.lock:
|
||||
total = len(self.dispatches)
|
||||
total = len(self.dispatches) + len(self.commands)
|
||||
if total < max_records:
|
||||
return
|
||||
over = total - max_records + 1
|
||||
done = [v for v in self.dispatches.values() if v.get("status") in {"SUCCEEDED", "FAILED"}]
|
||||
done.extend([v for v in self.commands.values() if v.get("status") in {"SUCCEEDED", "FAILED"}])
|
||||
done.sort(key=lambda x: x.get("updated_at", ""))
|
||||
for rec in done[:over]:
|
||||
self.dispatches.pop(rec["dispatch_id"], None)
|
||||
self.commands.pop(rec["dispatch_id"], None)
|
||||
|
||||
async def mark_event(self, device_id: str, event: dict[str, Any]) -> None:
|
||||
now = utc_now_iso()
|
||||
@@ -154,6 +222,8 @@ class EdgeDispatchManager:
|
||||
return
|
||||
|
||||
record = self.dispatches.get(dispatch_id)
|
||||
if record is None:
|
||||
record = self.commands.get(dispatch_id)
|
||||
if record is None:
|
||||
return
|
||||
|
||||
@@ -175,6 +245,18 @@ class EdgeDispatchManager:
|
||||
record["error"] = event.get("error")
|
||||
if conn:
|
||||
conn.busy = False
|
||||
elif evt == "command_status":
|
||||
status_val = event.get("status") or "RUNNING"
|
||||
record["status"] = status_val
|
||||
elif evt == "command_result":
|
||||
status_val = event.get("status") or "SUCCEEDED"
|
||||
record["status"] = status_val
|
||||
record["result"] = event
|
||||
if status_val == "FAILED":
|
||||
record["error"] = event.get("error") or event.get("stderr")
|
||||
elif evt == "pong":
|
||||
record["status"] = "SUCCEEDED"
|
||||
record["result"] = event
|
||||
|
||||
record["updated_at"] = now
|
||||
|
||||
@@ -186,6 +268,11 @@ class EdgeDispatchManager:
|
||||
record["status"] = "FAILED"
|
||||
record["error"] = f"device disconnected: {device_id}"
|
||||
record["updated_at"] = now
|
||||
for record in self.commands.values():
|
||||
if record["device_id"] == device_id and record["status"] in {"DISPATCHED", "RUNNING", "PENDING"}:
|
||||
record["status"] = "FAILED"
|
||||
record["error"] = f"device disconnected: {device_id}"
|
||||
record["updated_at"] = now
|
||||
|
||||
|
||||
manager = EdgeDispatchManager()
|
||||
@@ -234,6 +321,34 @@ async def get_dispatch(dispatch_id: str) -> dict[str, Any]:
|
||||
return record
|
||||
|
||||
|
||||
@app.post("/devices/{device_id}/command", response_model=DeviceCommandResponse)
|
||||
async def device_command(device_id: str, body: DeviceCommandRequest) -> DeviceCommandResponse:
|
||||
conn = await manager.select_device(device_id)
|
||||
try:
|
||||
record = await manager.create_command(conn, body)
|
||||
except WebSocketDisconnect as exc:
|
||||
await manager.unregister(conn.device_id)
|
||||
raise HTTPException(status_code=503, detail=f"device disconnected during command: {conn.device_id}") from exc
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
|
||||
return DeviceCommandResponse(
|
||||
dispatch_id=record["dispatch_id"],
|
||||
device_id=record["device_id"],
|
||||
command=record["command"],
|
||||
status=record["status"],
|
||||
created_at=record["created_at"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/commands/{dispatch_id}")
|
||||
async def get_command(dispatch_id: str) -> dict[str, Any]:
|
||||
record = manager.commands.get(dispatch_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"command not found: {dispatch_id}")
|
||||
return record
|
||||
|
||||
|
||||
@app.post("/dispatch/{dispatch_id}/artifacts")
|
||||
async def upload_artifacts(
|
||||
dispatch_id: str,
|
||||
|
||||
Reference in New Issue
Block a user