290 lines
8.9 KiB
Python
290 lines
8.9 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
import websockets
|
|
|
|
DISPATCH_WS_URL = os.getenv("DISPATCH_WS_URL", "ws://127.0.0.1:8020/ws/edge/edge-a4000-01")
|
|
WORKER_BASE_URL = os.getenv("WORKER_BASE_URL", "http://127.0.0.1:8000")
|
|
POLL_INTERVAL = float(os.getenv("EDGE_POLL_INTERVAL_SEC", "1.0"))
|
|
DISPATCH_HTTP_BASE = os.getenv("DISPATCH_HTTP_BASE", "")
|
|
ALLOW_REMOTE_UPDATE = os.getenv("EDGE_ALLOW_REMOTE_UPDATE", "true").lower() in {"1", "true", "yes", "on"}
|
|
ALLOW_REMOTE_RESTART = os.getenv("EDGE_ALLOW_REMOTE_RESTART", "true").lower() in {"1", "true", "yes", "on"}
|
|
|
|
SCRIPT_DIR = Path(__file__).resolve().parent
|
|
ROOT_DIR = SCRIPT_DIR.parent
|
|
RESTART_SCRIPT = ROOT_DIR / "scripts" / "restart_edge_device_local.sh"
|
|
|
|
if len(sys.argv) > 1:
|
|
DISPATCH_WS_URL = sys.argv[1]
|
|
|
|
|
|
def infer_http_base(ws_url: str) -> str:
|
|
parsed = urlparse(ws_url)
|
|
scheme = "https" if parsed.scheme == "wss" else "http"
|
|
host = parsed.netloc
|
|
return f"{scheme}://{host}"
|
|
|
|
|
|
if not DISPATCH_HTTP_BASE:
|
|
DISPATCH_HTTP_BASE = infer_http_base(DISPATCH_WS_URL)
|
|
|
|
|
|
def worker_post(path: str, payload: dict):
|
|
r = requests.post(f"{WORKER_BASE_URL}{path}", json=payload, timeout=30)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
|
|
def worker_get(path: str):
|
|
r = requests.get(f"{WORKER_BASE_URL}{path}", timeout=20)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
|
|
def run_shell(command: str, timeout_sec: int = 1200) -> tuple[int, str, str]:
|
|
proc = subprocess.run(
|
|
command,
|
|
cwd=str(ROOT_DIR),
|
|
shell=True,
|
|
text=True,
|
|
capture_output=True,
|
|
timeout=timeout_sec,
|
|
)
|
|
return proc.returncode, proc.stdout.strip(), proc.stderr.strip()
|
|
|
|
|
|
def upload_artifacts(dispatch_id: str, task_id: str, result: dict) -> dict:
|
|
candidate_fields = ["video_path", "first_frame_path", "metadata_path", "log_path"]
|
|
existing_paths = []
|
|
for field in candidate_fields:
|
|
p = result.get(field)
|
|
if p and Path(p).exists():
|
|
existing_paths.append(Path(p))
|
|
|
|
if not existing_paths:
|
|
return {}
|
|
|
|
opened = []
|
|
files = []
|
|
try:
|
|
for path in existing_paths:
|
|
fh = path.open("rb")
|
|
opened.append(fh)
|
|
files.append(("files", (path.name, fh, "application/octet-stream")))
|
|
|
|
data = {
|
|
"task_id": task_id,
|
|
"status": result.get("status", "SUCCEEDED"),
|
|
}
|
|
resp = requests.post(
|
|
f"{DISPATCH_HTTP_BASE}/dispatch/{dispatch_id}/artifacts",
|
|
data=data,
|
|
files=files,
|
|
timeout=600,
|
|
)
|
|
resp.raise_for_status()
|
|
payload = resp.json()
|
|
return payload.get("artifact_urls", {})
|
|
finally:
|
|
for fh in opened:
|
|
fh.close()
|
|
|
|
|
|
def _short_text(text: str, limit: int = 1500) -> str:
|
|
if len(text) <= limit:
|
|
return text
|
|
return text[:limit] + "...(truncated)"
|
|
|
|
|
|
async def handle_generate(ws, data: dict):
|
|
dispatch_id = data["dispatch_id"]
|
|
req = data["request"]
|
|
|
|
created = await asyncio.to_thread(worker_post, "/generate", req)
|
|
task_id = created["task_id"]
|
|
await ws.send(json.dumps({"event": "accepted", "dispatch_id": dispatch_id, "task_id": task_id}, ensure_ascii=False))
|
|
|
|
while True:
|
|
status = await asyncio.to_thread(worker_get, f"/tasks/{task_id}")
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"event": "status",
|
|
"dispatch_id": dispatch_id,
|
|
"task_id": task_id,
|
|
"status": status["status"],
|
|
"progress": status.get("progress", 0.0),
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
if status["status"] in {"SUCCEEDED", "FAILED"}:
|
|
result = await asyncio.to_thread(worker_get, f"/tasks/{task_id}/result")
|
|
artifact_urls = {}
|
|
result_payload = {
|
|
"event": "result",
|
|
"dispatch_id": dispatch_id,
|
|
"task_id": task_id,
|
|
"status": result.get("status", status["status"]),
|
|
}
|
|
if status["status"] == "SUCCEEDED":
|
|
artifact_urls = await asyncio.to_thread(upload_artifacts, dispatch_id, task_id, result)
|
|
result_payload["artifact_urls"] = artifact_urls
|
|
else:
|
|
result_payload["error"] = result.get("error")
|
|
await ws.send(json.dumps(result_payload, ensure_ascii=False))
|
|
return
|
|
await asyncio.sleep(POLL_INTERVAL)
|
|
|
|
|
|
async def handle_update_code(ws, data: dict):
|
|
dispatch_id = data.get("dispatch_id", "")
|
|
if not ALLOW_REMOTE_UPDATE:
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"event": "command_result",
|
|
"dispatch_id": dispatch_id,
|
|
"command": "update_code",
|
|
"status": "FAILED",
|
|
"error": "EDGE_ALLOW_REMOTE_UPDATE=false",
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
return
|
|
|
|
branch = data.get("branch", "master")
|
|
git_command = data.get("command") or f"git fetch --all && git checkout {branch} && git pull --ff-only origin {branch}"
|
|
|
|
await ws.send(
|
|
json.dumps(
|
|
{"event": "command_status", "dispatch_id": dispatch_id, "command": "update_code", "status": "RUNNING"},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
|
|
code, out, err = await asyncio.to_thread(run_shell, git_command, 1800)
|
|
payload = {
|
|
"event": "command_result",
|
|
"dispatch_id": dispatch_id,
|
|
"command": "update_code",
|
|
"status": "SUCCEEDED" if code == 0 else "FAILED",
|
|
"exit_code": code,
|
|
"stdout": _short_text(out),
|
|
"stderr": _short_text(err),
|
|
}
|
|
await ws.send(json.dumps(payload, ensure_ascii=False))
|
|
|
|
|
|
async def handle_restart_service(ws, data: dict):
|
|
dispatch_id = data.get("dispatch_id", "")
|
|
if not ALLOW_REMOTE_RESTART:
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"event": "command_result",
|
|
"dispatch_id": dispatch_id,
|
|
"command": "restart_service",
|
|
"status": "FAILED",
|
|
"error": "EDGE_ALLOW_REMOTE_RESTART=false",
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
return
|
|
|
|
if not RESTART_SCRIPT.exists():
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"event": "command_result",
|
|
"dispatch_id": dispatch_id,
|
|
"command": "restart_service",
|
|
"status": "FAILED",
|
|
"error": f"restart script missing: {RESTART_SCRIPT}",
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
return
|
|
|
|
await ws.send(
|
|
json.dumps(
|
|
{"event": "command_status", "dispatch_id": dispatch_id, "command": "restart_service", "status": "RUNNING"},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"event": "command_result",
|
|
"dispatch_id": dispatch_id,
|
|
"command": "restart_service",
|
|
"status": "SUCCEEDED",
|
|
"message": "restart script launched",
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
|
|
subprocess.Popen(
|
|
["bash", str(RESTART_SCRIPT)],
|
|
cwd=str(ROOT_DIR),
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
start_new_session=True,
|
|
)
|
|
raise SystemExit(0)
|
|
|
|
|
|
async def handle_ping(ws, data: dict):
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"event": "pong",
|
|
"dispatch_id": data.get("dispatch_id", ""),
|
|
"status": "ok",
|
|
"worker_base_url": WORKER_BASE_URL,
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
|
|
|
|
async def main() -> None:
|
|
while True:
|
|
try:
|
|
async with websockets.connect(DISPATCH_WS_URL, max_size=2**22) as ws:
|
|
print(f"connected: {DISPATCH_WS_URL}")
|
|
while True:
|
|
raw = await ws.recv()
|
|
data = json.loads(raw)
|
|
event = data.get("event")
|
|
if event == "generate":
|
|
await handle_generate(ws, data)
|
|
elif event == "update_code":
|
|
await handle_update_code(ws, data)
|
|
elif event == "restart_service":
|
|
await handle_restart_service(ws, data)
|
|
elif event == "ping":
|
|
await handle_ping(ws, data)
|
|
elif event == "registered":
|
|
print("registered", data)
|
|
except SystemExit:
|
|
return
|
|
except Exception as exc:
|
|
print("connection error, retry in 3s:", exc)
|
|
time.sleep(3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|