fix:优化服务
This commit is contained in:
88
video_worker/scripts/edge_device_client.py
Normal file
88
video_worker/scripts/edge_device_client.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
import websockets
|
||||
|
||||
DISPATCH_WS_URL = os.getenv("DISPATCH_WS_URL", "ws://127.0.0.1:8020/ws/edge/edge-a4000-01")
|
||||
WORKER_BASE_URL = os.getenv("WORKER_BASE_URL", "http://127.0.0.1:8000")
|
||||
POLL_INTERVAL = float(os.getenv("EDGE_POLL_INTERVAL_SEC", "1.0"))
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
DISPATCH_WS_URL = sys.argv[1]
|
||||
|
||||
|
||||
def worker_post(path: str, payload: dict):
|
||||
r = requests.post(f"{WORKER_BASE_URL}{path}", json=payload, timeout=30)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def worker_get(path: str):
|
||||
r = requests.get(f"{WORKER_BASE_URL}{path}", timeout=20)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
async def handle_generate(ws, data: dict):
|
||||
dispatch_id = data["dispatch_id"]
|
||||
req = data["request"]
|
||||
|
||||
created = await asyncio.to_thread(worker_post, "/generate", req)
|
||||
task_id = created["task_id"]
|
||||
await ws.send(json.dumps({"event": "accepted", "dispatch_id": dispatch_id, "task_id": task_id}, ensure_ascii=False))
|
||||
|
||||
while True:
|
||||
status = await asyncio.to_thread(worker_get, f"/tasks/{task_id}")
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "status",
|
||||
"dispatch_id": dispatch_id,
|
||||
"task_id": task_id,
|
||||
"status": status["status"],
|
||||
"progress": status.get("progress", 0.0),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
if status["status"] in {"SUCCEEDED", "FAILED"}:
|
||||
result = await asyncio.to_thread(worker_get, f"/tasks/{task_id}/result")
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "result",
|
||||
"dispatch_id": dispatch_id,
|
||||
"task_id": task_id,
|
||||
**result,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
return
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
while True:
|
||||
try:
|
||||
async with websockets.connect(DISPATCH_WS_URL, max_size=2**22) as ws:
|
||||
print(f"connected: {DISPATCH_WS_URL}")
|
||||
while True:
|
||||
raw = await ws.recv()
|
||||
data = json.loads(raw)
|
||||
event = data.get("event")
|
||||
if event == "generate":
|
||||
await handle_generate(ws, data)
|
||||
elif event == "registered":
|
||||
print("registered", data)
|
||||
except Exception as exc:
|
||||
print("connection error, retry in 3s:", exc)
|
||||
time.sleep(3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user