fix: 优化架构

This commit is contained in:
Daniel
2026-03-25 19:35:37 +08:00
parent 34786b37c7
commit 508c28ce31
184 changed files with 2199 additions and 241 deletions

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,9 @@
from __future__ import annotations
from pathlib import Path
class BaseImageGen:
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
raise NotImplementedError

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
from engine.comfy_client import generate_image as comfy_generate_image
from engine.config import AppConfig
from .base import BaseImageGen
from .mock_adapter import MockImageGen
class ComfyAdapter(BaseImageGen):
def __init__(self, cfg: AppConfig):
self.cfg = cfg
self.fallback = MockImageGen()
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
positive = str(prompt.get("positive", "") or "")
negative = str(prompt.get("negative", "") or "")
try:
return str(
comfy_generate_image(
positive,
output_dir,
negative_text=negative or None,
cfg=self.cfg,
timeout_s=60,
retry=2,
filename_prefix="shot",
)
)
except Exception as e:
# Let render_pipeline do configured fallback.
raise

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import os
import uuid
from pathlib import Path
from urllib.request import urlopen
from PIL import Image
from .base import BaseImageGen
ASSETS_DIR = "assets"
DEMO_IMAGE = os.path.join(ASSETS_DIR, "demo.jpg")
def ensure_demo_image() -> None:
os.makedirs(ASSETS_DIR, exist_ok=True)
if os.path.exists(DEMO_IMAGE):
return
url = "https://picsum.photos/1280/720"
with urlopen(url, timeout=30) as resp:
data = resp.read()
with open(DEMO_IMAGE, "wb") as f:
f.write(data)
class MockImageGen(BaseImageGen):
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
# prompt is accepted for interface consistency; mock uses only demo.jpg.
_ = prompt
ensure_demo_image()
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"shot_{uuid.uuid4().hex}.png"
try:
# Convert to PNG so verification criteria can match *.png.
img = Image.open(DEMO_IMAGE).convert("RGB")
img.save(str(out_path), format="PNG")
except Exception:
# Last-resort: if PNG conversion fails, still write a best-effort copy.
out_path.write_bytes(Path(DEMO_IMAGE).read_bytes())
return str(out_path)

View File

@@ -0,0 +1,83 @@
from __future__ import annotations
import os
import uuid
from io import BytesIO
from pathlib import Path
from typing import Any
import requests
from PIL import Image
from engine.config import AppConfig
from .base import BaseImageGen
class OpenAIImageAdapter(BaseImageGen):
"""
Optional image provider adapter using OpenAI Images API (or OpenAI-compatible gateways).
Requires `openai` python package and a configured API key via environment variables.
"""
def __init__(self, cfg: AppConfig):
self.cfg = cfg
# Expected keys (configurable):
# - image.openai.model
# - openai.api_key_env / openai.base_url_env (reuses existing engine/script_gen config fields)
self.model = str(cfg.get("image.openai.model", cfg.get("image.model", ""))).strip()
if not self.model:
raise ValueError("OpenAIImageAdapter requires `image.openai.model` (or `image.model`).")
api_key_env_or_literal = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY") or "OPENAI_API_KEY").strip()
# Support both:
# - env var name (e.g. OPENAI_API_KEY)
# - literal API key (e.g. starts with `sk-...`) for quick local POCs.
if api_key_env_or_literal.startswith("sk-"):
api_key = api_key_env_or_literal
else:
api_key = os.environ.get(api_key_env_or_literal)
if not api_key:
raise RuntimeError(f"OpenAIImageAdapter missing API key: `{api_key_env_or_literal}`")
self.api_key = api_key
base_url_env_or_literal = str(cfg.get("openai.base_url_env", "https://api.openai.com/v1")).strip()
self.base_url = base_url_env_or_literal.rstrip("/") if base_url_env_or_literal else "https://api.openai.com/v1"
# Lazy import to avoid hard dependency for mock/comfy users.
from openai import OpenAI # type: ignore
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
positive = prompt.get("positive", "")
negative = prompt.get("negative", "")
# OpenAI Images API generally doesn't expose a dedicated negative_prompt field.
# To keep interface consistency, embed negative hints into the prompt text.
if negative:
prompt_text = f"{positive}\nNegative prompt: {negative}"
else:
prompt_text = positive
result = self.client.images.generate(model=self.model, prompt=prompt_text)
# OpenAI SDK: result.data[0].url
url: str | None = None
try:
url = result.data[0].url # type: ignore[attr-defined]
except Exception:
pass
if not url:
raise RuntimeError("OpenAIImageAdapter unexpected response: missing image url")
r = requests.get(url, timeout=60)
r.raise_for_status()
out_path = output_dir / f"shot_{uuid.uuid4().hex}.png"
img = Image.open(BytesIO(r.content)).convert("RGB")
img.save(str(out_path), format="PNG")
return str(out_path)

View File

@@ -0,0 +1,60 @@
from __future__ import annotations
import uuid
from pathlib import Path
from typing import Any
import requests
from PIL import Image
from engine.config import AppConfig
from .base import BaseImageGen
class ReplicateAdapter(BaseImageGen):
def __init__(self, cfg: AppConfig):
self.cfg = cfg
# Expected: image.replicate.model
self.model = str(cfg.get("image.replicate.model", cfg.get("image.model", ""))).strip()
if not self.model:
raise ValueError("ReplicateAdapter requires `image.replicate.model` (or `image.model`).")
# Import lazily so that environments without replicate installed can still run with mock/comfy.
import replicate # type: ignore
self.replicate = replicate
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
input_payload: dict[str, Any] = {
"prompt": prompt.get("positive", ""),
"negative_prompt": prompt.get("negative", ""),
}
# replicate.run is synchronous when wait is handled by the SDK version.
output = self.replicate.run(self.model, input=input_payload)
# Common shapes: [url, ...] or dict-like.
image_url = None
if isinstance(output, list) and output:
image_url = output[0]
elif isinstance(output, dict):
image_url = output.get("image") or output.get("output") or output.get("url")
if not isinstance(image_url, str) or not image_url:
raise RuntimeError(f"Unexpected Replicate output shape: {type(output)}")
r = requests.get(image_url, timeout=60)
r.raise_for_status()
# Always output PNG to satisfy downstream validation `outputs/{task_id}/*.png`.
out_path = output_dir / f"shot_{uuid.uuid4().hex}.png"
# Pillow doesn't provide open_bytes; wrap content into a buffer.
from io import BytesIO
img = Image.open(BytesIO(r.content)).convert("RGB")
img.save(str(out_path), format="PNG")
return str(out_path)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from pathlib import Path
from engine.config import AppConfig
from .base import BaseImageGen
class StabilityAdapter(BaseImageGen):
"""
Placeholder for Stability AI image generation.
Add implementation + dependencies when needed.
"""
def __init__(self, cfg: AppConfig):
self.cfg = cfg
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
raise NotImplementedError("StabilityAdapter not implemented yet")