Files
AiVideo/engine/comfy_client.py
2026-03-18 17:36:07 +08:00

189 lines
7.8 KiB
Python

from __future__ import annotations
import asyncio
import json
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
import httpx
from .config import AppConfig
@dataclass(frozen=True)
class ComfyResult:
prompt_id: str
output_files: list[Path]
class ComfyClient:
def __init__(self, cfg: AppConfig):
self.cfg = cfg
self.base_url = str(cfg.get("app.comfy_base_url", "http://127.0.0.1:8188")).rstrip("/")
self.output_dir = Path(str(cfg.get("app.comfy_output_dir", "./ComfyUI/output")))
self.workflow_path = Path(str(cfg.get("comfy_workflow.workflow_path", "./workflow_api.json")))
self._client_id = str(uuid.uuid4())
def load_workflow(self) -> dict[str, Any]:
if not self.workflow_path.exists():
raise FileNotFoundError(f"workflow file not found: {self.workflow_path}")
raw = json.loads(self.workflow_path.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError(f"workflow_api.json root must be dict, got {type(raw)}")
return raw
def _nodes(self, workflow: dict[str, Any]) -> dict[str, Any]:
# ComfyUI API workflow exports typically use { node_id: {class_type, inputs, ...}, ... }
return workflow
def _find_node_id_by_class_type(self, workflow: dict[str, Any], class_types: Iterable[str]) -> str | None:
want = {c.strip() for c in class_types if c and str(c).strip()}
if not want:
return None
for node_id, node in self._nodes(workflow).items():
if not isinstance(node, dict):
continue
ct = node.get("class_type")
if isinstance(ct, str) and ct in want:
return str(node_id)
return None
def _resolve_node_id(self, workflow: dict[str, Any], configured_id: Any, fallback_class_types_key: str) -> str:
if configured_id is not None and str(configured_id).strip():
node_id = str(configured_id).strip()
if node_id not in self._nodes(workflow):
raise KeyError(f"Configured node_id {node_id} not found in workflow")
return node_id
class_types = self.cfg.get(f"comfy_workflow.{fallback_class_types_key}", []) or []
if not isinstance(class_types, list):
raise ValueError(f"Config comfy_workflow.{fallback_class_types_key} must be list")
found = self._find_node_id_by_class_type(workflow, [str(x) for x in class_types])
if not found:
raise KeyError(f"Cannot resolve node by class types: {class_types}")
return found
def inject_params(self, workflow: dict[str, Any], image_prompt: str, seed: int, motion_prompt: str | None = None) -> dict[str, Any]:
wf = json.loads(json.dumps(workflow)) # deep copy
prompt_node_id = self._resolve_node_id(
wf,
self.cfg.get("comfy_workflow.prompt_node_id", None),
"prompt_node_class_types",
)
prompt_key = str(self.cfg.get("comfy_workflow.prompt_input_key", "text"))
self._set_input(wf, prompt_node_id, prompt_key, image_prompt)
seed_node_id = self._resolve_node_id(
wf,
self.cfg.get("comfy_workflow.seed_node_id", None),
"seed_node_class_types",
)
seed_key = str(self.cfg.get("comfy_workflow.seed_input_key", "seed"))
self._set_input(wf, seed_node_id, seed_key, int(seed))
motion_node_id = self.cfg.get("comfy_workflow.motion_node_id", None)
if motion_prompt and motion_node_id is not None and str(motion_node_id).strip():
motion_key = str(self.cfg.get("comfy_workflow.motion_input_key", "text"))
self._set_input(wf, str(motion_node_id).strip(), motion_key, motion_prompt)
return wf
def _set_input(self, workflow: dict[str, Any], node_id: str, key: str, value: Any) -> None:
node = self._nodes(workflow).get(str(node_id))
if not isinstance(node, dict):
raise KeyError(f"Node {node_id} not found")
inputs = node.get("inputs")
if inputs is None:
inputs = {}
node["inputs"] = inputs
if not isinstance(inputs, dict):
raise TypeError(f"Node {node_id} inputs must be dict, got {type(inputs)}")
inputs[key] = value
async def _post_prompt(self, client: httpx.AsyncClient, workflow: dict[str, Any]) -> str:
url = f"{self.base_url}/prompt"
payload = {"prompt": workflow, "client_id": self._client_id}
r = await client.post(url, json=payload)
r.raise_for_status()
data = r.json()
pid = data.get("prompt_id") or data.get("prompt_id".upper())
if not isinstance(pid, str) or not pid:
raise RuntimeError(f"Unexpected /prompt response: {data}")
return pid
async def _get_history(self, client: httpx.AsyncClient, prompt_id: str) -> dict[str, Any] | None:
# Common endpoints:
# - /history/{prompt_id}
# - /history (returns all histories keyed by prompt id)
for url in (f"{self.base_url}/history/{prompt_id}", f"{self.base_url}/history"):
try:
r = await client.get(url)
if r.status_code == 404:
continue
r.raise_for_status()
data = r.json()
if isinstance(data, dict):
if prompt_id in data and isinstance(data[prompt_id], dict):
return data[prompt_id]
if url.endswith(f"/{prompt_id}"):
return data
return None
except httpx.HTTPStatusError:
raise
except Exception:
continue
return None
def _extract_output_files(self, history_item: dict[str, Any]) -> list[Path]:
out: list[Path] = []
outputs = history_item.get("outputs")
if not isinstance(outputs, dict):
return out
def walk(v: Any) -> None:
if isinstance(v, dict):
# ComfyUI tends to store files like {"filename":"x.mp4","subfolder":"","type":"output"}
fn = v.get("filename")
if isinstance(fn, str) and fn.strip():
out.append(self.output_dir / fn)
for vv in v.values():
walk(vv)
elif isinstance(v, list):
for vv in v:
walk(vv)
walk(outputs)
# De-dup while preserving order
seen: set[str] = set()
uniq: list[Path] = []
for p in out:
s = str(p)
if s not in seen:
seen.add(s)
uniq.append(p)
return uniq
async def run_workflow(self, workflow: dict[str, Any], *, poll_interval_s: float = 1.0, timeout_s: float = 300.0) -> ComfyResult:
async with httpx.AsyncClient(timeout=30.0) as client:
prompt_id = await self._post_prompt(client, workflow)
deadline = asyncio.get_event_loop().time() + timeout_s
last_files: list[Path] = []
while True:
if asyncio.get_event_loop().time() > deadline:
raise TimeoutError(f"ComfyUI job timeout: {prompt_id}")
item = await self._get_history(client, prompt_id)
if isinstance(item, dict):
files = self._extract_output_files(item)
if files:
last_files = files
# Heuristic: if any file exists on disk, treat as done.
if any(p.exists() for p in files):
return ComfyResult(prompt_id=prompt_id, output_files=files)
await asyncio.sleep(poll_interval_s)
# unreachable
# return ComfyResult(prompt_id=prompt_id, output_files=last_files)