89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import requests
|
|
import websockets
|
|
|
|
DISPATCH_WS_URL = os.getenv("DISPATCH_WS_URL", "ws://127.0.0.1:8020/ws/edge/edge-a4000-01")
|
|
WORKER_BASE_URL = os.getenv("WORKER_BASE_URL", "http://127.0.0.1:8000")
|
|
POLL_INTERVAL = float(os.getenv("EDGE_POLL_INTERVAL_SEC", "1.0"))
|
|
|
|
if len(sys.argv) > 1:
|
|
DISPATCH_WS_URL = sys.argv[1]
|
|
|
|
|
|
def worker_post(path: str, payload: dict):
|
|
r = requests.post(f"{WORKER_BASE_URL}{path}", json=payload, timeout=30)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
|
|
def worker_get(path: str):
|
|
r = requests.get(f"{WORKER_BASE_URL}{path}", timeout=20)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
|
|
async def handle_generate(ws, data: dict):
|
|
dispatch_id = data["dispatch_id"]
|
|
req = data["request"]
|
|
|
|
created = await asyncio.to_thread(worker_post, "/generate", req)
|
|
task_id = created["task_id"]
|
|
await ws.send(json.dumps({"event": "accepted", "dispatch_id": dispatch_id, "task_id": task_id}, ensure_ascii=False))
|
|
|
|
while True:
|
|
status = await asyncio.to_thread(worker_get, f"/tasks/{task_id}")
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"event": "status",
|
|
"dispatch_id": dispatch_id,
|
|
"task_id": task_id,
|
|
"status": status["status"],
|
|
"progress": status.get("progress", 0.0),
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
if status["status"] in {"SUCCEEDED", "FAILED"}:
|
|
result = await asyncio.to_thread(worker_get, f"/tasks/{task_id}/result")
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"event": "result",
|
|
"dispatch_id": dispatch_id,
|
|
"task_id": task_id,
|
|
**result,
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
return
|
|
await asyncio.sleep(POLL_INTERVAL)
|
|
|
|
|
|
async def main() -> None:
|
|
while True:
|
|
try:
|
|
async with websockets.connect(DISPATCH_WS_URL, max_size=2**22) as ws:
|
|
print(f"connected: {DISPATCH_WS_URL}")
|
|
while True:
|
|
raw = await ws.recv()
|
|
data = json.loads(raw)
|
|
event = data.get("event")
|
|
if event == "generate":
|
|
await handle_generate(ws, data)
|
|
elif event == "registered":
|
|
print("registered", data)
|
|
except Exception as exc:
|
|
print("connection error, retry in 3s:", exc)
|
|
time.sleep(3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|