Files
fn_registry/python/functions/ml/comfyui_stream_progress.py
T
egutierrez ff41f4f053 feat(ml): auto-commit con 7 cambios
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-24 02:52:51 +02:00

221 lines
8.3 KiB
Python

"""Sigue en vivo el progreso de un prompt ComfyUI por WebSocket (/ws).
Funcion impura: red (WebSocket o, en fallback, HTTP GET en bucle).
Alternativa en-vivo a comfyui_wait_result (que sondea /history). Aqui se escucha el
canal de eventos del servidor (ws://<server>/ws?clientId=<id>) y se siguen los
mensajes que ComfyUI emite mientras ejecuta:
- type="progress" -> un paso del sampler (data: {value, max, node}).
- type="executing" -> el grafo entra en un nodo (data: {node, prompt_id});
node=None con prompt_id propio marca el fin (senal legacy).
- type="execution_success" -> el prompt termino bien (data: {prompt_id}).
- type="execution_error" -> el prompt fallo (data: {prompt_id, ...}).
Defensa contra carreras: un prompt con nodos cacheados completa en <1s, antes de que
el WS reciba su execution_success. Por eso esta funcion (a) comprueba /history al
entrar — si el prompt ya termino, no hay nada que seguir — y (b) revisa /history en
cada ventana de recv sin eventos, para detectar el fin aunque el WS pierda el evento,
sin esperar al timeout completo.
Si websocket-client NO esta instalado en el interprete que ejecuta esta funcion
(el venv del registry no lo trae; el de ComfyUI si), cae limpiamente a polling de
/history reutilizando comfyui_wait_result y devuelve el mismo dict con
method="polling".
"""
import json
import time
def comfyui_stream_progress(
prompt_id: str,
*,
server: str = "127.0.0.1:8188",
client_id: str | None = None,
timeout: float = 300.0,
) -> dict:
"""Sigue el progreso en vivo de un prompt por WebSocket; cae a polling si falta ws.
Args:
prompt_id: id devuelto por comfyui_submit_workflow.
server: host:port del servidor ComfyUI sin esquema (default
"127.0.0.1:8188"). keyword-only.
client_id: clientId para registrar el socket en el servidor; si None se
genera un uuid4. keyword-only.
timeout: maximo de segundos a esperar a que el prompt complete.
keyword-only.
Returns:
dict con:
- ok (bool): True si el seguimiento concluyo sin error (incluido el
fallback a polling cuando completa).
- completed (bool): True si el prompt termino (success o senal de fin).
- steps_seen (int): numero de mensajes "progress" observados por WS
(0 en el fallback de polling y en trabajos cacheados, que no emiten
pasos intermedios).
- last_node (str): id del ultimo nodo en ejecucion visto.
- method (str): "websocket" o "polling" segun la via usada.
- error (str): mensaje de error; cadena vacia si todo OK.
"""
out = {
"ok": False,
"completed": False,
"steps_seen": 0,
"last_node": "",
"method": "websocket",
"error": "",
}
try:
from websocket import ( # type: ignore
WebSocketTimeoutException,
create_connection,
)
except ImportError:
return _fallback_polling(out, prompt_id, server, timeout)
import uuid
cid = client_id or uuid.uuid4().hex
ws_url = f"ws://{server}/ws?clientId={cid}"
deadline = time.time() + timeout
# Pre-check: el prompt pudo terminar antes de que conectemos (trabajos con nodos
# cacheados completan en <1s). Si ya esta en history, no hay nada que seguir.
done, last = _history_check(server, prompt_id)
if done:
out["completed"] = True
out["ok"] = True
if last:
out["last_node"] = last
return out
ws = None
try:
ws = create_connection(ws_url, timeout=min(timeout, 30.0))
except Exception as exc: # noqa: BLE001 — degradar a fallback, no romper
out["error"] = f"no se pudo abrir WS {ws_url}: {exc}"
return _fallback_polling(out, prompt_id, server, timeout)
try:
while time.time() < deadline:
ws.settimeout(min(2.0, max(0.1, deadline - time.time())))
try:
msg = ws.recv()
except WebSocketTimeoutException:
# Sin evento en la ventana: el fin pudo perderse (carrera con un
# trabajo rapido). Confirma por history antes de seguir esperando.
done, last = _history_check(server, prompt_id)
if done:
out["completed"] = True
out["ok"] = True
if last and not out["last_node"]:
out["last_node"] = last
break
continue
if isinstance(msg, (bytes, bytearray)):
continue # frames binarios = previews de imagen, se ignoran
try:
evt = json.loads(msg)
except (json.JSONDecodeError, TypeError):
continue
mtype = evt.get("type")
data = evt.get("data", {}) or {}
evt_pid = data.get("prompt_id")
if mtype == "progress":
out["steps_seen"] += 1
node = data.get("node")
if node is not None:
out["last_node"] = str(node)
elif mtype == "executing":
node = data.get("node")
if node is not None:
out["last_node"] = str(node)
elif evt_pid == prompt_id:
out["completed"] = True
out["ok"] = True
break
elif mtype == "execution_success" and evt_pid == prompt_id:
out["completed"] = True
out["ok"] = True
break
elif mtype == "execution_error" and evt_pid == prompt_id:
out["error"] = f"execution_error: {json.dumps(data)[:400]}"
break
else:
# deadline agotado sin fin por WS: ultimo check de history por si el
# evento de fin se perdio del todo.
done, last = _history_check(server, prompt_id)
if done:
out["completed"] = True
out["ok"] = True
if last and not out["last_node"]:
out["last_node"] = last
else:
out["error"] = f"timeout de {timeout}s sin fin para {prompt_id}"
finally:
try:
ws.close()
except Exception: # noqa: BLE001
pass
return out
def _history_check(server: str, prompt_id: str) -> tuple:
"""Consulta GET /history/{prompt_id} una vez (no bucle).
Devuelve (done, last_node): done=True solo si el prompt completo con exito;
last_node = id del ultimo nodo de output si lo hay. Errores de red se tratan
como "aun no" (False), no lanzan.
"""
import urllib.error
import urllib.request
url = f"http://{server}/history/{prompt_id}"
try:
with urllib.request.urlopen(url, timeout=10.0) as resp:
hist = json.loads(resp.read())
except (urllib.error.URLError, json.JSONDecodeError, OSError, ValueError):
return (False, "")
entry = hist.get(prompt_id) if isinstance(hist, dict) else None
if not entry:
return (False, "")
status = entry.get("status", {})
if status.get("completed") or status.get("status_str") == "success":
outs = entry.get("outputs", {}) or {}
last = str(list(outs)[-1]) if outs else ""
return (True, last)
return (False, "")
def _fallback_polling(out: dict, prompt_id: str, server: str, timeout: float) -> dict:
"""Cae a polling de /history reutilizando comfyui_wait_result del registry."""
out["method"] = "polling"
try:
from comfyui_wait_result import comfyui_wait_result # hermano en ml/
except ImportError:
from ml.comfyui_wait_result import comfyui_wait_result # via python/functions
try:
outputs = comfyui_wait_result(prompt_id, server=server, timeout=timeout)
out["completed"] = True
out["ok"] = True
if outputs:
out["last_node"] = str(list(outputs)[-1])
except TimeoutError as exc:
out["error"] = f"polling: {exc}"
except RuntimeError as exc:
out["error"] = f"polling: {exc}"
return out
if __name__ == "__main__":
import sys
pid = sys.argv[1] if len(sys.argv) > 1 else ""
if not pid:
print("uso: comfyui_stream_progress.py <prompt_id>", file=sys.stderr)
sys.exit(2)
res = comfyui_stream_progress(pid)
print(json.dumps(res, indent=2))