"""Sigue en vivo el progreso de un prompt ComfyUI por WebSocket (/ws). Funcion impura: red (WebSocket o, en fallback, HTTP GET en bucle). Alternativa en-vivo a comfyui_wait_result (que sondea /history). Aqui se escucha el canal de eventos del servidor (ws:///ws?clientId=) y se siguen los mensajes que ComfyUI emite mientras ejecuta: - type="progress" -> un paso del sampler (data: {value, max, node}). - type="executing" -> el grafo entra en un nodo (data: {node, prompt_id}); node=None con prompt_id propio marca el fin (senal legacy). - type="execution_success" -> el prompt termino bien (data: {prompt_id}). - type="execution_error" -> el prompt fallo (data: {prompt_id, ...}). Defensa contra carreras: un prompt con nodos cacheados completa en <1s, antes de que el WS reciba su execution_success. Por eso esta funcion (a) comprueba /history al entrar — si el prompt ya termino, no hay nada que seguir — y (b) revisa /history en cada ventana de recv sin eventos, para detectar el fin aunque el WS pierda el evento, sin esperar al timeout completo. Si websocket-client NO esta instalado en el interprete que ejecuta esta funcion (el venv del registry no lo trae; el de ComfyUI si), cae limpiamente a polling de /history reutilizando comfyui_wait_result y devuelve el mismo dict con method="polling". """ import json import time def comfyui_stream_progress( prompt_id: str, *, server: str = "127.0.0.1:8188", client_id: str | None = None, timeout: float = 300.0, ) -> dict: """Sigue el progreso en vivo de un prompt por WebSocket; cae a polling si falta ws. Args: prompt_id: id devuelto por comfyui_submit_workflow. server: host:port del servidor ComfyUI sin esquema (default "127.0.0.1:8188"). keyword-only. client_id: clientId para registrar el socket en el servidor; si None se genera un uuid4. keyword-only. timeout: maximo de segundos a esperar a que el prompt complete. keyword-only. Returns: dict con: - ok (bool): True si el seguimiento concluyo sin error (incluido el fallback a polling cuando completa). - completed (bool): True si el prompt termino (success o senal de fin). - steps_seen (int): numero de mensajes "progress" observados por WS (0 en el fallback de polling y en trabajos cacheados, que no emiten pasos intermedios). - last_node (str): id del ultimo nodo en ejecucion visto. - method (str): "websocket" o "polling" segun la via usada. - error (str): mensaje de error; cadena vacia si todo OK. """ out = { "ok": False, "completed": False, "steps_seen": 0, "last_node": "", "method": "websocket", "error": "", } try: from websocket import ( # type: ignore WebSocketTimeoutException, create_connection, ) except ImportError: return _fallback_polling(out, prompt_id, server, timeout) import uuid cid = client_id or uuid.uuid4().hex ws_url = f"ws://{server}/ws?clientId={cid}" deadline = time.time() + timeout # Pre-check: el prompt pudo terminar antes de que conectemos (trabajos con nodos # cacheados completan en <1s). Si ya esta en history, no hay nada que seguir. done, last = _history_check(server, prompt_id) if done: out["completed"] = True out["ok"] = True if last: out["last_node"] = last return out ws = None try: ws = create_connection(ws_url, timeout=min(timeout, 30.0)) except Exception as exc: # noqa: BLE001 — degradar a fallback, no romper out["error"] = f"no se pudo abrir WS {ws_url}: {exc}" return _fallback_polling(out, prompt_id, server, timeout) try: while time.time() < deadline: ws.settimeout(min(2.0, max(0.1, deadline - time.time()))) try: msg = ws.recv() except WebSocketTimeoutException: # Sin evento en la ventana: el fin pudo perderse (carrera con un # trabajo rapido). Confirma por history antes de seguir esperando. done, last = _history_check(server, prompt_id) if done: out["completed"] = True out["ok"] = True if last and not out["last_node"]: out["last_node"] = last break continue if isinstance(msg, (bytes, bytearray)): continue # frames binarios = previews de imagen, se ignoran try: evt = json.loads(msg) except (json.JSONDecodeError, TypeError): continue mtype = evt.get("type") data = evt.get("data", {}) or {} evt_pid = data.get("prompt_id") if mtype == "progress": out["steps_seen"] += 1 node = data.get("node") if node is not None: out["last_node"] = str(node) elif mtype == "executing": node = data.get("node") if node is not None: out["last_node"] = str(node) elif evt_pid == prompt_id: out["completed"] = True out["ok"] = True break elif mtype == "execution_success" and evt_pid == prompt_id: out["completed"] = True out["ok"] = True break elif mtype == "execution_error" and evt_pid == prompt_id: out["error"] = f"execution_error: {json.dumps(data)[:400]}" break else: # deadline agotado sin fin por WS: ultimo check de history por si el # evento de fin se perdio del todo. done, last = _history_check(server, prompt_id) if done: out["completed"] = True out["ok"] = True if last and not out["last_node"]: out["last_node"] = last else: out["error"] = f"timeout de {timeout}s sin fin para {prompt_id}" finally: try: ws.close() except Exception: # noqa: BLE001 pass return out def _history_check(server: str, prompt_id: str) -> tuple: """Consulta GET /history/{prompt_id} una vez (no bucle). Devuelve (done, last_node): done=True solo si el prompt completo con exito; last_node = id del ultimo nodo de output si lo hay. Errores de red se tratan como "aun no" (False), no lanzan. """ import urllib.error import urllib.request url = f"http://{server}/history/{prompt_id}" try: with urllib.request.urlopen(url, timeout=10.0) as resp: hist = json.loads(resp.read()) except (urllib.error.URLError, json.JSONDecodeError, OSError, ValueError): return (False, "") entry = hist.get(prompt_id) if isinstance(hist, dict) else None if not entry: return (False, "") status = entry.get("status", {}) if status.get("completed") or status.get("status_str") == "success": outs = entry.get("outputs", {}) or {} last = str(list(outs)[-1]) if outs else "" return (True, last) return (False, "") def _fallback_polling(out: dict, prompt_id: str, server: str, timeout: float) -> dict: """Cae a polling de /history reutilizando comfyui_wait_result del registry.""" out["method"] = "polling" try: from comfyui_wait_result import comfyui_wait_result # hermano en ml/ except ImportError: from ml.comfyui_wait_result import comfyui_wait_result # via python/functions try: outputs = comfyui_wait_result(prompt_id, server=server, timeout=timeout) out["completed"] = True out["ok"] = True if outputs: out["last_node"] = str(list(outputs)[-1]) except TimeoutError as exc: out["error"] = f"polling: {exc}" except RuntimeError as exc: out["error"] = f"polling: {exc}" return out if __name__ == "__main__": import sys pid = sys.argv[1] if len(sys.argv) > 1 else "" if not pid: print("uso: comfyui_stream_progress.py ", file=sys.stderr) sys.exit(2) res = comfyui_stream_progress(pid) print(json.dumps(res, indent=2))