Files
AI_A4000/video_worker/scripts/edge_device_client.py
2026-04-07 01:00:56 +08:00

89 lines
2.7 KiB
Python

import asyncio
import json
import os
import sys
import time
import requests
import websockets
DISPATCH_WS_URL = os.getenv("DISPATCH_WS_URL", "ws://127.0.0.1:8020/ws/edge/edge-a4000-01")
WORKER_BASE_URL = os.getenv("WORKER_BASE_URL", "http://127.0.0.1:8000")
POLL_INTERVAL = float(os.getenv("EDGE_POLL_INTERVAL_SEC", "1.0"))
if len(sys.argv) > 1:
DISPATCH_WS_URL = sys.argv[1]
def worker_post(path: str, payload: dict):
r = requests.post(f"{WORKER_BASE_URL}{path}", json=payload, timeout=30)
r.raise_for_status()
return r.json()
def worker_get(path: str):
r = requests.get(f"{WORKER_BASE_URL}{path}", timeout=20)
r.raise_for_status()
return r.json()
async def handle_generate(ws, data: dict):
dispatch_id = data["dispatch_id"]
req = data["request"]
created = await asyncio.to_thread(worker_post, "/generate", req)
task_id = created["task_id"]
await ws.send(json.dumps({"event": "accepted", "dispatch_id": dispatch_id, "task_id": task_id}, ensure_ascii=False))
while True:
status = await asyncio.to_thread(worker_get, f"/tasks/{task_id}")
await ws.send(
json.dumps(
{
"event": "status",
"dispatch_id": dispatch_id,
"task_id": task_id,
"status": status["status"],
"progress": status.get("progress", 0.0),
},
ensure_ascii=False,
)
)
if status["status"] in {"SUCCEEDED", "FAILED"}:
result = await asyncio.to_thread(worker_get, f"/tasks/{task_id}/result")
await ws.send(
json.dumps(
{
"event": "result",
"dispatch_id": dispatch_id,
"task_id": task_id,
**result,
},
ensure_ascii=False,
)
)
return
await asyncio.sleep(POLL_INTERVAL)
async def main() -> None:
while True:
try:
async with websockets.connect(DISPATCH_WS_URL, max_size=2**22) as ws:
print(f"connected: {DISPATCH_WS_URL}")
while True:
raw = await ws.recv()
data = json.loads(raw)
event = data.get("event")
if event == "generate":
await handle_generate(ws, data)
elif event == "registered":
print("registered", data)
except Exception as exc:
print("connection error, retry in 3s:", exc)
time.sleep(3)
if __name__ == "__main__":
asyncio.run(main())