"""Tests para vram_budget.""" import sys import os # Ajustar path para importar desde python/functions/ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from ml.vram_budget import vram_budget def test_sdxl_fp16_fits_24gb(): """SDXL fp16 en una GPU de 24 GB deberia caber con headroom positivo.""" result = vram_budget( gpu_vram_total_mb=24576, # 24 GB model_type="sdxl", quantization="fp16", n_loras=0, width=1024, height=1024, batch_size=1, ) assert result["required_mb"] > 0, "required_mb debe ser positivo" assert result["fits"] is True, f"SDXL fp16 debe caber en 24 GB, required={result['required_mb']} MB" assert result["headroom_mb"] > 0, f"headroom debe ser positivo, got {result['headroom_mb']}" assert result["warning"] is None, f"no debe haber warning, got: {result['warning']}" def test_flux_fp16_no_fits_8gb(): """Flux fp16 (~23 GB) no debe caber en una GPU de 8 GB.""" result = vram_budget( gpu_vram_total_mb=8192, # 8 GB model_type="flux_dev", quantization="fp16", n_loras=0, width=1024, height=1024, batch_size=1, ) assert result["required_mb"] > 8192, f"Flux fp16 debe requerir mas de 8 GB, got {result['required_mb']} MB" assert result["fits"] is False, "Flux fp16 no debe caber en 8 GB" assert result["headroom_mb"] < 0, f"headroom debe ser negativo, got {result['headroom_mb']}" assert result["warning"] is not None, "debe haber warning con informacion de deficit" assert "+N MB" in result["warning"] or "+" in result["warning"], \ f"warning debe indicar cuantos MB extra se necesitan: {result['warning']}" def test_lora_plus_quant_warning(): """LoRA con quantization q8_0 debe emitir warning de incompatibilidad.""" result = vram_budget( gpu_vram_total_mb=24576, model_type="sdxl", quantization="q8_0", n_loras=2, width=1024, height=1024, batch_size=1, ) assert result["warning"] is not None, "debe haber warning por lora+quantization incompatible" assert "incompatible" in result["warning"].lower(), \ f"warning debe mencionar incompatibilidad: {result['warning']}" assert "fp16" in result["warning"], \ f"warning debe sugerir fp16: {result['warning']}" def test_unknown_combo(): """Combinacion (model_type, quant) desconocida debe retornar required_mb=-1 y warning.""" result = vram_budget( gpu_vram_total_mb=24576, model_type="modelo_inventado", quantization="q99_k", n_loras=0, ) assert result["required_mb"] == -1, \ f"required_mb debe ser -1 para combo desconocido, got {result['required_mb']}" assert result["fits"] is False, "fits debe ser False para combo desconocido" assert result["warning"] is not None, "debe haber warning para combo desconocido" assert "unknown" in result["warning"].lower(), \ f"warning debe mencionar 'unknown': {result['warning']}"