fix: 优化架构
This commit is contained in:
60
engine/adapters/image/replicate_adapter.py
Normal file
60
engine/adapters/image/replicate_adapter.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from engine.config import AppConfig
|
||||
|
||||
from .base import BaseImageGen
|
||||
|
||||
|
||||
class ReplicateAdapter(BaseImageGen):
|
||||
def __init__(self, cfg: AppConfig):
|
||||
self.cfg = cfg
|
||||
# Expected: image.replicate.model
|
||||
self.model = str(cfg.get("image.replicate.model", cfg.get("image.model", ""))).strip()
|
||||
if not self.model:
|
||||
raise ValueError("ReplicateAdapter requires `image.replicate.model` (or `image.model`).")
|
||||
|
||||
# Import lazily so that environments without replicate installed can still run with mock/comfy.
|
||||
import replicate # type: ignore
|
||||
|
||||
self.replicate = replicate
|
||||
|
||||
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
input_payload: dict[str, Any] = {
|
||||
"prompt": prompt.get("positive", ""),
|
||||
"negative_prompt": prompt.get("negative", ""),
|
||||
}
|
||||
|
||||
# replicate.run is synchronous when wait is handled by the SDK version.
|
||||
output = self.replicate.run(self.model, input=input_payload)
|
||||
|
||||
# Common shapes: [url, ...] or dict-like.
|
||||
image_url = None
|
||||
if isinstance(output, list) and output:
|
||||
image_url = output[0]
|
||||
elif isinstance(output, dict):
|
||||
image_url = output.get("image") or output.get("output") or output.get("url")
|
||||
if not isinstance(image_url, str) or not image_url:
|
||||
raise RuntimeError(f"Unexpected Replicate output shape: {type(output)}")
|
||||
|
||||
r = requests.get(image_url, timeout=60)
|
||||
r.raise_for_status()
|
||||
|
||||
# Always output PNG to satisfy downstream validation `outputs/{task_id}/*.png`.
|
||||
out_path = output_dir / f"shot_{uuid.uuid4().hex}.png"
|
||||
# Pillow doesn't provide open_bytes; wrap content into a buffer.
|
||||
from io import BytesIO
|
||||
|
||||
img = Image.open(BytesIO(r.content)).convert("RGB")
|
||||
img.save(str(out_path), format="PNG")
|
||||
return str(out_path)
|
||||
|
||||
Reference in New Issue
Block a user