fix: 新增数据模块

This commit is contained in:
Daniel
2026-04-07 01:34:49 +08:00
parent e606b3dcd6
commit f529aa3279
11 changed files with 319 additions and 10 deletions

View File

@@ -1,13 +1,16 @@
import asyncio
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
from uuid import uuid4
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from app.oss_client import oss_uploader
from app.schemas import GenerateRequest
from app.settings import settings
def utc_now_iso() -> str:
@@ -86,6 +89,7 @@ class EdgeDispatchManager:
raise HTTPException(status_code=409, detail="no idle edge device available")
async def create_dispatch(self, conn: EdgeConnection, req: GenerateRequest) -> dict[str, Any]:
await self._prune_if_needed()
dispatch_id = uuid4().hex
now = utc_now_iso()
record = {
@@ -97,6 +101,7 @@ class EdgeDispatchManager:
"updated_at": now,
"result": None,
"error": None,
"artifact_urls": {},
}
async with self.lock:
@@ -125,6 +130,18 @@ class EdgeDispatchManager:
raise
return record
async def _prune_if_needed(self) -> None:
max_records = max(100, int(settings.edge_max_dispatch_records))
async with self.lock:
total = len(self.dispatches)
if total < max_records:
return
over = total - max_records + 1
done = [v for v in self.dispatches.values() if v.get("status") in {"SUCCEEDED", "FAILED"}]
done.sort(key=lambda x: x.get("updated_at", ""))
for rec in done[:over]:
self.dispatches.pop(rec["dispatch_id"], None)
async def mark_event(self, device_id: str, event: dict[str, Any]) -> None:
now = utc_now_iso()
async with self.lock:
@@ -217,6 +234,54 @@ async def get_dispatch(dispatch_id: str) -> dict[str, Any]:
return record
@app.post("/dispatch/{dispatch_id}/artifacts")
async def upload_artifacts(
dispatch_id: str,
task_id: str = Form(default=""),
status: str = Form(default="SUCCEEDED"),
files: list[UploadFile] = File(default_factory=list),
) -> dict[str, Any]:
record = manager.dispatches.get(dispatch_id)
if record is None:
raise HTTPException(status_code=404, detail=f"dispatch not found: {dispatch_id}")
if not oss_uploader.enabled:
raise HTTPException(status_code=400, detail="OSS upload is disabled, set OSS_ENABLED=true")
if not files:
raise HTTPException(status_code=400, detail="no files uploaded")
uploaded: dict[str, dict[str, str]] = {}
for file in files:
name = file.filename or "artifact.bin"
try:
result = await asyncio.to_thread(oss_uploader.upload_fileobj, dispatch_id, name, file.file)
uploaded[Path(name).name] = result
finally:
await file.close()
now = utc_now_iso()
async with manager.lock:
record["artifact_urls"] = uploaded
record["result"] = {
"event": "result",
"dispatch_id": dispatch_id,
"task_id": task_id or None,
"status": status,
"artifact_urls": uploaded,
}
record["status"] = status
record["updated_at"] = now
conn = manager.connections.get(record["device_id"])
if conn:
conn.busy = False
return {
"dispatch_id": dispatch_id,
"status": status,
"artifact_urls": uploaded,
"updated_at": now,
}
@app.websocket("/ws/edge/{device_id}")
async def edge_socket(websocket: WebSocket, device_id: str) -> None:
await websocket.accept()

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
import oss2
from app.settings import settings
class OSSUploader:
def __init__(self) -> None:
self.enabled = bool(settings.oss_enabled)
if not self.enabled:
self.bucket = None
return
if not all([
settings.oss_endpoint,
settings.oss_bucket,
settings.oss_access_key_id,
settings.oss_access_key_secret,
]):
raise RuntimeError("OSS is enabled but endpoint/bucket/ak/sk is not fully configured")
auth = oss2.Auth(settings.oss_access_key_id, settings.oss_access_key_secret)
self.bucket = oss2.Bucket(auth, settings.oss_endpoint, settings.oss_bucket)
@staticmethod
def _safe_name(name: str) -> str:
return name.replace("\\", "_").replace("/", "_")
def _key(self, dispatch_id: str, filename: str) -> str:
date_part = datetime.now(timezone.utc).strftime("%Y%m%d")
safe_file = self._safe_name(Path(filename).name)
return f"{settings.oss_prefix.strip('/')}/{date_part}/{dispatch_id}/{safe_file}"
def _public_url(self, key: str) -> str:
if settings.oss_public_base_url:
return f"{settings.oss_public_base_url.rstrip('/')}/{key}"
endpoint = settings.oss_endpoint.rstrip("/")
if endpoint.startswith("http://") or endpoint.startswith("https://"):
return f"{endpoint}/{key}"
return f"https://{endpoint}/{key}"
def upload_fileobj(self, dispatch_id: str, filename: str, fileobj) -> dict[str, str]:
if not self.enabled or self.bucket is None:
raise RuntimeError("OSS uploader is not enabled")
key = self._key(dispatch_id, filename)
fileobj.seek(0)
self.bucket.put_object(key, fileobj)
return {
"filename": Path(filename).name,
"object_key": key,
"url": self._public_url(key),
}
oss_uploader = OSSUploader()

View File

@@ -13,6 +13,15 @@ class Settings(BaseSettings):
ws_poll_interval_sec: float = Field(default=1.0, alias="WS_POLL_INTERVAL_SEC")
edge_dispatch_host: str = Field(default="0.0.0.0", alias="EDGE_DISPATCH_HOST")
edge_dispatch_port: int = Field(default=8020, alias="EDGE_DISPATCH_PORT")
edge_max_dispatch_records: int = Field(default=2000, alias="EDGE_MAX_DISPATCH_RECORDS")
oss_enabled: bool = Field(default=False, alias="OSS_ENABLED")
oss_endpoint: str = Field(default="", alias="OSS_ENDPOINT")
oss_bucket: str = Field(default="", alias="OSS_BUCKET")
oss_access_key_id: str = Field(default="", alias="OSS_ACCESS_KEY_ID")
oss_access_key_secret: str = Field(default="", alias="OSS_ACCESS_KEY_SECRET")
oss_public_base_url: str = Field(default="", alias="OSS_PUBLIC_BASE_URL")
oss_prefix: str = Field(default="video-worker", alias="OSS_PREFIX")
output_dir: Path = Field(default=Path("./outputs"), alias="OUTPUT_DIR")
runtime_dir: Path = Field(default=Path("./runtime"), alias="RUNTIME_DIR")