Files
AiVideo/engine/adapters/image/openai_image_adapter.py
2026-03-25 19:35:37 +08:00

84 lines
3.1 KiB
Python

from __future__ import annotations
import os
import uuid
from io import BytesIO
from pathlib import Path
from typing import Any
import requests
from PIL import Image
from engine.config import AppConfig
from .base import BaseImageGen
class OpenAIImageAdapter(BaseImageGen):
"""
Optional image provider adapter using OpenAI Images API (or OpenAI-compatible gateways).
Requires `openai` python package and a configured API key via environment variables.
"""
def __init__(self, cfg: AppConfig):
self.cfg = cfg
# Expected keys (configurable):
# - image.openai.model
# - openai.api_key_env / openai.base_url_env (reuses existing engine/script_gen config fields)
self.model = str(cfg.get("image.openai.model", cfg.get("image.model", ""))).strip()
if not self.model:
raise ValueError("OpenAIImageAdapter requires `image.openai.model` (or `image.model`).")
api_key_env_or_literal = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY") or "OPENAI_API_KEY").strip()
# Support both:
# - env var name (e.g. OPENAI_API_KEY)
# - literal API key (e.g. starts with `sk-...`) for quick local POCs.
if api_key_env_or_literal.startswith("sk-"):
api_key = api_key_env_or_literal
else:
api_key = os.environ.get(api_key_env_or_literal)
if not api_key:
raise RuntimeError(f"OpenAIImageAdapter missing API key: `{api_key_env_or_literal}`")
self.api_key = api_key
base_url_env_or_literal = str(cfg.get("openai.base_url_env", "https://api.openai.com/v1")).strip()
self.base_url = base_url_env_or_literal.rstrip("/") if base_url_env_or_literal else "https://api.openai.com/v1"
# Lazy import to avoid hard dependency for mock/comfy users.
from openai import OpenAI # type: ignore
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
def generate(self, prompt: dict[str, str], output_dir: str | Path) -> str:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
positive = prompt.get("positive", "")
negative = prompt.get("negative", "")
# OpenAI Images API generally doesn't expose a dedicated negative_prompt field.
# To keep interface consistency, embed negative hints into the prompt text.
if negative:
prompt_text = f"{positive}\nNegative prompt: {negative}"
else:
prompt_text = positive
result = self.client.images.generate(model=self.model, prompt=prompt_text)
# OpenAI SDK: result.data[0].url
url: str | None = None
try:
url = result.data[0].url # type: ignore[attr-defined]
except Exception:
pass
if not url:
raise RuntimeError("OpenAIImageAdapter unexpected response: missing image url")
r = requests.get(url, timeout=60)
r.raise_for_status()
out_path = output_dir / f"shot_{uuid.uuid4().hex}.png"
img = Image.open(BytesIO(r.content)).convert("RGB")
img.save(str(out_path), format="PNG")
return str(out_path)