import asyncio import json import os from pathlib import Path import sys import time from urllib.parse import urlparse import requests import websockets DISPATCH_WS_URL = os.getenv("DISPATCH_WS_URL", "ws://127.0.0.1:8020/ws/edge/edge-a4000-01") WORKER_BASE_URL = os.getenv("WORKER_BASE_URL", "http://127.0.0.1:8000") POLL_INTERVAL = float(os.getenv("EDGE_POLL_INTERVAL_SEC", "1.0")) DISPATCH_HTTP_BASE = os.getenv("DISPATCH_HTTP_BASE", "") if len(sys.argv) > 1: DISPATCH_WS_URL = sys.argv[1] def infer_http_base(ws_url: str) -> str: parsed = urlparse(ws_url) scheme = "https" if parsed.scheme == "wss" else "http" host = parsed.netloc return f"{scheme}://{host}" if not DISPATCH_HTTP_BASE: DISPATCH_HTTP_BASE = infer_http_base(DISPATCH_WS_URL) def worker_post(path: str, payload: dict): r = requests.post(f"{WORKER_BASE_URL}{path}", json=payload, timeout=30) r.raise_for_status() return r.json() def worker_get(path: str): r = requests.get(f"{WORKER_BASE_URL}{path}", timeout=20) r.raise_for_status() return r.json() def upload_artifacts(dispatch_id: str, task_id: str, result: dict) -> dict: candidate_fields = ["video_path", "first_frame_path", "metadata_path", "log_path"] existing_paths = [] for field in candidate_fields: p = result.get(field) if p and Path(p).exists(): existing_paths.append(Path(p)) if not existing_paths: return {} opened = [] files = [] try: for path in existing_paths: fh = path.open("rb") opened.append(fh) files.append(("files", (path.name, fh, "application/octet-stream"))) data = { "task_id": task_id, "status": result.get("status", "SUCCEEDED"), } resp = requests.post( f"{DISPATCH_HTTP_BASE}/dispatch/{dispatch_id}/artifacts", data=data, files=files, timeout=600, ) resp.raise_for_status() payload = resp.json() return payload.get("artifact_urls", {}) finally: for fh in opened: fh.close() async def handle_generate(ws, data: dict): dispatch_id = data["dispatch_id"] req = data["request"] created = await asyncio.to_thread(worker_post, "/generate", req) task_id = created["task_id"] await ws.send(json.dumps({"event": "accepted", "dispatch_id": dispatch_id, "task_id": task_id}, ensure_ascii=False)) while True: status = await asyncio.to_thread(worker_get, f"/tasks/{task_id}") await ws.send( json.dumps( { "event": "status", "dispatch_id": dispatch_id, "task_id": task_id, "status": status["status"], "progress": status.get("progress", 0.0), }, ensure_ascii=False, ) ) if status["status"] in {"SUCCEEDED", "FAILED"}: result = await asyncio.to_thread(worker_get, f"/tasks/{task_id}/result") artifact_urls = {} result_payload = { "event": "result", "dispatch_id": dispatch_id, "task_id": task_id, "status": result.get("status", status["status"]), } if status["status"] == "SUCCEEDED": artifact_urls = await asyncio.to_thread(upload_artifacts, dispatch_id, task_id, result) result_payload["artifact_urls"] = artifact_urls else: result_payload["error"] = result.get("error") await ws.send( json.dumps(result_payload, ensure_ascii=False) ) return await asyncio.sleep(POLL_INTERVAL) async def main() -> None: while True: try: async with websockets.connect(DISPATCH_WS_URL, max_size=2**22) as ws: print(f"connected: {DISPATCH_WS_URL}") while True: raw = await ws.recv() data = json.loads(raw) event = data.get("event") if event == "generate": await handle_generate(ws, data) elif event == "registered": print("registered", data) except Exception as exc: print("connection error, retry in 3s:", exc) time.sleep(3) if __name__ == "__main__": asyncio.run(main())