import asyncio import json import os import sys import time import requests import websockets DISPATCH_WS_URL = os.getenv("DISPATCH_WS_URL", "ws://127.0.0.1:8020/ws/edge/edge-a4000-01") WORKER_BASE_URL = os.getenv("WORKER_BASE_URL", "http://127.0.0.1:8000") POLL_INTERVAL = float(os.getenv("EDGE_POLL_INTERVAL_SEC", "1.0")) if len(sys.argv) > 1: DISPATCH_WS_URL = sys.argv[1] def worker_post(path: str, payload: dict): r = requests.post(f"{WORKER_BASE_URL}{path}", json=payload, timeout=30) r.raise_for_status() return r.json() def worker_get(path: str): r = requests.get(f"{WORKER_BASE_URL}{path}", timeout=20) r.raise_for_status() return r.json() async def handle_generate(ws, data: dict): dispatch_id = data["dispatch_id"] req = data["request"] created = await asyncio.to_thread(worker_post, "/generate", req) task_id = created["task_id"] await ws.send(json.dumps({"event": "accepted", "dispatch_id": dispatch_id, "task_id": task_id}, ensure_ascii=False)) while True: status = await asyncio.to_thread(worker_get, f"/tasks/{task_id}") await ws.send( json.dumps( { "event": "status", "dispatch_id": dispatch_id, "task_id": task_id, "status": status["status"], "progress": status.get("progress", 0.0), }, ensure_ascii=False, ) ) if status["status"] in {"SUCCEEDED", "FAILED"}: result = await asyncio.to_thread(worker_get, f"/tasks/{task_id}/result") await ws.send( json.dumps( { "event": "result", "dispatch_id": dispatch_id, "task_id": task_id, **result, }, ensure_ascii=False, ) ) return await asyncio.sleep(POLL_INTERVAL) async def main() -> None: while True: try: async with websockets.connect(DISPATCH_WS_URL, max_size=2**22) as ws: print(f"connected: {DISPATCH_WS_URL}") while True: raw = await ws.recv() data = json.loads(raw) event = data.get("event") if event == "generate": await handle_generate(ws, data) elif event == "registered": print("registered", data) except Exception as exc: print("connection error, retry in 3s:", exc) time.sleep(3) if __name__ == "__main__": asyncio.run(main())