402 lines
15 KiB
Python
402 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Iterable
|
|
|
|
import httpx
|
|
|
|
from .config import AppConfig
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ComfyResult:
|
|
prompt_id: str
|
|
output_files: list[Path]
|
|
|
|
|
|
class ComfyClient:
|
|
def __init__(self, cfg: AppConfig):
|
|
self.cfg = cfg
|
|
self.base_url = str(cfg.get("app.comfy_base_url", "http://127.0.0.1:8188")).rstrip("/")
|
|
self.output_dir = Path(str(cfg.get("app.comfy_output_dir", "./ComfyUI/output")))
|
|
self.workflow_path = Path(str(cfg.get("comfy_workflow.workflow_path", "./workflow_api.json")))
|
|
self._client_id = str(uuid.uuid4())
|
|
|
|
def load_workflow(self) -> dict[str, Any]:
|
|
if not self.workflow_path.exists():
|
|
raise FileNotFoundError(f"workflow file not found: {self.workflow_path}")
|
|
raw = json.loads(self.workflow_path.read_text(encoding="utf-8"))
|
|
if not isinstance(raw, dict):
|
|
raise ValueError(f"workflow_api.json root must be dict, got {type(raw)}")
|
|
return raw
|
|
|
|
def _nodes(self, workflow: dict[str, Any]) -> dict[str, Any]:
|
|
# ComfyUI API workflow exports typically use { node_id: {class_type, inputs, ...}, ... }
|
|
return workflow
|
|
|
|
def _find_node_id_by_class_type(self, workflow: dict[str, Any], class_types: Iterable[str]) -> str | None:
|
|
want = {c.strip() for c in class_types if c and str(c).strip()}
|
|
if not want:
|
|
return None
|
|
for node_id, node in self._nodes(workflow).items():
|
|
if not isinstance(node, dict):
|
|
continue
|
|
ct = node.get("class_type")
|
|
if isinstance(ct, str) and ct in want:
|
|
return str(node_id)
|
|
return None
|
|
|
|
def _resolve_node_id(self, workflow: dict[str, Any], configured_id: Any, fallback_class_types_key: str) -> str:
|
|
if configured_id is not None and str(configured_id).strip():
|
|
node_id = str(configured_id).strip()
|
|
if node_id not in self._nodes(workflow):
|
|
raise KeyError(f"Configured node_id {node_id} not found in workflow")
|
|
return node_id
|
|
class_types = self.cfg.get(f"comfy_workflow.{fallback_class_types_key}", []) or []
|
|
if not isinstance(class_types, list):
|
|
raise ValueError(f"Config comfy_workflow.{fallback_class_types_key} must be list")
|
|
found = self._find_node_id_by_class_type(workflow, [str(x) for x in class_types])
|
|
if not found:
|
|
raise KeyError(f"Cannot resolve node by class types: {class_types}")
|
|
return found
|
|
|
|
def inject_params(self, workflow: dict[str, Any], image_prompt: str, seed: int, motion_prompt: str | None = None) -> dict[str, Any]:
|
|
wf = json.loads(json.dumps(workflow)) # deep copy
|
|
|
|
prompt_node_id = self._resolve_node_id(
|
|
wf,
|
|
self.cfg.get("comfy_workflow.prompt_node_id", None),
|
|
"prompt_node_class_types",
|
|
)
|
|
prompt_key = str(self.cfg.get("comfy_workflow.prompt_input_key", "text"))
|
|
self._set_input(wf, prompt_node_id, prompt_key, image_prompt)
|
|
|
|
seed_node_id = self._resolve_node_id(
|
|
wf,
|
|
self.cfg.get("comfy_workflow.seed_node_id", None),
|
|
"seed_node_class_types",
|
|
)
|
|
seed_key = str(self.cfg.get("comfy_workflow.seed_input_key", "seed"))
|
|
self._set_input(wf, seed_node_id, seed_key, int(seed))
|
|
|
|
motion_node_id = self.cfg.get("comfy_workflow.motion_node_id", None)
|
|
if motion_prompt and motion_node_id is not None and str(motion_node_id).strip():
|
|
motion_key = str(self.cfg.get("comfy_workflow.motion_input_key", "text"))
|
|
self._set_input(wf, str(motion_node_id).strip(), motion_key, motion_prompt)
|
|
|
|
return wf
|
|
|
|
def _set_input(self, workflow: dict[str, Any], node_id: str, key: str, value: Any) -> None:
|
|
node = self._nodes(workflow).get(str(node_id))
|
|
if not isinstance(node, dict):
|
|
raise KeyError(f"Node {node_id} not found")
|
|
inputs = node.get("inputs")
|
|
if inputs is None:
|
|
inputs = {}
|
|
node["inputs"] = inputs
|
|
if not isinstance(inputs, dict):
|
|
raise TypeError(f"Node {node_id} inputs must be dict, got {type(inputs)}")
|
|
inputs[key] = value
|
|
|
|
async def _post_prompt(self, client: httpx.AsyncClient, workflow: dict[str, Any]) -> str:
|
|
url = f"{self.base_url}/prompt"
|
|
payload = {"prompt": workflow, "client_id": self._client_id}
|
|
r = await client.post(url, json=payload)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
pid = data.get("prompt_id") or data.get("prompt_id".upper())
|
|
if not isinstance(pid, str) or not pid:
|
|
raise RuntimeError(f"Unexpected /prompt response: {data}")
|
|
return pid
|
|
|
|
async def _get_history(self, client: httpx.AsyncClient, prompt_id: str) -> dict[str, Any] | None:
|
|
# Common endpoints:
|
|
# - /history/{prompt_id}
|
|
# - /history (returns all histories keyed by prompt id)
|
|
for url in (f"{self.base_url}/history/{prompt_id}", f"{self.base_url}/history"):
|
|
try:
|
|
r = await client.get(url)
|
|
if r.status_code == 404:
|
|
continue
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
if isinstance(data, dict):
|
|
if prompt_id in data and isinstance(data[prompt_id], dict):
|
|
return data[prompt_id]
|
|
if url.endswith(f"/{prompt_id}"):
|
|
return data
|
|
return None
|
|
except httpx.HTTPStatusError:
|
|
raise
|
|
except Exception:
|
|
continue
|
|
return None
|
|
|
|
def _extract_output_files(self, history_item: dict[str, Any]) -> list[Path]:
|
|
out: list[Path] = []
|
|
outputs = history_item.get("outputs")
|
|
if not isinstance(outputs, dict):
|
|
return out
|
|
|
|
def walk(v: Any) -> None:
|
|
if isinstance(v, dict):
|
|
# ComfyUI tends to store files like {"filename":"x.mp4","subfolder":"","type":"output"}
|
|
fn = v.get("filename")
|
|
if isinstance(fn, str) and fn.strip():
|
|
out.append(self.output_dir / fn)
|
|
for vv in v.values():
|
|
walk(vv)
|
|
elif isinstance(v, list):
|
|
for vv in v:
|
|
walk(vv)
|
|
|
|
walk(outputs)
|
|
# De-dup while preserving order
|
|
seen: set[str] = set()
|
|
uniq: list[Path] = []
|
|
for p in out:
|
|
s = str(p)
|
|
if s not in seen:
|
|
seen.add(s)
|
|
uniq.append(p)
|
|
return uniq
|
|
|
|
async def run_workflow(self, workflow: dict[str, Any], *, poll_interval_s: float = 1.0, timeout_s: float = 300.0) -> ComfyResult:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
prompt_id = await self._post_prompt(client, workflow)
|
|
|
|
deadline = asyncio.get_event_loop().time() + timeout_s
|
|
last_files: list[Path] = []
|
|
while True:
|
|
if asyncio.get_event_loop().time() > deadline:
|
|
raise TimeoutError(f"ComfyUI job timeout: {prompt_id}")
|
|
item = await self._get_history(client, prompt_id)
|
|
if isinstance(item, dict):
|
|
files = self._extract_output_files(item)
|
|
if files:
|
|
last_files = files
|
|
# Heuristic: if any file exists on disk, treat as done.
|
|
if any(p.exists() for p in files):
|
|
return ComfyResult(prompt_id=prompt_id, output_files=files)
|
|
await asyncio.sleep(poll_interval_s)
|
|
|
|
# unreachable
|
|
# return ComfyResult(prompt_id=prompt_id, output_files=last_files)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Minimal "text->image" helpers (used by shot rendering)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_simple_workflow(
|
|
prompt_text: str,
|
|
*,
|
|
seed: int,
|
|
ckpt_name: str,
|
|
width: int,
|
|
height: int,
|
|
steps: int = 20,
|
|
cfg: float = 8.0,
|
|
sampler_name: str = "euler",
|
|
scheduler: str = "normal",
|
|
denoise: float = 1.0,
|
|
filename_prefix: str = "shot",
|
|
negative_text: str = "low quality, blurry",
|
|
) -> dict[str, Any]:
|
|
# Best-effort workflow. If your ComfyUI nodes/models differ, generation must fallback.
|
|
return {
|
|
"3": {
|
|
"class_type": "KSampler",
|
|
"inputs": {
|
|
"seed": int(seed),
|
|
"steps": int(steps),
|
|
"cfg": float(cfg),
|
|
"sampler_name": sampler_name,
|
|
"scheduler": scheduler,
|
|
"denoise": float(denoise),
|
|
"model": ["4", 0],
|
|
"positive": ["6", 0],
|
|
"negative": ["7", 0],
|
|
"latent_image": ["5", 0],
|
|
},
|
|
},
|
|
"4": {
|
|
"class_type": "CheckpointLoaderSimple",
|
|
"inputs": {
|
|
"ckpt_name": ckpt_name,
|
|
},
|
|
},
|
|
"5": {
|
|
"class_type": "EmptyLatentImage",
|
|
"inputs": {
|
|
"width": int(width),
|
|
"height": int(height),
|
|
"batch_size": 1,
|
|
},
|
|
},
|
|
"6": {
|
|
"class_type": "CLIPTextEncode",
|
|
"inputs": {
|
|
"text": prompt_text,
|
|
"clip": ["4", 1],
|
|
},
|
|
},
|
|
"7": {
|
|
"class_type": "CLIPTextEncode",
|
|
"inputs": {
|
|
"text": negative_text,
|
|
"clip": ["4", 1],
|
|
},
|
|
},
|
|
"8": {
|
|
"class_type": "VAEDecode",
|
|
"inputs": {
|
|
"samples": ["3", 0],
|
|
"vae": ["4", 2],
|
|
},
|
|
},
|
|
"9": {
|
|
"class_type": "SaveImage",
|
|
"inputs": {
|
|
"images": ["8", 0],
|
|
"filename_prefix": filename_prefix,
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
def _queue_prompt(base_url: str, workflow: dict[str, Any], client_id: str) -> str:
|
|
r = httpx.post(
|
|
base_url.rstrip("/") + "/prompt",
|
|
json={"prompt": workflow, "client_id": client_id},
|
|
timeout=30.0,
|
|
)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
pid = data.get("prompt_id")
|
|
if not isinstance(pid, str) or not pid:
|
|
raise RuntimeError(f"Unexpected /prompt response: {data}")
|
|
return pid
|
|
|
|
|
|
def _get_history_item(base_url: str, prompt_id: str) -> dict[str, Any] | None:
|
|
for url in (f"{base_url.rstrip('/')}/history/{prompt_id}", f"{base_url.rstrip('/')}/history"):
|
|
try:
|
|
r = httpx.get(url, timeout=30.0)
|
|
if r.status_code == 404:
|
|
continue
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
if isinstance(data, dict):
|
|
if prompt_id in data and isinstance(data[prompt_id], dict):
|
|
return data[prompt_id]
|
|
if url.endswith(f"/{prompt_id}") and isinstance(data, dict):
|
|
return data
|
|
return None
|
|
except Exception:
|
|
continue
|
|
return None
|
|
|
|
|
|
def _extract_first_image_view_target(history_item: dict[str, Any]) -> tuple[str, str] | None:
|
|
outputs = history_item.get("outputs")
|
|
if not isinstance(outputs, dict):
|
|
return None
|
|
|
|
def walk(v: Any) -> list[dict[str, Any]]:
|
|
found: list[dict[str, Any]] = []
|
|
if isinstance(v, dict):
|
|
if isinstance(v.get("filename"), str) and v.get("filename").strip():
|
|
found.append(v)
|
|
for vv in v.values():
|
|
found.extend(walk(vv))
|
|
elif isinstance(v, list):
|
|
for vv in v:
|
|
found.extend(walk(vv))
|
|
return found
|
|
|
|
candidates = walk(outputs)
|
|
for c in candidates:
|
|
fn = str(c.get("filename", "")).strip()
|
|
sf = str(c.get("subfolder", "") or "").strip()
|
|
if fn:
|
|
return fn, sf
|
|
return None
|
|
|
|
|
|
def generate_image(
|
|
prompt_text: str,
|
|
output_dir: str | Path,
|
|
*,
|
|
cfg: AppConfig | None = None,
|
|
timeout_s: int = 60,
|
|
retry: int = 2,
|
|
width: int | None = None,
|
|
height: int | None = None,
|
|
filename_prefix: str = "shot",
|
|
ckpt_candidates: list[str] | None = None,
|
|
negative_text: str | None = None,
|
|
) -> Path:
|
|
cfg2 = cfg or AppConfig.load("./configs/config.yaml")
|
|
base_url = str(cfg2.get("app.comfy_base_url", "http://comfyui:8188")).rstrip("/")
|
|
|
|
out_dir = Path(output_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if width is None or height is None:
|
|
mock_size = cfg2.get("video.mock_size", [1024, 576])
|
|
width = int(width or mock_size[0])
|
|
height = int(height or mock_size[1])
|
|
|
|
if negative_text is None:
|
|
negative_text = "low quality, blurry"
|
|
|
|
if ckpt_candidates is None:
|
|
ckpt_candidates = [
|
|
"v1-5-pruned-emaonly.ckpt",
|
|
"v1-5-pruned-emaonly.safetensors",
|
|
"sd-v1-5-tiny.safetensors",
|
|
]
|
|
|
|
last_err: Exception | None = None
|
|
for _attempt in range(max(1, retry)):
|
|
for ckpt_name in ckpt_candidates:
|
|
client_id = str(uuid.uuid4())
|
|
seed = int(uuid.uuid4().int % 2_147_483_647)
|
|
workflow = _build_simple_workflow(
|
|
prompt_text,
|
|
seed=seed,
|
|
ckpt_name=ckpt_name,
|
|
width=width,
|
|
height=height,
|
|
filename_prefix=filename_prefix,
|
|
negative_text=negative_text,
|
|
)
|
|
try:
|
|
prompt_id = _queue_prompt(base_url, workflow, client_id)
|
|
start = time.time()
|
|
while time.time() - start < timeout_s:
|
|
item = _get_history_item(base_url, prompt_id)
|
|
if isinstance(item, dict):
|
|
img_target = _extract_first_image_view_target(item)
|
|
if img_target:
|
|
filename, subfolder = img_target
|
|
view_url = f"{base_url}/view?filename={filename}&subfolder={subfolder}"
|
|
img_resp = httpx.get(view_url, timeout=60.0)
|
|
img_resp.raise_for_status()
|
|
image_path = out_dir / filename
|
|
image_path.write_bytes(img_resp.content)
|
|
return image_path
|
|
time.sleep(1.0)
|
|
except Exception as e:
|
|
last_err = e
|
|
continue
|
|
|
|
raise RuntimeError(f"ComfyUI image generation failed after retries: {last_err}")
|