fix: 新增数据模块
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.oss_client import oss_uploader
|
||||
from app.schemas import GenerateRequest
|
||||
from app.settings import settings
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
@@ -86,6 +89,7 @@ class EdgeDispatchManager:
|
||||
raise HTTPException(status_code=409, detail="no idle edge device available")
|
||||
|
||||
async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]:
|
||||
await self._prune_if_needed()
|
||||
dispatch_id = uuid4().hex
|
||||
now = utc_now_iso()
|
||||
record = {
|
||||
@@ -97,6 +101,7 @@ class EdgeDispatchManager:
|
||||
"updated_at": now,
|
||||
"result": None,
|
||||
"error": None,
|
||||
"artifact_urls": {},
|
||||
}
|
||||
|
||||
async with self.lock:
|
||||
@@ -125,6 +130,18 @@ class EdgeDispatchManager:
|
||||
raise
|
||||
return record
|
||||
|
||||
async def _prune_if_needed(self) -> None:
|
||||
max_records = max(100, int(settings.edge_max_dispatch_records))
|
||||
async with self.lock:
|
||||
total = len(self.dispatches)
|
||||
if total < max_records:
|
||||
return
|
||||
over = total - max_records + 1
|
||||
done = [v for v in self.dispatches.values() if v.get("status") in {"SUCCEEDED", "FAILED"}]
|
||||
done.sort(key=lambda x: x.get("updated_at", ""))
|
||||
for rec in done[:over]:
|
||||
self.dispatches.pop(rec["dispatch_id"], None)
|
||||
|
||||
async def mark_event(self, device_id: str, event: dict[str, Any]) -> None:
|
||||
now = utc_now_iso()
|
||||
async with self.lock:
|
||||
@@ -217,6 +234,54 @@ async def get_dispatch(dispatch_id: str) -> dict[str, Any]:
|
||||
return record
|
||||
|
||||
|
||||
@app.post("/dispatch/{dispatch_id}/artifacts")
|
||||
async def upload_artifacts(
|
||||
dispatch_id: str,
|
||||
task_id: str = Form(default=""),
|
||||
status: str = Form(default="SUCCEEDED"),
|
||||
files: list[UploadFile] = File(default_factory=list),
|
||||
) -> dict[str, Any]:
|
||||
record = manager.dispatches.get(dispatch_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}")
|
||||
if not oss_uploader.enabled:
|
||||
raise HTTPException(status_code=400, detail="OSS upload is disabled, set OSS_ENABLED=true")
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="no files uploaded")
|
||||
|
||||
uploaded: dict[str, dict[str, str]] = {}
|
||||
for file in files:
|
||||
name = file.filename or "artifact.bin"
|
||||
try:
|
||||
result = await asyncio.to_thread(oss_uploader.upload_fileobj, dispatch_id, name, file.file)
|
||||
uploaded[Path(name).name] = result
|
||||
finally:
|
||||
await file.close()
|
||||
|
||||
now = utc_now_iso()
|
||||
async with manager.lock:
|
||||
record["artifact_urls"] = uploaded
|
||||
record["result"] = {
|
||||
"event": "result",
|
||||
"dispatch_id": dispatch_id,
|
||||
"task_id": task_id or None,
|
||||
"status": status,
|
||||
"artifact_urls": uploaded,
|
||||
}
|
||||
record["status"] = status
|
||||
record["updated_at"] = now
|
||||
conn = manager.connections.get(record["device_id"])
|
||||
if conn:
|
||||
conn.busy = False
|
||||
|
||||
return {
|
||||
"dispatch_id": dispatch_id,
|
||||
"status": status,
|
||||
"artifact_urls": uploaded,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
|
||||
@app.websocket("/ws/edge/{device_id}")
|
||||
async def edge_socket(websocket: WebSocket, device_id: str) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
61
video_worker/app/oss_client.py
Normal file
61
video_worker/app/oss_client.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import oss2
|
||||
|
||||
from app.settings import settings
|
||||
|
||||
|
||||
class OSSUploader:
|
||||
def __init__(self) -> None:
|
||||
self.enabled = bool(settings.oss_enabled)
|
||||
if not self.enabled:
|
||||
self.bucket = None
|
||||
return
|
||||
|
||||
if not all([
|
||||
settings.oss_endpoint,
|
||||
settings.oss_bucket,
|
||||
settings.oss_access_key_id,
|
||||
settings.oss_access_key_secret,
|
||||
]):
|
||||
raise RuntimeError("OSS is enabled but endpoint/bucket/ak/sk is not fully configured")
|
||||
|
||||
auth = oss2.Auth(settings.oss_access_key_id, settings.oss_access_key_secret)
|
||||
self.bucket = oss2.Bucket(auth, settings.oss_endpoint, settings.oss_bucket)
|
||||
|
||||
@staticmethod
|
||||
def _safe_name(name: str) -> str:
|
||||
return name.replace("\\", "_").replace("/", "_")
|
||||
|
||||
def _key(self, dispatch_id: str, filename: str) -> str:
|
||||
date_part = datetime.now(timezone.utc).strftime("%Y%m%d")
|
||||
safe_file = self._safe_name(Path(filename).name)
|
||||
return f"{settings.oss_prefix.strip('/')}/{date_part}/{dispatch_id}/{safe_file}"
|
||||
|
||||
def _public_url(self, key: str) -> str:
|
||||
if settings.oss_public_base_url:
|
||||
return f"{settings.oss_public_base_url.rstrip('/')}/{key}"
|
||||
|
||||
endpoint = settings.oss_endpoint.rstrip("/")
|
||||
if endpoint.startswith("http://") or endpoint.startswith("https://"):
|
||||
return f"{endpoint}/{key}"
|
||||
return f"https://{endpoint}/{key}"
|
||||
|
||||
def upload_fileobj(self, dispatch_id: str, filename: str, fileobj) -> dict[str, str]:
|
||||
if not self.enabled or self.bucket is None:
|
||||
raise RuntimeError("OSS uploader is not enabled")
|
||||
|
||||
key = self._key(dispatch_id, filename)
|
||||
fileobj.seek(0)
|
||||
self.bucket.put_object(key, fileobj)
|
||||
return {
|
||||
"filename": Path(filename).name,
|
||||
"object_key": key,
|
||||
"url": self._public_url(key),
|
||||
}
|
||||
|
||||
|
||||
oss_uploader = OSSUploader()
|
||||
@@ -13,6 +13,15 @@ class Settings(BaseSettings):
|
||||
ws_poll_interval_sec: float = Field(default=1.0, alias="WS_POLL_INTERVAL_SEC")
|
||||
edge_dispatch_host: str = Field(default="0.0.0.0", alias="EDGE_DISPATCH_HOST")
|
||||
edge_dispatch_port: int = Field(default=8020, alias="EDGE_DISPATCH_PORT")
|
||||
edge_max_dispatch_records: int = Field(default=2000, alias="EDGE_MAX_DISPATCH_RECORDS")
|
||||
|
||||
oss_enabled: bool = Field(default=False, alias="OSS_ENABLED")
|
||||
oss_endpoint: str = Field(default="", alias="OSS_ENDPOINT")
|
||||
oss_bucket: str = Field(default="", alias="OSS_BUCKET")
|
||||
oss_access_key_id: str = Field(default="", alias="OSS_ACCESS_KEY_ID")
|
||||
oss_access_key_secret: str = Field(default="", alias="OSS_ACCESS_KEY_SECRET")
|
||||
oss_public_base_url: str = Field(default="", alias="OSS_PUBLIC_BASE_URL")
|
||||
oss_prefix: str = Field(default="video-worker", alias="OSS_PREFIX")
|
||||
|
||||
output_dir: Path = Field(default=Path("./outputs"), alias="OUTPUT_DIR")
|
||||
runtime_dir: Path = Field(default=Path("./runtime"), alias="RUNTIME_DIR")
|
||||
|
||||
Reference in New Issue
Block a user