"""diffusers_load_pipeline — carga un pipeline diffusers con cache global por (model, dtype, device).""" from __future__ import annotations import sys import os import time from typing import Any sys.path.insert(0, os.path.dirname(__file__)) from model_ref import ModelRef from torch_device_select import torch_device_select # Cache global: (model_key, dtype, device) -> pipeline object _PIPELINE_CACHE: dict[tuple[str, str, str], Any] = {} def _get_model_key(model: ModelRef) -> str: """Retorna la clave de cache para un ModelRef.""" return model.path if model.path else model.name def diffusers_load_pipeline( model: ModelRef, device: str = "auto", dtype: str = "fp16", ) -> Any: """Carga un pipeline diffusers con cache global por (model_key, dtype, device). Usa AutoPipelineForText2Image.from_pretrained con torch_dtype=torch.float16 y variant="fp16" por defecto. Hace pipe.to(device) tras la carga. Los pipelines se cachean en memoria — segunda llamada con los mismos parametros retorna el objeto cacheado sin recargar el modelo del disco. Args: model: Referencia al modelo. model.path se usa si esta presente; si no, model.name se pasa directo a from_pretrained (HF hub). device: Preferencia de device. 'auto' delega a torch_device_select (CUDA > MPS > CPU). Ejemplos: 'auto', 'cuda', 'cuda:0', 'cpu'. dtype: Precision del modelo. 'fp16' usa torch.float16 + variant="fp16". 'fp32' usa torch.float32 sin variant. 'bf16' usa torch.bfloat16. Returns: Objeto pipeline diffusers cargado y movido al device seleccionado. El tipo concreto depende del modelo (StableDiffusionPipeline, StableDiffusionXLPipeline, etc.) pero siempre es callable via pipe(...). Raises: ImportError: Si torch o diffusers no estan instalados. OSError: Si el path del modelo no existe o el nombre del hub es invalido. """ try: import torch from diffusers import AutoPipelineForText2Image except ImportError as exc: raise ImportError( "diffusers_load_pipeline requiere torch y diffusers. " "Instalar con: pip install torch diffusers" ) from exc resolved_device = torch_device_select(device) model_key = _get_model_key(model) cache_key = (model_key, dtype, resolved_device) if cache_key in _PIPELINE_CACHE: return _PIPELINE_CACHE[cache_key] load_path = model.path if model.path else model.name if dtype == "fp16": torch_dtype = torch.float16 pipe = AutoPipelineForText2Image.from_pretrained( load_path, torch_dtype=torch_dtype, variant="fp16", ) elif dtype == "bf16": torch_dtype = torch.bfloat16 pipe = AutoPipelineForText2Image.from_pretrained( load_path, torch_dtype=torch_dtype, ) elif dtype == "fp32": torch_dtype = torch.float32 pipe = AutoPipelineForText2Image.from_pretrained( load_path, torch_dtype=torch_dtype, ) else: raise ValueError( f"dtype '{dtype}' no soportado. Usar 'fp16', 'bf16' o 'fp32'." ) pipe = pipe.to(resolved_device) _PIPELINE_CACHE[cache_key] = pipe return pipe def _clear_pipeline_cache() -> None: """Limpia el cache global de pipelines (uso interno y tests).""" _PIPELINE_CACHE.clear()