import asyncio import json import os from pathlib import Path import subprocess import sys import time from urllib.parse import urlparse import requests import websockets DISPATCH_WS_URL = os.getenv("DISPATCH_WS_URL", "ws://127.0.0.1:8020/ws/edge/edge-a4000-01") WORKER_BASE_URL = os.getenv("WORKER_BASE_URL", "http://127.0.0.1:8000") POLL_INTERVAL = float(os.getenv("EDGE_POLL_INTERVAL_SEC", "1.0")) DISPATCH_HTTP_BASE = os.getenv("DISPATCH_HTTP_BASE", "") ALLOW_REMOTE_UPDATE = os.getenv("EDGE_ALLOW_REMOTE_UPDATE", "true").lower() in {"1", "true", "yes", "on"} ALLOW_REMOTE_RESTART = os.getenv("EDGE_ALLOW_REMOTE_RESTART", "true").lower() in {"1", "true", "yes", "on"} SCRIPT_DIR = Path(__file__).resolve().parent ROOT_DIR = SCRIPT_DIR.parent RESTART_SCRIPT = ROOT_DIR / "scripts" / "restart_edge_device_local.sh" if len(sys.argv) > 1: DISPATCH_WS_URL = sys.argv[1] def infer_http_base(ws_url: str) -> str: parsed = urlparse(ws_url) scheme = "https" if parsed.scheme == "wss" else "http" host = parsed.netloc return f"{scheme}://{host}" if not DISPATCH_HTTP_BASE: DISPATCH_HTTP_BASE = infer_http_base(DISPATCH_WS_URL) def worker_post(path: str, payload: dict): r = requests.post(f"{WORKER_BASE_URL}{path}", json=payload, timeout=30) r.raise_for_status() return r.json() def worker_get(path: str): r = requests.get(f"{WORKER_BASE_URL}{path}", timeout=20) r.raise_for_status() return r.json() def run_shell(command: str, timeout_sec: int = 1200) -> tuple[int, str, str]: proc = subprocess.run( command, cwd=str(ROOT_DIR), shell=True, text=True, capture_output=True, timeout=timeout_sec, ) return proc.returncode, proc.stdout.strip(), proc.stderr.strip() def upload_artifacts(dispatch_id: str, task_id: str, result: dict) -> dict: candidate_fields = ["video_path", "first_frame_path", "metadata_path", "log_path"] existing_paths = [] for field in candidate_fields: p = result.get(field) if p and Path(p).exists(): existing_paths.append(Path(p)) if not existing_paths: return {} opened = [] files = [] try: for path in existing_paths: fh = path.open("rb") opened.append(fh) files.append(("files", (path.name, fh, "application/octet-stream"))) data = { "task_id": task_id, "status": result.get("status", "SUCCEEDED"), } resp = requests.post( f"{DISPATCH_HTTP_BASE}/dispatch/{dispatch_id}/artifacts", data=data, files=files, timeout=600, ) resp.raise_for_status() payload = resp.json() return payload.get("artifact_urls", {}) finally: for fh in opened: fh.close() def _short_text(text: str, limit: int = 1500) -> str: if len(text) <= limit: return text return text[:limit] + "...(truncated)" async def handle_generate(ws, data: dict): dispatch_id = data["dispatch_id"] req = data["request"] created = await asyncio.to_thread(worker_post, "/generate", req) task_id = created["task_id"] await ws.send(json.dumps({"event": "accepted", "dispatch_id": dispatch_id, "task_id": task_id}, ensure_ascii=False)) while True: status = await asyncio.to_thread(worker_get, f"/tasks/{task_id}") await ws.send( json.dumps( { "event": "status", "dispatch_id": dispatch_id, "task_id": task_id, "status": status["status"], "progress": status.get("progress", 0.0), }, ensure_ascii=False, ) ) if status["status"] in {"SUCCEEDED", "FAILED"}: result = await asyncio.to_thread(worker_get, f"/tasks/{task_id}/result") artifact_urls = {} result_payload = { "event": "result", "dispatch_id": dispatch_id, "task_id": task_id, "status": result.get("status", status["status"]), } if status["status"] == "SUCCEEDED": artifact_urls = await asyncio.to_thread(upload_artifacts, dispatch_id, task_id, result) result_payload["artifact_urls"] = artifact_urls else: result_payload["error"] = result.get("error") await ws.send(json.dumps(result_payload, ensure_ascii=False)) return await asyncio.sleep(POLL_INTERVAL) async def handle_update_code(ws, data: dict): dispatch_id = data.get("dispatch_id", "") if not ALLOW_REMOTE_UPDATE: await ws.send( json.dumps( { "event": "command_result", "dispatch_id": dispatch_id, "command": "update_code", "status": "FAILED", "error": "EDGE_ALLOW_REMOTE_UPDATE=false", }, ensure_ascii=False, ) ) return branch = data.get("branch", "master") git_command = data.get("command") or f"git fetch --all && git checkout {branch} && git pull --ff-only origin {branch}" await ws.send( json.dumps( {"event": "command_status", "dispatch_id": dispatch_id, "command": "update_code", "status": "RUNNING"}, ensure_ascii=False, ) ) code, out, err = await asyncio.to_thread(run_shell, git_command, 1800) payload = { "event": "command_result", "dispatch_id": dispatch_id, "command": "update_code", "status": "SUCCEEDED" if code == 0 else "FAILED", "exit_code": code, "stdout": _short_text(out), "stderr": _short_text(err), } await ws.send(json.dumps(payload, ensure_ascii=False)) async def handle_restart_service(ws, data: dict): dispatch_id = data.get("dispatch_id", "") if not ALLOW_REMOTE_RESTART: await ws.send( json.dumps( { "event": "command_result", "dispatch_id": dispatch_id, "command": "restart_service", "status": "FAILED", "error": "EDGE_ALLOW_REMOTE_RESTART=false", }, ensure_ascii=False, ) ) return if not RESTART_SCRIPT.exists(): await ws.send( json.dumps( { "event": "command_result", "dispatch_id": dispatch_id, "command": "restart_service", "status": "FAILED", "error": f"restart script missing: {RESTART_SCRIPT}", }, ensure_ascii=False, ) ) return await ws.send( json.dumps( {"event": "command_status", "dispatch_id": dispatch_id, "command": "restart_service", "status": "RUNNING"}, ensure_ascii=False, ) ) await ws.send( json.dumps( { "event": "command_result", "dispatch_id": dispatch_id, "command": "restart_service", "status": "SUCCEEDED", "message": "restart script launched", }, ensure_ascii=False, ) ) subprocess.Popen( ["bash", str(RESTART_SCRIPT)], cwd=str(ROOT_DIR), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True, ) raise SystemExit(0) async def handle_ping(ws, data: dict): await ws.send( json.dumps( { "event": "pong", "dispatch_id": data.get("dispatch_id", ""), "status": "ok", "worker_base_url": WORKER_BASE_URL, }, ensure_ascii=False, ) ) async def main() -> None: while True: try: async with websockets.connect(DISPATCH_WS_URL, max_size=2**22) as ws: print(f"connected: {DISPATCH_WS_URL}") while True: raw = await ws.recv() data = json.loads(raw) event = data.get("event") if event == "generate": await handle_generate(ws, data) elif event == "update_code": await handle_update_code(ws, data) elif event == "restart_service": await handle_restart_service(ws, data) elif event == "ping": await handle_ping(ws, data) elif event == "registered": print("registered", data) except SystemExit: return except Exception as exc: print("connection error, retry in 3s:", exc) time.sleep(3) if __name__ == "__main__": asyncio.run(main())