fix: 优化架构
This commit is contained in:
1
engine/adapters/__init__.py
Normal file
1
engine/adapters/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
engine/adapters/image/__init__.py
Normal file
1
engine/adapters/image/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
9
engine/adapters/image/base.py
Normal file
9
engine/adapters/image/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class BaseImageGen:
|
||||
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
36
engine/adapters/image/comfy_adapter.py
Normal file
36
engine/adapters/image/comfy_adapter.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from engine.comfy_client import generate_image as comfy_generate_image
|
||||
from engine.config import AppConfig
|
||||
|
||||
from .base import BaseImageGen
|
||||
from .mock_adapter import MockImageGen
|
||||
|
||||
|
||||
class ComfyAdapter(BaseImageGen):
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
self.fallback = MockImageGen()
|
||||
|
||||
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
|
||||
positive = str(prompt.get("positive", "") or "")
|
||||
negative = str(prompt.get("negative", "") or "")
|
||||
try:
|
||||
return str(
|
||||
comfy_generate_image(
|
||||
positive,
|
||||
output_dir,
|
||||
negative_text=negative or None,
|
||||
cfg=self.cfg,
|
||||
timeout_s=60,
|
||||
retry=2,
|
||||
filename_prefix="shot",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Let render_pipeline do configured fallback.
|
||||
raise
|
||||
|
||||
45
engine/adapters/image/mock_adapter.py
Normal file
45
engine/adapters/image/mock_adapter.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from urllib.request import urlopen
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .base import BaseImageGen
|
||||
|
||||
|
||||
ASSETS_DIR = "assets"
|
||||
DEMO_IMAGE = os.path.join(ASSETS_DIR, "demo.jpg")
|
||||
|
||||
|
||||
def ensure_demo_image() -> None:
|
||||
os.makedirs(ASSETS_DIR, exist_ok=True)
|
||||
if os.path.exists(DEMO_IMAGE):
|
||||
return
|
||||
|
||||
url = "https://picsum.photos/1280/720"
|
||||
with urlopen(url, timeout=30) as resp:
|
||||
data = resp.read()
|
||||
with open(DEMO_IMAGE, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
class MockImageGen(BaseImageGen):
|
||||
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
|
||||
# prompt is accepted for interface consistency; mock uses only demo.jpg.
|
||||
_ = prompt
|
||||
ensure_demo_image()
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = out_dir / f"shot_{uuid.uuid4().hex}.png"
|
||||
try:
|
||||
# Convert to PNG so verification criteria can match *.png.
|
||||
img = Image.open(DEMO_IMAGE).convert("RGB")
|
||||
img.save(str(out_path), format="PNG")
|
||||
except Exception:
|
||||
# Last-resort: if PNG conversion fails, still write a best-effort copy.
|
||||
out_path.write_bytes(Path(DEMO_IMAGE).read_bytes())
|
||||
return str(out_path)
|
||||
|
||||
83
engine/adapters/image/openai_image_adapter.py
Normal file
83
engine/adapters/image/openai_image_adapter.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from engine.config import AppConfig
|
||||
|
||||
from .base import BaseImageGen
|
||||
|
||||
|
||||
class OpenAIImageAdapter(BaseImageGen):
|
||||
"""
|
||||
Optional image provider adapter using OpenAI Images API (or OpenAI-compatible gateways).
|
||||
Requires `openai` python package and a configured API key via environment variables.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
# Expected keys (configurable):
|
||||
# - image.openai.model
|
||||
# - openai.api_key_env / openai.base_url_env (reuses existing engine/script_gen config fields)
|
||||
self.model = str(cfg.get("image.openai.model", cfg.get("image.model", ""))).strip()
|
||||
if not self.model:
|
||||
raise ValueError("OpenAIImageAdapter requires `image.openai.model` (or `image.model`).")
|
||||
|
||||
api_key_env_or_literal = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY") or "OPENAI_API_KEY").strip()
|
||||
# Support both:
|
||||
# - env var name (e.g. OPENAI_API_KEY)
|
||||
# - literal API key (e.g. starts with `sk-...`) for quick local POCs.
|
||||
if api_key_env_or_literal.startswith("sk-"):
|
||||
api_key = api_key_env_or_literal
|
||||
else:
|
||||
api_key = os.environ.get(api_key_env_or_literal)
|
||||
if not api_key:
|
||||
raise RuntimeError(f"OpenAIImageAdapter missing API key: `{api_key_env_or_literal}`")
|
||||
self.api_key = api_key
|
||||
|
||||
base_url_env_or_literal = str(cfg.get("openai.base_url_env", "https://api.openai.com/v1")).strip()
|
||||
self.base_url = base_url_env_or_literal.rstrip("/") if base_url_env_or_literal else "https://api.openai.com/v1"
|
||||
|
||||
# Lazy import to avoid hard dependency for mock/comfy users.
|
||||
from openai import OpenAI # type: ignore
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
positive = prompt.get("positive", "")
|
||||
negative = prompt.get("negative", "")
|
||||
# OpenAI Images API generally doesn't expose a dedicated negative_prompt field.
|
||||
# To keep interface consistency, embed negative hints into the prompt text.
|
||||
if negative:
|
||||
prompt_text = f"{positive}\nNegative prompt: {negative}"
|
||||
else:
|
||||
prompt_text = positive
|
||||
|
||||
result = self.client.images.generate(model=self.model, prompt=prompt_text)
|
||||
|
||||
# OpenAI SDK: result.data[0].url
|
||||
url: str | None = None
|
||||
try:
|
||||
url = result.data[0].url # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
if not url:
|
||||
raise RuntimeError("OpenAIImageAdapter unexpected response: missing image url")
|
||||
|
||||
r = requests.get(url, timeout=60)
|
||||
r.raise_for_status()
|
||||
|
||||
out_path = output_dir / f"shot_{uuid.uuid4().hex}.png"
|
||||
img = Image.open(BytesIO(r.content)).convert("RGB")
|
||||
img.save(str(out_path), format="PNG")
|
||||
return str(out_path)
|
||||
|
||||
60
engine/adapters/image/replicate_adapter.py
Normal file
60
engine/adapters/image/replicate_adapter.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from engine.config import AppConfig
|
||||
|
||||
from .base import BaseImageGen
|
||||
|
||||
|
||||
class ReplicateAdapter(BaseImageGen):
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
# Expected: image.replicate.model
|
||||
self.model = str(cfg.get("image.replicate.model", cfg.get("image.model", ""))).strip()
|
||||
if not self.model:
|
||||
raise ValueError("ReplicateAdapter requires `image.replicate.model` (or `image.model`).")
|
||||
|
||||
# Import lazily so that environments without replicate installed can still run with mock/comfy.
|
||||
import replicate # type: ignore
|
||||
|
||||
self.replicate = replicate
|
||||
|
||||
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
input_payload: dict[str, Any] = {
|
||||
"prompt": prompt.get("positive", ""),
|
||||
"negative_prompt": prompt.get("negative", ""),
|
||||
}
|
||||
|
||||
# replicate.run is synchronous when wait is handled by the SDK version.
|
||||
output = self.replicate.run(self.model, input=input_payload)
|
||||
|
||||
# Common shapes: [url, ...] or dict-like.
|
||||
image_url = None
|
||||
if isinstance(output, list) and output:
|
||||
image_url = output[0]
|
||||
elif isinstance(output, dict):
|
||||
image_url = output.get("image") or output.get("output") or output.get("url")
|
||||
if not isinstance(image_url, str) or not image_url:
|
||||
raise RuntimeError(f"Unexpected Replicate output shape: {type(output)}")
|
||||
|
||||
r = requests.get(image_url, timeout=60)
|
||||
r.raise_for_status()
|
||||
|
||||
# Always output PNG to satisfy downstream validation `outputs/{task_id}/*.png`.
|
||||
out_path = output_dir / f"shot_{uuid.uuid4().hex}.png"
|
||||
# Pillow doesn't provide open_bytes; wrap content into a buffer.
|
||||
from io import BytesIO
|
||||
|
||||
img = Image.open(BytesIO(r.content)).convert("RGB")
|
||||
img.save(str(out_path), format="PNG")
|
||||
return str(out_path)
|
||||
|
||||
21
engine/adapters/image/stability_adapter.py
Normal file
21
engine/adapters/image/stability_adapter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from engine.config import AppConfig
|
||||
|
||||
from .base import BaseImageGen
|
||||
|
||||
|
||||
class StabilityAdapter(BaseImageGen):
|
||||
"""
|
||||
Placeholder for Stability AI image generation.
|
||||
Add implementation + dependencies when needed.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
|
||||
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
|
||||
raise NotImplementedError("StabilityAdapter not implemented yet")
|
||||
|
||||
1
engine/adapters/llm/__init__.py
Normal file
1
engine/adapters/llm/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
12
engine/adapters/llm/base.py
Normal file
12
engine/adapters/llm/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
def generate_script(self, prompt: str, context: dict[str, Any] | None = None) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def refine_scene(self, scene: Any, context: dict[str, Any] | None = None) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
25
engine/adapters/llm/mock_adapter.py
Normal file
25
engine/adapters/llm/mock_adapter.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from engine.types import Scene
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class MockLLM(BaseLLM):
|
||||
def generate_script(self, prompt: str, context: dict[str, Any] | None = None) -> list[Scene]:
|
||||
# Simple deterministic scenes for offline development.
|
||||
prompt = (prompt or "").strip()
|
||||
if not prompt:
|
||||
prompt = "a warm city night"
|
||||
return [
|
||||
Scene(image_prompt=f"{prompt},城市夜景,霓虹灯,电影感", video_motion="缓慢推进镜头,轻微摇镜", narration="夜色温柔落在街灯上"),
|
||||
Scene(image_prompt=f"{prompt},咖啡店窗边,暖光,细雨", video_motion="侧向平移,人物轻轻抬头", narration="雨声里藏着一段回忆"),
|
||||
Scene(image_prompt=f"{prompt},桥上远景,车流光轨,温暖", video_motion="拉远全景,光轨流动", narration="我们在光里学会告别"),
|
||||
]
|
||||
|
||||
def refine_scene(self, scene: Scene, context: dict[str, Any] | None = None) -> Scene:
|
||||
# Minimal polish: append a hint.
|
||||
return Scene(image_prompt=scene.image_prompt, video_motion=scene.video_motion, narration=(scene.narration + "(更凝练)")[:30])
|
||||
|
||||
29
engine/adapters/llm/openai_adapter.py
Normal file
29
engine/adapters/llm/openai_adapter.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from engine.config import AppConfig
|
||||
from engine.script_gen import generate_scenes, refine_scene
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class OpenAIAdapter(BaseLLM):
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
|
||||
def generate_script(self, prompt: str, context: dict[str, Any] | None = None):
|
||||
# Existing script_gen already enforces JSON schema and length constraints.
|
||||
return generate_scenes(prompt, self.cfg)
|
||||
|
||||
def refine_scene(self, scene: Any, context: dict[str, Any] | None = None):
|
||||
if context is None:
|
||||
context = {}
|
||||
# Context carries needed values to call refine_scene in script_gen.
|
||||
scenes = context.get("scenes")
|
||||
prompt2 = context.get("prompt")
|
||||
target_index = context.get("target_index")
|
||||
if scenes is None or prompt2 is None or target_index is None:
|
||||
raise ValueError("OpenAIAdapter.refine_scene missing context: scenes/prompt/target_index")
|
||||
return refine_scene(prompt=prompt2, scenes=scenes, target_index=int(target_index), cfg=self.cfg)
|
||||
|
||||
1
engine/adapters/tts/__init__.py
Normal file
1
engine/adapters/tts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
9
engine/adapters/tts/base.py
Normal file
9
engine/adapters/tts/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class BaseTTS:
|
||||
def generate(self, text: str, output_path: str | Path) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
28
engine/adapters/tts/edge_adapter.py
Normal file
28
engine/adapters/tts/edge_adapter.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from engine.audio_gen import synthesize_one
|
||||
from engine.config import AppConfig
|
||||
|
||||
from .base import BaseTTS
|
||||
|
||||
|
||||
class EdgeTTS(BaseTTS):
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
|
||||
def generate(self, text: str, output_path: str | Path) -> str:
|
||||
text = text or " "
|
||||
output_path = Path(output_path)
|
||||
voice = str(self.cfg.get("tts.voice", "zh-CN-XiaoxiaoNeural"))
|
||||
rate = str(self.cfg.get("tts.rate", "+0%"))
|
||||
volume = str(self.cfg.get("tts.volume", "+0%"))
|
||||
|
||||
async def _run():
|
||||
asset = await synthesize_one(text, output_path, voice, rate, volume)
|
||||
return str(asset.path)
|
||||
|
||||
return asyncio.run(_run())
|
||||
|
||||
15
engine/adapters/tts/mock_adapter.py
Normal file
15
engine/adapters/tts/mock_adapter.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .base import BaseTTS
|
||||
|
||||
|
||||
class MockTTS(BaseTTS):
|
||||
def generate(self, text: str, output_path: str | Path) -> str:
|
||||
# No-op for offline tests: return empty path so video adapter skips audio.
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_bytes(b"")
|
||||
return str(output_path)
|
||||
|
||||
1
engine/adapters/video/__init__.py
Normal file
1
engine/adapters/video/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
9
engine/adapters/video/base.py
Normal file
9
engine/adapters/video/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class BaseVideoGen:
|
||||
def generate(self, image_path: str, prompt: dict, output_path: str | Path) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
18
engine/adapters/video/ltx_adapter.py
Normal file
18
engine/adapters/video/ltx_adapter.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from engine.config import AppConfig
|
||||
|
||||
from .base import BaseVideoGen
|
||||
|
||||
|
||||
class LTXVideoGen(BaseVideoGen):
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
|
||||
def generate(self, image_path: str, prompt: dict, output_path: str | Path) -> str:
|
||||
# Reserved for future: direct image->video generation (LTX / diffusion video).
|
||||
# Current project keeps clip generation via MoviePy for stability.
|
||||
raise NotImplementedError("LTXVideoGen is not implemented yet")
|
||||
|
||||
81
engine/adapters/video/moviepy_adapter.py
Normal file
81
engine/adapters/video/moviepy_adapter.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from moviepy import AudioFileClip, VideoClip
|
||||
from PIL import Image
|
||||
|
||||
from engine.config import AppConfig
|
||||
|
||||
from .base import BaseVideoGen
|
||||
|
||||
|
||||
class MoviePyVideoGen(BaseVideoGen):
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
|
||||
def generate(self, image_path: str, prompt: dict, output_path: str | Path) -> str:
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Required prompt fields for shot rendering.
|
||||
duration_s = float(prompt.get("duration_s", 3))
|
||||
fps = int(prompt.get("fps", self.cfg.get("video.mock_fps", 24)))
|
||||
audio_path = prompt.get("audio_path")
|
||||
|
||||
# Clip resolution.
|
||||
size = prompt.get("size")
|
||||
if isinstance(size, (list, tuple)) and len(size) == 2:
|
||||
w, h = int(size[0]), int(size[1])
|
||||
else:
|
||||
mock_size = self.cfg.get("video.mock_size", [1024, 576])
|
||||
w, h = int(mock_size[0]), int(mock_size[1])
|
||||
|
||||
base_img = Image.open(image_path).convert("RGB")
|
||||
|
||||
def make_frame(t: float):
|
||||
progress = float(t) / max(duration_s, 1e-6)
|
||||
progress = max(0.0, min(1.0, progress))
|
||||
scale = 1.0 + 0.03 * progress
|
||||
new_w = max(w, int(w * scale))
|
||||
new_h = max(h, int(h * scale))
|
||||
frame = base_img.resize((new_w, new_h), Image.LANCZOS)
|
||||
left = (new_w - w) // 2
|
||||
top = (new_h - h) // 2
|
||||
frame = frame.crop((left, top, left + w, top + h))
|
||||
return np.array(frame)
|
||||
|
||||
video = VideoClip(make_frame, duration=duration_s, has_constant_size=True)
|
||||
|
||||
# Optional audio.
|
||||
if audio_path and os.path.exists(str(audio_path)):
|
||||
a = AudioFileClip(str(audio_path))
|
||||
video = video.with_audio(a)
|
||||
else:
|
||||
a = None
|
||||
|
||||
try:
|
||||
video.write_videofile(
|
||||
str(output_path),
|
||||
fps=fps,
|
||||
codec="libx264",
|
||||
audio_codec="aac",
|
||||
preset="veryfast",
|
||||
threads=2,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
video.close()
|
||||
except Exception:
|
||||
pass
|
||||
if a is not None:
|
||||
try:
|
||||
a.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return str(output_path)
|
||||
|
||||
Reference in New Issue
Block a user