from __future__ import annotations import uuid from pathlib import Path from typing import Any import requests from PIL import Image from engine.config import AppConfig from .base import BaseImageGen class ReplicateAdapter(BaseImageGen): def __init__(self, cfg: AppConfig): self.cfg = cfg # Expected: image.replicate.model self.model = str(cfg.get("image.replicate.model", cfg.get("image.model", ""))).strip() if not self.model: raise ValueError("ReplicateAdapter requires `image.replicate.model` (or `image.model`).") # Import lazily so that environments without replicate installed can still run with mock/comfy. import replicate # type: ignore self.replicate = replicate def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) input_payload: dict[str, Any] = { "prompt": prompt.get("positive", ""), "negative_prompt": prompt.get("negative", ""), } # replicate.run is synchronous when wait is handled by the SDK version. output = self.replicate.run(self.model, input=input_payload) # Common shapes: [url, ...] or dict-like. image_url = None if isinstance(output, list) and output: image_url = output[0] elif isinstance(output, dict): image_url = output.get("image") or output.get("output") or output.get("url") if not isinstance(image_url, str) or not image_url: raise RuntimeError(f"Unexpected Replicate output shape: {type(output)}") r = requests.get(image_url, timeout=60) r.raise_for_status() # Always output PNG to satisfy downstream validation `outputs/{task_id}/*.png`. out_path = output_dir / f"shot_{uuid.uuid4().hex}.png" # Pillow doesn't provide open_bytes; wrap content into a buffer. from io import BytesIO img = Image.open(BytesIO(r.content)).convert("RGB") img.save(str(out_path), format="PNG") return str(out_path)