from __future__ import annotations import asyncio import json import time import uuid from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable import httpx from .config import AppConfig @dataclass(frozen=True) class ComfyResult: prompt_id: str output_files: list[Path] class ComfyClient: def __init__(self, cfg: AppConfig): self.cfg = cfg self.base_url = str(cfg.get("app.comfy_base_url", "http://127.0.0.1:8188")).rstrip("/") self.output_dir = Path(str(cfg.get("app.comfy_output_dir", "./ComfyUI/output"))) self.workflow_path = Path(str(cfg.get("comfy_workflow.workflow_path", "./workflow_api.json"))) self._client_id = str(uuid.uuid4()) def load_workflow(self) -> dict[str, Any]: if not self.workflow_path.exists(): raise FileNotFoundError(f"workflow file not found: {self.workflow_path}") raw = json.loads(self.workflow_path.read_text(encoding="utf-8")) if not isinstance(raw, dict): raise ValueError(f"workflow_api.json root must be dict, got {type(raw)}") return raw def _nodes(self, workflow: dict[str, Any]) -> dict[str, Any]: # ComfyUI API workflow exports typically use { node_id: {class_type, inputs, ...}, ... } return workflow def _find_node_id_by_class_type(self, workflow: dict[str, Any], class_types: Iterable[str]) -> str | None: want = {c.strip() for c in class_types if c and str(c).strip()} if not want: return None for node_id, node in self._nodes(workflow).items(): if not isinstance(node, dict): continue ct = node.get("class_type") if isinstance(ct, str) and ct in want: return str(node_id) return None def _resolve_node_id(self, workflow: dict[str, Any], configured_id: Any, fallback_class_types_key: str) -> str: if configured_id is not None and str(configured_id).strip(): node_id = str(configured_id).strip() if node_id not in self._nodes(workflow): raise KeyError(f"Configured node_id {node_id} not found in workflow") return node_id class_types = self.cfg.get(f"comfy_workflow.{fallback_class_types_key}", []) or [] if not isinstance(class_types, list): raise ValueError(f"Config comfy_workflow.{fallback_class_types_key} must be list") found = self._find_node_id_by_class_type(workflow, [str(x) for x in class_types]) if not found: raise KeyError(f"Cannot resolve node by class types: {class_types}") return found def inject_params(self, workflow: dict[str, Any], image_prompt: str, seed: int, motion_prompt: str | None = None) -> dict[str, Any]: wf = json.loads(json.dumps(workflow)) # deep copy prompt_node_id = self._resolve_node_id( wf, self.cfg.get("comfy_workflow.prompt_node_id", None), "prompt_node_class_types", ) prompt_key = str(self.cfg.get("comfy_workflow.prompt_input_key", "text")) self._set_input(wf, prompt_node_id, prompt_key, image_prompt) seed_node_id = self._resolve_node_id( wf, self.cfg.get("comfy_workflow.seed_node_id", None), "seed_node_class_types", ) seed_key = str(self.cfg.get("comfy_workflow.seed_input_key", "seed")) self._set_input(wf, seed_node_id, seed_key, int(seed)) motion_node_id = self.cfg.get("comfy_workflow.motion_node_id", None) if motion_prompt and motion_node_id is not None and str(motion_node_id).strip(): motion_key = str(self.cfg.get("comfy_workflow.motion_input_key", "text")) self._set_input(wf, str(motion_node_id).strip(), motion_key, motion_prompt) return wf def _set_input(self, workflow: dict[str, Any], node_id: str, key: str, value: Any) -> None: node = self._nodes(workflow).get(str(node_id)) if not isinstance(node, dict): raise KeyError(f"Node {node_id} not found") inputs = node.get("inputs") if inputs is None: inputs = {} node["inputs"] = inputs if not isinstance(inputs, dict): raise TypeError(f"Node {node_id} inputs must be dict, got {type(inputs)}") inputs[key] = value async def _post_prompt(self, client: httpx.AsyncClient, workflow: dict[str, Any]) -> str: url = f"{self.base_url}/prompt" payload = {"prompt": workflow, "client_id": self._client_id} r = await client.post(url, json=payload) r.raise_for_status() data = r.json() pid = data.get("prompt_id") or data.get("prompt_id".upper()) if not isinstance(pid, str) or not pid: raise RuntimeError(f"Unexpected /prompt response: {data}") return pid async def _get_history(self, client: httpx.AsyncClient, prompt_id: str) -> dict[str, Any] | None: # Common endpoints: # - /history/{prompt_id} # - /history (returns all histories keyed by prompt id) for url in (f"{self.base_url}/history/{prompt_id}", f"{self.base_url}/history"): try: r = await client.get(url) if r.status_code == 404: continue r.raise_for_status() data = r.json() if isinstance(data, dict): if prompt_id in data and isinstance(data[prompt_id], dict): return data[prompt_id] if url.endswith(f"/{prompt_id}"): return data return None except httpx.HTTPStatusError: raise except Exception: continue return None def _extract_output_files(self, history_item: dict[str, Any]) -> list[Path]: out: list[Path] = [] outputs = history_item.get("outputs") if not isinstance(outputs, dict): return out def walk(v: Any) -> None: if isinstance(v, dict): # ComfyUI tends to store files like {"filename":"x.mp4","subfolder":"","type":"output"} fn = v.get("filename") if isinstance(fn, str) and fn.strip(): out.append(self.output_dir / fn) for vv in v.values(): walk(vv) elif isinstance(v, list): for vv in v: walk(vv) walk(outputs) # De-dup while preserving order seen: set[str] = set() uniq: list[Path] = [] for p in out: s = str(p) if s not in seen: seen.add(s) uniq.append(p) return uniq async def run_workflow(self, workflow: dict[str, Any], *, poll_interval_s: float = 1.0, timeout_s: float = 300.0) -> ComfyResult: async with httpx.AsyncClient(timeout=30.0) as client: prompt_id = await self._post_prompt(client, workflow) deadline = asyncio.get_event_loop().time() + timeout_s last_files: list[Path] = [] while True: if asyncio.get_event_loop().time() > deadline: raise TimeoutError(f"ComfyUI job timeout: {prompt_id}") item = await self._get_history(client, prompt_id) if isinstance(item, dict): files = self._extract_output_files(item) if files: last_files = files # Heuristic: if any file exists on disk, treat as done. if any(p.exists() for p in files): return ComfyResult(prompt_id=prompt_id, output_files=files) await asyncio.sleep(poll_interval_s) # unreachable # return ComfyResult(prompt_id=prompt_id, output_files=last_files) # --------------------------------------------------------------------------- # Minimal "text->image" helpers (used by shot rendering) # --------------------------------------------------------------------------- def _build_simple_workflow( prompt_text: str, *, seed: int, ckpt_name: str, width: int, height: int, steps: int = 20, cfg: float = 8.0, sampler_name: str = "euler", scheduler: str = "normal", denoise: float = 1.0, filename_prefix: str = "shot", negative_text: str = "low quality, blurry", ) -> dict[str, Any]: # Best-effort workflow. If your ComfyUI nodes/models differ, generation must fallback. return { "3": { "class_type": "KSampler", "inputs": { "seed": int(seed), "steps": int(steps), "cfg": float(cfg), "sampler_name": sampler_name, "scheduler": scheduler, "denoise": float(denoise), "model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0], "latent_image": ["5", 0], }, }, "4": { "class_type": "CheckpointLoaderSimple", "inputs": { "ckpt_name": ckpt_name, }, }, "5": { "class_type": "EmptyLatentImage", "inputs": { "width": int(width), "height": int(height), "batch_size": 1, }, }, "6": { "class_type": "CLIPTextEncode", "inputs": { "text": prompt_text, "clip": ["4", 1], }, }, "7": { "class_type": "CLIPTextEncode", "inputs": { "text": negative_text, "clip": ["4", 1], }, }, "8": { "class_type": "VAEDecode", "inputs": { "samples": ["3", 0], "vae": ["4", 2], }, }, "9": { "class_type": "SaveImage", "inputs": { "images": ["8", 0], "filename_prefix": filename_prefix, }, }, } def _queue_prompt(base_url: str, workflow: dict[str, Any], client_id: str) -> str: r = httpx.post( base_url.rstrip("/") + "/prompt", json={"prompt": workflow, "client_id": client_id}, timeout=30.0, ) r.raise_for_status() data = r.json() pid = data.get("prompt_id") if not isinstance(pid, str) or not pid: raise RuntimeError(f"Unexpected /prompt response: {data}") return pid def _get_history_item(base_url: str, prompt_id: str) -> dict[str, Any] | None: for url in (f"{base_url.rstrip('/')}/history/{prompt_id}", f"{base_url.rstrip('/')}/history"): try: r = httpx.get(url, timeout=30.0) if r.status_code == 404: continue r.raise_for_status() data = r.json() if isinstance(data, dict): if prompt_id in data and isinstance(data[prompt_id], dict): return data[prompt_id] if url.endswith(f"/{prompt_id}") and isinstance(data, dict): return data return None except Exception: continue return None def _extract_first_image_view_target(history_item: dict[str, Any]) -> tuple[str, str] | None: outputs = history_item.get("outputs") if not isinstance(outputs, dict): return None def walk(v: Any) -> list[dict[str, Any]]: found: list[dict[str, Any]] = [] if isinstance(v, dict): if isinstance(v.get("filename"), str) and v.get("filename").strip(): found.append(v) for vv in v.values(): found.extend(walk(vv)) elif isinstance(v, list): for vv in v: found.extend(walk(vv)) return found candidates = walk(outputs) for c in candidates: fn = str(c.get("filename", "")).strip() sf = str(c.get("subfolder", "") or "").strip() if fn: return fn, sf return None def generate_image( prompt_text: str, output_dir: str | Path, *, cfg: AppConfig | None = None, timeout_s: int = 60, retry: int = 2, width: int | None = None, height: int | None = None, filename_prefix: str = "shot", ckpt_candidates: list[str] | None = None, negative_text: str | None = None, ) -> Path: cfg2 = cfg or AppConfig.load("./configs/config.yaml") base_url = str(cfg2.get("app.comfy_base_url", "http://comfyui:8188")).rstrip("/") out_dir = Path(output_dir) out_dir.mkdir(parents=True, exist_ok=True) if width is None or height is None: mock_size = cfg2.get("video.mock_size", [1024, 576]) width = int(width or mock_size[0]) height = int(height or mock_size[1]) if negative_text is None: negative_text = "low quality, blurry" if ckpt_candidates is None: ckpt_candidates = [ "v1-5-pruned-emaonly.ckpt", "v1-5-pruned-emaonly.safetensors", "sd-v1-5-tiny.safetensors", ] last_err: Exception | None = None for _attempt in range(max(1, retry)): for ckpt_name in ckpt_candidates: client_id = str(uuid.uuid4()) seed = int(uuid.uuid4().int % 2_147_483_647) workflow = _build_simple_workflow( prompt_text, seed=seed, ckpt_name=ckpt_name, width=width, height=height, filename_prefix=filename_prefix, negative_text=negative_text, ) try: prompt_id = _queue_prompt(base_url, workflow, client_id) start = time.time() while time.time() - start < timeout_s: item = _get_history_item(base_url, prompt_id) if isinstance(item, dict): img_target = _extract_first_image_view_target(item) if img_target: filename, subfolder = img_target view_url = f"{base_url}/view?filename={filename}&subfolder={subfolder}" img_resp = httpx.get(view_url, timeout=60.0) img_resp.raise_for_status() image_path = out_dir / filename image_path.write_bytes(img_resp.content) return image_path time.sleep(1.0) except Exception as e: last_err = e continue raise RuntimeError(f"ComfyUI image generation failed after retries: {last_err}")