Files
fn_registry/python/functions/ml/diffusers_load_pipeline.py
egutierrez a802f59f55 chore: auto-commit (95 archivos)
- cmd/fn/doctor.go
- cmd/fn/main.go
- cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt
- cpp/apps/primitives_gallery/playground/tables/data_table.cpp
- cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp
- cpp/apps/primitives_gallery/playground/tables/data_table_logic.h
- cpp/apps/primitives_gallery/playground/tables/self_test.cpp
- cpp/apps/primitives_gallery/playground/tables/tql.cpp
- cpp/apps/primitives_gallery/playground/tables/viz.cpp
- cpp/apps/primitives_gallery/playground/tables/viz.h
- ...

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 00:50:34 +02:00

103 lines
3.4 KiB
Python

"""diffusers_load_pipeline — carga un pipeline diffusers con cache global por (model, dtype, device)."""
from __future__ import annotations
import sys
import os
import time
from typing import Any
sys.path.insert(0, os.path.dirname(__file__))
from model_ref import ModelRef
from torch_device_select import torch_device_select
# Cache global: (model_key, dtype, device) -> pipeline object
_PIPELINE_CACHE: dict[tuple[str, str, str], Any] = {}
def _get_model_key(model: ModelRef) -> str:
"""Retorna la clave de cache para un ModelRef."""
return model.path if model.path else model.name
def diffusers_load_pipeline(
model: ModelRef,
device: str = "auto",
dtype: str = "fp16",
) -> Any:
"""Carga un pipeline diffusers con cache global por (model_key, dtype, device).
Usa AutoPipelineForText2Image.from_pretrained con torch_dtype=torch.float16
y variant="fp16" por defecto. Hace pipe.to(device) tras la carga. Los
pipelines se cachean en memoria — segunda llamada con los mismos parametros
retorna el objeto cacheado sin recargar el modelo del disco.
Args:
model: Referencia al modelo. model.path se usa si esta presente;
si no, model.name se pasa directo a from_pretrained (HF hub).
device: Preferencia de device. 'auto' delega a torch_device_select
(CUDA > MPS > CPU). Ejemplos: 'auto', 'cuda', 'cuda:0', 'cpu'.
dtype: Precision del modelo. 'fp16' usa torch.float16 + variant="fp16".
'fp32' usa torch.float32 sin variant. 'bf16' usa torch.bfloat16.
Returns:
Objeto pipeline diffusers cargado y movido al device seleccionado.
El tipo concreto depende del modelo (StableDiffusionPipeline,
StableDiffusionXLPipeline, etc.) pero siempre es callable via pipe(...).
Raises:
ImportError: Si torch o diffusers no estan instalados.
OSError: Si el path del modelo no existe o el nombre del hub es invalido.
"""
try:
import torch
from diffusers import AutoPipelineForText2Image
except ImportError as exc:
raise ImportError(
"diffusers_load_pipeline requiere torch y diffusers. "
"Instalar con: pip install torch diffusers"
) from exc
resolved_device = torch_device_select(device)
model_key = _get_model_key(model)
cache_key = (model_key, dtype, resolved_device)
if cache_key in _PIPELINE_CACHE:
return _PIPELINE_CACHE[cache_key]
load_path = model.path if model.path else model.name
if dtype == "fp16":
torch_dtype = torch.float16
pipe = AutoPipelineForText2Image.from_pretrained(
load_path,
torch_dtype=torch_dtype,
variant="fp16",
)
elif dtype == "bf16":
torch_dtype = torch.bfloat16
pipe = AutoPipelineForText2Image.from_pretrained(
load_path,
torch_dtype=torch_dtype,
)
elif dtype == "fp32":
torch_dtype = torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(
load_path,
torch_dtype=torch_dtype,
)
else:
raise ValueError(
f"dtype '{dtype}' no soportado. Usar 'fp16', 'bf16' o 'fp32'."
)
pipe = pipe.to(resolved_device)
_PIPELINE_CACHE[cache_key] = pipe
return pipe
def _clear_pipeline_cache() -> None:
"""Limpia el cache global de pipelines (uso interno y tests)."""
_PIPELINE_CACHE.clear()