fix: 新增数据模块
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import websockets
|
||||
@@ -10,11 +12,23 @@ import websockets
|
||||
DISPATCH_WS_URL = os.getenv("DISPATCH_WS_URL", "ws://127.0.0.1:8020/ws/edge/edge-a4000-01")
|
||||
WORKER_BASE_URL = os.getenv("WORKER_BASE_URL", "http://127.0.0.1:8000")
|
||||
POLL_INTERVAL = float(os.getenv("EDGE_POLL_INTERVAL_SEC", "1.0"))
|
||||
DISPATCH_HTTP_BASE = os.getenv("DISPATCH_HTTP_BASE", "")
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
DISPATCH_WS_URL = sys.argv[1]
|
||||
|
||||
|
||||
def infer_http_base(ws_url: str) -> str:
|
||||
parsed = urlparse(ws_url)
|
||||
scheme = "https" if parsed.scheme == "wss" else "http"
|
||||
host = parsed.netloc
|
||||
return f"{scheme}://{host}"
|
||||
|
||||
|
||||
if not DISPATCH_HTTP_BASE:
|
||||
DISPATCH_HTTP_BASE = infer_http_base(DISPATCH_WS_URL)
|
||||
|
||||
|
||||
def worker_post(path: str, payload: dict):
|
||||
r = requests.post(f"{WORKER_BASE_URL}{path}", json=payload, timeout=30)
|
||||
r.raise_for_status()
|
||||
@@ -27,6 +41,43 @@ def worker_get(path: str):
|
||||
return r.json()
|
||||
|
||||
|
||||
def upload_artifacts(dispatch_id: str, task_id: str, result: dict) -> dict:
|
||||
candidate_fields = ["video_path", "first_frame_path", "metadata_path", "log_path"]
|
||||
existing_paths = []
|
||||
for field in candidate_fields:
|
||||
p = result.get(field)
|
||||
if p and Path(p).exists():
|
||||
existing_paths.append(Path(p))
|
||||
|
||||
if not existing_paths:
|
||||
return {}
|
||||
|
||||
opened = []
|
||||
files = []
|
||||
try:
|
||||
for path in existing_paths:
|
||||
fh = path.open("rb")
|
||||
opened.append(fh)
|
||||
files.append(("files", (path.name, fh, "application/octet-stream")))
|
||||
|
||||
data = {
|
||||
"task_id": task_id,
|
||||
"status": result.get("status", "SUCCEEDED"),
|
||||
}
|
||||
resp = requests.post(
|
||||
f"{DISPATCH_HTTP_BASE}/dispatch/{dispatch_id}/artifacts",
|
||||
data=data,
|
||||
files=files,
|
||||
timeout=600,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
return payload.get("artifact_urls", {})
|
||||
finally:
|
||||
for fh in opened:
|
||||
fh.close()
|
||||
|
||||
|
||||
async def handle_generate(ws, data: dict):
|
||||
dispatch_id = data["dispatch_id"]
|
||||
req = data["request"]
|
||||
@@ -51,16 +102,20 @@ async def handle_generate(ws, data: dict):
|
||||
)
|
||||
if status["status"] in {"SUCCEEDED", "FAILED"}:
|
||||
result = await asyncio.to_thread(worker_get, f"/tasks/{task_id}/result")
|
||||
artifact_urls = {}
|
||||
result_payload = {
|
||||
"event": "result",
|
||||
"dispatch_id": dispatch_id,
|
||||
"task_id": task_id,
|
||||
"status": result.get("status", status["status"]),
|
||||
}
|
||||
if status["status"] == "SUCCEEDED":
|
||||
artifact_urls = await asyncio.to_thread(upload_artifacts, dispatch_id, task_id, result)
|
||||
result_payload["artifact_urls"] = artifact_urls
|
||||
else:
|
||||
result_payload["error"] = result.get("error")
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "result",
|
||||
"dispatch_id": dispatch_id,
|
||||
"task_id": task_id,
|
||||
**result,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
json.dumps(result_payload, ensure_ascii=False)
|
||||
)
|
||||
return
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
|
||||
Reference in New Issue
Block a user