Files
AiVideo/scripts/inspect_comfy_node.py
2026-03-25 13:33:48 +08:00

332 lines
12 KiB
Python

from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Iterable
import httpx
import yaml
def fetch_object_info(base_url: str, timeout_s: float = 5.0) -> dict[str, Any]:
url = base_url.rstrip("/") + "/object_info"
with httpx.Client(timeout=timeout_s) as client:
r = client.get(url)
r.raise_for_status()
data = r.json()
if not isinstance(data, dict):
raise RuntimeError(f"Unexpected object_info type: {type(data)}")
return data
def load_yaml(path: str | Path) -> dict[str, Any]:
p = Path(path)
if not p.exists():
return {}
raw = yaml.safe_load(p.read_text(encoding="utf-8"))
return raw if isinstance(raw, dict) else {}
def load_json(path: str | Path) -> Any:
p = Path(path)
if not p.exists():
return None
return json.loads(p.read_text(encoding="utf-8"))
def iter_node_class_types(object_info: dict[str, Any]) -> Iterable[str]:
for k in object_info.keys():
if isinstance(k, str):
yield k
def find_ckpt_values(object_info: dict[str, Any]) -> list[str]:
"""
Heuristic: locate any node input that looks like checkpoint selector.
ComfyUI commonly uses CheckpointLoaderSimple.inputs.required.ckpt_name = [[...values...]]
"""
vals: list[str] = []
for node_name, node_info in object_info.items():
if not isinstance(node_info, dict):
continue
inputs = node_info.get("input")
if not isinstance(inputs, dict):
continue
required = inputs.get("required")
if not isinstance(required, dict):
continue
for key in ("ckpt_name", "checkpoint", "model_name"):
entry = required.get(key)
# expected shape: [ [values...], {meta...} ] or [ [values...] ]
if isinstance(entry, list) and entry:
first = entry[0]
if isinstance(first, list):
for v in first:
if isinstance(v, str):
vals.append(v)
# de-dup
seen: set[str] = set()
out: list[str] = []
for v in vals:
if v not in seen:
seen.add(v)
out.append(v)
return out
def has_ksampler_seed(object_info: dict[str, Any], ks_classes: list[str], seed_key: str) -> bool:
for cls in ks_classes:
info = object_info.get(cls)
if not isinstance(info, dict):
continue
inputs = info.get("input")
if not isinstance(inputs, dict):
continue
required = inputs.get("required")
optional = inputs.get("optional")
if isinstance(required, dict) and seed_key in required:
return True
if isinstance(optional, dict) and seed_key in optional:
return True
return False
def resolve_seed_target_from_workflow(workflow: Any, seed_class_types: list[str]) -> tuple[str | None, str | None]:
"""
Returns (node_id, class_type) by scanning workflow dict for first matching class_type.
workflow_api.json is typically { node_id: {class_type, inputs, ...}, ... }
"""
if not isinstance(workflow, dict):
return (None, None)
want = set(seed_class_types)
for node_id, node in workflow.items():
if not isinstance(node, dict):
continue
ct = node.get("class_type")
if isinstance(ct, str) and ct in want:
return (str(node_id), ct)
return (None, None)
def _workflow_nodes(workflow: Any) -> dict[str, Any]:
if not isinstance(workflow, dict):
raise ValueError("workflow_api.json root must be an object mapping node_id -> node")
return workflow
def _get_node(workflow: dict[str, Any], node_id: str) -> dict[str, Any]:
node = workflow.get(str(node_id))
if not isinstance(node, dict):
raise KeyError(f"workflow missing node_id={node_id}")
return node
def _validate_configured_node_id(
*,
workflow: dict[str, Any],
node_id: Any,
allowed_class_types: list[str],
name: str,
) -> list[str]:
errs: list[str] = []
if node_id is None or not str(node_id).strip():
return errs
nid = str(node_id).strip()
try:
node = _get_node(workflow, nid)
except Exception as e:
return [f"{name}: configured node_id={nid} not found in workflow ({e})"]
ct = node.get("class_type")
if allowed_class_types and isinstance(ct, str) and ct not in set(allowed_class_types):
errs.append(f"{name}: node_id={nid} class_type={ct} not in allowed {allowed_class_types}")
return errs
def _workflow_has_ltx_node(workflow: dict[str, Any], keyword: str) -> bool:
kw = keyword.lower()
for _nid, node in workflow.items():
if not isinstance(node, dict):
continue
ct = node.get("class_type")
if isinstance(ct, str) and kw in ct.lower():
return True
return False
def main() -> int:
p = argparse.ArgumentParser(description="Inspect ComfyUI /object_info for LTX + checkpoints + sampler override readiness")
p.add_argument("--base-url", default="")
p.add_argument("--timeout", type=float, default=8.0)
p.add_argument("--config", default="./configs/config.yaml")
p.add_argument("--workflow", default="./workflow_api.json")
p.add_argument(
"--expected-checkpoint",
action="append",
default=[],
help="Expected checkpoint name (can repeat). Exact match against ckpt list.",
)
p.add_argument(
"--ltx-keyword",
default="LTX",
help="Keyword to detect LTX-Video nodes in object_info keys (default: LTX)",
)
args = p.parse_args()
cfg = load_yaml(args.config)
base_url = (args.base_url or "").strip()
if not base_url:
app_cfg = (cfg.get("app") or {}) if isinstance(cfg, dict) else {}
if isinstance(app_cfg, dict):
base_url = str(app_cfg.get("comfy_base_url", "")).strip()
if not base_url:
base_url = "http://127.0.0.1:8188"
comfy_cfg = (cfg.get("comfy_workflow") or {}) if isinstance(cfg, dict) else {}
seed_key = str(comfy_cfg.get("seed_input_key", "seed"))
seed_class_types = comfy_cfg.get("seed_node_class_types") or ["KSampler", "KSamplerAdvanced"]
if not isinstance(seed_class_types, list):
seed_class_types = ["KSampler", "KSamplerAdvanced"]
seed_class_types = [str(x) for x in seed_class_types]
# Industrial hard requirement: workflow must exist for ID matching checks
wf_path = Path(args.workflow)
if not wf_path.exists():
sys.stderr.write(f"[inspect] FAIL: workflow_api.json not found at {wf_path}\n")
return 3
try:
object_info = fetch_object_info(base_url, timeout_s=args.timeout)
except Exception as e:
sys.stderr.write(f"[inspect] ERROR fetch /object_info: {e}\n")
return 2
# 1) LTX-Video plugin activated? (heuristic)
keyword = str(args.ltx_keyword or "LTX")
ltx_hits = sorted([k for k in iter_node_class_types(object_info) if keyword.lower() in k.lower()])
ltx_ok = len(ltx_hits) > 0
# 2) checkpoint list includes expected
ckpts = find_ckpt_values(object_info)
expected = list(args.expected_checkpoint or [])
missing = [x for x in expected if x not in ckpts]
ckpt_ok = len(missing) == 0 if expected else True
# 3) KSampler defaults overridden by our python? (readiness check)
# /object_info cannot prove runtime override happened, but we can validate:
# - ComfyUI exposes a sampler node class with a 'seed' input key
# - our config intends to override that same key
ks_ok = has_ksampler_seed(object_info, seed_class_types, seed_key)
wf = load_json(args.workflow)
try:
wf_nodes = _workflow_nodes(wf)
except Exception as e:
sys.stderr.write(f"[inspect] FAIL: invalid workflow format: {e}\n")
return 3
seed_node_id, seed_node_class = resolve_seed_target_from_workflow(wf_nodes, seed_class_types)
wf_ok = seed_node_id is not None
# Hard validation: configured node IDs must exist and match expected class_type families
prompt_allowed = [str(x) for x in (comfy_cfg.get("prompt_node_class_types") or []) if str(x).strip()]
seed_allowed = [str(x) for x in (comfy_cfg.get("seed_node_class_types") or []) if str(x).strip()]
save_allowed = [str(x) for x in (comfy_cfg.get("save_node_class_types") or []) if str(x).strip()]
errs: list[str] = []
errs += _validate_configured_node_id(
workflow=wf_nodes,
node_id=comfy_cfg.get("prompt_node_id"),
allowed_class_types=prompt_allowed,
name="prompt_node_id",
)
errs += _validate_configured_node_id(
workflow=wf_nodes,
node_id=comfy_cfg.get("seed_node_id"),
allowed_class_types=seed_allowed,
name="seed_node_id",
)
errs += _validate_configured_node_id(
workflow=wf_nodes,
node_id=comfy_cfg.get("save_node_id"),
allowed_class_types=save_allowed,
name="save_node_id",
)
errs += _validate_configured_node_id(
workflow=wf_nodes,
node_id=comfy_cfg.get("motion_node_id"),
allowed_class_types=[],
name="motion_node_id",
)
# Hard validation: workflow must contain LTX node(s) if we're using LTX-Video pipeline
wf_ltx_ok = _workflow_has_ltx_node(wf_nodes, str(args.ltx_keyword or "LTX"))
# Hard validation: seed node in workflow must expose the seed input key (so it can be overridden)
wf_seed_key_ok = False
if wf_ok:
try:
node = _get_node(wf_nodes, str(seed_node_id))
inputs = node.get("inputs")
wf_seed_key_ok = isinstance(inputs, dict) and seed_key in inputs
except Exception:
wf_seed_key_ok = False
report = {
"base_url": base_url,
"ltx": {
"keyword": keyword,
"activated": ltx_ok,
"matching_nodes": ltx_hits[:50],
"match_count": len(ltx_hits),
},
"checkpoints": {
"expected": expected,
"found_count": len(ckpts),
"missing": missing,
"ok": ckpt_ok,
"sample": ckpts[:50],
},
"sampler_override_readiness": {
"seed_input_key_from_config": seed_key,
"seed_node_class_types_from_config": seed_class_types,
"comfy_has_seed_input": ks_ok,
"workflow_path": args.workflow,
"workflow_seed_node_detected": wf_ok,
"workflow_seed_node_id": seed_node_id,
"workflow_seed_node_class_type": seed_node_class,
"workflow_seed_node_has_seed_key": wf_seed_key_ok,
"note": "object_info cannot prove runtime override; this enforces key alignment + workflow ID/class checks.",
},
"workflow_validation": {
"ltx_node_in_workflow": wf_ltx_ok,
"configured_node_id_errors": errs,
},
"ok": bool(ltx_ok and ckpt_ok and ks_ok and wf_ok and wf_ltx_ok and wf_seed_key_ok and not errs),
}
sys.stdout.write(json.dumps(report, ensure_ascii=False, indent=2) + "\n")
if not ltx_ok:
sys.stderr.write(f"[inspect] FAIL: no node matched keyword '{keyword}' (LTX plugin may be missing)\n")
if not ckpt_ok:
sys.stderr.write(f"[inspect] FAIL: missing checkpoints: {missing}\n")
if not ks_ok:
sys.stderr.write(f"[inspect] FAIL: ComfyUI sampler classes {seed_class_types} do not expose input '{seed_key}'\n")
if not wf_ok:
sys.stderr.write(f"[inspect] FAIL: workflow does not contain a seed node of class types {seed_class_types}\n")
if not wf_ltx_ok:
sys.stderr.write(f"[inspect] FAIL: workflow has no node with class_type containing '{args.ltx_keyword}'\n")
if wf_ok and not wf_seed_key_ok:
sys.stderr.write(f"[inspect] FAIL: workflow seed node {seed_node_id} does not expose inputs['{seed_key}']\n")
if errs:
for e in errs:
sys.stderr.write(f"[inspect] FAIL: {e}\n")
return 0 if report["ok"] else 3
if __name__ == "__main__":
raise SystemExit(main())