"""Estimador de VRAM requerida para modelos de generacion de imagen.""" # Base weights por (model_type, quantization) en MB. # Incluye pesos del modelo + overhead tipico del contexto de inferencia. _MODEL_VRAM_MB: dict[tuple[str, str], int] = { ("sd15", "fp16"): 2100, ("sd15", "q8_0"): 1200, ("sd15", "q4_0"): 700, ("sdxl", "fp16"): 6800, ("sdxl", "q8_0"): 3800, ("sdxl", "q4_0"): 2200, ("flux_dev", "fp16"): 23000, ("flux_dev", "q8_0"): 13000, ("flux_dev", "q4_0"): 7000, ("flux_schnell", "fp16"): 23000, ("flux_schnell", "q8_0"): 12500, ("flux_schnell", "q4_0"): 6500, ("sd3", "fp16"): 8500, ("sd3", "q8_0"): 4800, ("sd3", "q4_0"): 2800, ("qwen_image", "fp16"): 8000, ("qwen_image", "q8_0"): 4500, ("qwen_image", "q4_0"): 2600, } # MB por LoRA adicional (estimacion conservadora en fp16). _LORA_MB = 300 # Modelos que requieren overhead de latente mas alto (Flux usa bloques transformer mas grandes). _FLUX_MODELS = {"flux_dev", "flux_schnell"} # Quantizaciones que son incompatibles con LoRA en la mayoria de runtimes. _QUANT_LORA_INCOMPATIBLE = {"q8_0", "q4_0", "q4_k_m", "q5_k_m", "q6_k"} def _latent_overhead_mb(model_type: str, width: int, height: int, batch_size: int) -> int: """Estima el overhead de VRAM para activaciones y latentes en MB.""" pixels = width * height if model_type in _FLUX_MODELS: # Flux usa un espacio latente 16x mas comprimido pero con mas canales. overhead = pixels // 32 else: # SD 1.5 / SDXL / SD3: overhead aprox w*h/64 MB. overhead = pixels // 64 return overhead * batch_size def vram_budget( gpu_vram_total_mb: int, model_type: str, quantization: str, n_loras: int = 0, width: int = 1024, height: int = 1024, batch_size: int = 1, ) -> dict: """Estima la VRAM requerida para ejecutar un modelo de generacion de imagen. Usa heuristicas tabuladas por (model_type, quantization) mas overhead de latentes y LoRAs. No requiere GPU ni runtime — solo lookup y aritmetica. Args: gpu_vram_total_mb: VRAM total de la GPU en MB. model_type: Tipo de modelo. Valores: sd15, sdxl, flux_dev, flux_schnell, sd3, qwen_image. quantization: Esquema de cuantizacion. Valores: fp16, q8_0, q4_0, etc. n_loras: Numero de LoRAs a cargar simultaneamente (default 0). width: Ancho de la imagen a generar en pixeles (default 1024). height: Alto de la imagen a generar en pixeles (default 1024). batch_size: Numero de imagenes en paralelo (default 1). Returns: dict con: - required_mb (int): VRAM estimada necesaria en MB. -1 si combo desconocido. - fits (bool): True si required_mb <= gpu_vram_total_mb. - headroom_mb (int): MB sobrantes (negativo si no cabe). 0 si combo desconocido. - warning (str | None): Aviso sobre incompatibilidades o ajustes necesarios. None si no hay advertencias. """ key = (model_type, quantization) if key not in _MODEL_VRAM_MB: return { "required_mb": -1, "fits": False, "headroom_mb": 0, "warning": f"unknown model/quant combo: ({model_type!r}, {quantization!r})", } base_mb = _MODEL_VRAM_MB[key] latent_mb = _latent_overhead_mb(model_type, width, height, batch_size) lora_mb = n_loras * _LORA_MB required_mb = base_mb + latent_mb + lora_mb fits = required_mb <= gpu_vram_total_mb headroom_mb = gpu_vram_total_mb - required_mb warning: str | None = None # LoRA + quantization incompatible en la mayoria de runtimes. if n_loras > 0 and quantization in _QUANT_LORA_INCOMPATIBLE: warning = f"lora+quantization incompatible — usa fp16 para cargar LoRAs con {model_type}" elif not fits: deficit = required_mb - gpu_vram_total_mb warning = f"needs +{deficit} MB (required {required_mb} MB, available {gpu_vram_total_mb} MB)" return { "required_mb": required_mb, "fits": fits, "headroom_mb": headroom_mb, "warning": warning, }