84 lines
3.1 KiB
Python
84 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import uuid
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import requests
|
|
from PIL import Image
|
|
|
|
from engine.config import AppConfig
|
|
|
|
from .base import BaseImageGen
|
|
|
|
|
|
class OpenAIImageAdapter(BaseImageGen):
|
|
"""
|
|
Optional image provider adapter using OpenAI Images API (or OpenAI-compatible gateways).
|
|
Requires `openai` python package and a configured API key via environment variables.
|
|
"""
|
|
|
|
def __init__(self, cfg: AppConfig):
|
|
self.cfg = cfg
|
|
# Expected keys (configurable):
|
|
# - image.openai.model
|
|
# - openai.api_key_env / openai.base_url_env (reuses existing engine/script_gen config fields)
|
|
self.model = str(cfg.get("image.openai.model", cfg.get("image.model", ""))).strip()
|
|
if not self.model:
|
|
raise ValueError("OpenAIImageAdapter requires `image.openai.model` (or `image.model`).")
|
|
|
|
api_key_env_or_literal = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY") or "OPENAI_API_KEY").strip()
|
|
# Support both:
|
|
# - env var name (e.g. OPENAI_API_KEY)
|
|
# - literal API key (e.g. starts with `sk-...`) for quick local POCs.
|
|
if api_key_env_or_literal.startswith("sk-"):
|
|
api_key = api_key_env_or_literal
|
|
else:
|
|
api_key = os.environ.get(api_key_env_or_literal)
|
|
if not api_key:
|
|
raise RuntimeError(f"OpenAIImageAdapter missing API key: `{api_key_env_or_literal}`")
|
|
self.api_key = api_key
|
|
|
|
base_url_env_or_literal = str(cfg.get("openai.base_url_env", "https://api.openai.com/v1")).strip()
|
|
self.base_url = base_url_env_or_literal.rstrip("/") if base_url_env_or_literal else "https://api.openai.com/v1"
|
|
|
|
# Lazy import to avoid hard dependency for mock/comfy users.
|
|
from openai import OpenAI # type: ignore
|
|
|
|
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
|
|
output_dir = Path(output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
positive = prompt.get("positive", "")
|
|
negative = prompt.get("negative", "")
|
|
# OpenAI Images API generally doesn't expose a dedicated negative_prompt field.
|
|
# To keep interface consistency, embed negative hints into the prompt text.
|
|
if negative:
|
|
prompt_text = f"{positive}\nNegative prompt: {negative}"
|
|
else:
|
|
prompt_text = positive
|
|
|
|
result = self.client.images.generate(model=self.model, prompt=prompt_text)
|
|
|
|
# OpenAI SDK: result.data[0].url
|
|
url: str | None = None
|
|
try:
|
|
url = result.data[0].url # type: ignore[attr-defined]
|
|
except Exception:
|
|
pass
|
|
if not url:
|
|
raise RuntimeError("OpenAIImageAdapter unexpected response: missing image url")
|
|
|
|
r = requests.get(url, timeout=60)
|
|
r.raise_for_status()
|
|
|
|
out_path = output_dir / f"shot_{uuid.uuid4().hex}.png"
|
|
img = Image.open(BytesIO(r.content)).convert("RGB")
|
|
img.save(str(out_path), format="PNG")
|
|
return str(out_path)
|
|
|