Files
fn_registry/python/functions/ml/vram_budget.py
T
egutierrez e3c8979e8d chore: auto-commit (95 archivos)
- cmd/fn/doctor.go
- cmd/fn/main.go
- cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt
- cpp/apps/primitives_gallery/playground/tables/data_table.cpp
- cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp
- cpp/apps/primitives_gallery/playground/tables/data_table_logic.h
- cpp/apps/primitives_gallery/playground/tables/self_test.cpp
- cpp/apps/primitives_gallery/playground/tables/tql.cpp
- cpp/apps/primitives_gallery/playground/tables/viz.cpp
- cpp/apps/primitives_gallery/playground/tables/viz.h
- ...

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 00:50:34 +02:00

112 lines
4.1 KiB
Python

"""Estimador de VRAM requerida para modelos de generacion de imagen."""
# Base weights por (model_type, quantization) en MB.
# Incluye pesos del modelo + overhead tipico del contexto de inferencia.
_MODEL_VRAM_MB: dict[tuple[str, str], int] = {
("sd15", "fp16"): 2100,
("sd15", "q8_0"): 1200,
("sd15", "q4_0"): 700,
("sdxl", "fp16"): 6800,
("sdxl", "q8_0"): 3800,
("sdxl", "q4_0"): 2200,
("flux_dev", "fp16"): 23000,
("flux_dev", "q8_0"): 13000,
("flux_dev", "q4_0"): 7000,
("flux_schnell", "fp16"): 23000,
("flux_schnell", "q8_0"): 12500,
("flux_schnell", "q4_0"): 6500,
("sd3", "fp16"): 8500,
("sd3", "q8_0"): 4800,
("sd3", "q4_0"): 2800,
("qwen_image", "fp16"): 8000,
("qwen_image", "q8_0"): 4500,
("qwen_image", "q4_0"): 2600,
}
# MB por LoRA adicional (estimacion conservadora en fp16).
_LORA_MB = 300
# Modelos que requieren overhead de latente mas alto (Flux usa bloques transformer mas grandes).
_FLUX_MODELS = {"flux_dev", "flux_schnell"}
# Quantizaciones que son incompatibles con LoRA en la mayoria de runtimes.
_QUANT_LORA_INCOMPATIBLE = {"q8_0", "q4_0", "q4_k_m", "q5_k_m", "q6_k"}
def _latent_overhead_mb(model_type: str, width: int, height: int, batch_size: int) -> int:
"""Estima el overhead de VRAM para activaciones y latentes en MB."""
pixels = width * height
if model_type in _FLUX_MODELS:
# Flux usa un espacio latente 16x mas comprimido pero con mas canales.
overhead = pixels // 32
else:
# SD 1.5 / SDXL / SD3: overhead aprox w*h/64 MB.
overhead = pixels // 64
return overhead * batch_size
def vram_budget(
gpu_vram_total_mb: int,
model_type: str,
quantization: str,
n_loras: int = 0,
width: int = 1024,
height: int = 1024,
batch_size: int = 1,
) -> dict:
"""Estima la VRAM requerida para ejecutar un modelo de generacion de imagen.
Usa heuristicas tabuladas por (model_type, quantization) mas overhead de
latentes y LoRAs. No requiere GPU ni runtime — solo lookup y aritmetica.
Args:
gpu_vram_total_mb: VRAM total de la GPU en MB.
model_type: Tipo de modelo. Valores: sd15, sdxl, flux_dev, flux_schnell, sd3, qwen_image.
quantization: Esquema de cuantizacion. Valores: fp16, q8_0, q4_0, etc.
n_loras: Numero de LoRAs a cargar simultaneamente (default 0).
width: Ancho de la imagen a generar en pixeles (default 1024).
height: Alto de la imagen a generar en pixeles (default 1024).
batch_size: Numero de imagenes en paralelo (default 1).
Returns:
dict con:
- required_mb (int): VRAM estimada necesaria en MB. -1 si combo desconocido.
- fits (bool): True si required_mb <= gpu_vram_total_mb.
- headroom_mb (int): MB sobrantes (negativo si no cabe). 0 si combo desconocido.
- warning (str | None): Aviso sobre incompatibilidades o ajustes necesarios.
None si no hay advertencias.
"""
key = (model_type, quantization)
if key not in _MODEL_VRAM_MB:
return {
"required_mb": -1,
"fits": False,
"headroom_mb": 0,
"warning": f"unknown model/quant combo: ({model_type!r}, {quantization!r})",
}
base_mb = _MODEL_VRAM_MB[key]
latent_mb = _latent_overhead_mb(model_type, width, height, batch_size)
lora_mb = n_loras * _LORA_MB
required_mb = base_mb + latent_mb + lora_mb
fits = required_mb <= gpu_vram_total_mb
headroom_mb = gpu_vram_total_mb - required_mb
warning: str | None = None
# LoRA + quantization incompatible en la mayoria de runtimes.
if n_loras > 0 and quantization in _QUANT_LORA_INCOMPATIBLE:
warning = f"lora+quantization incompatible — usa fp16 para cargar LoRAs con {model_type}"
elif not fits:
deficit = required_mb - gpu_vram_total_mb
warning = f"needs +{deficit} MB (required {required_mb} MB, available {gpu_vram_total_mb} MB)"
return {
"required_mb": required_mb,
"fits": fits,
"headroom_mb": headroom_mb,
"warning": warning,
}