chore: auto-commit (95 archivos)
- cmd/fn/doctor.go - cmd/fn/main.go - cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt - cpp/apps/primitives_gallery/playground/tables/data_table.cpp - cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp - cpp/apps/primitives_gallery/playground/tables/data_table_logic.h - cpp/apps/primitives_gallery/playground/tables/self_test.cpp - cpp/apps/primitives_gallery/playground/tables/tql.cpp - cpp/apps/primitives_gallery/playground/tables/viz.cpp - cpp/apps/primitives_gallery/playground/tables/viz.h - ... Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
"""Tests para cuda_available."""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
# Asegurar que el modulo ml es importable desde el path del registry
|
||||
sys.path.insert(0, "python/functions")
|
||||
|
||||
from ml.cuda_available import cuda_available
|
||||
|
||||
|
||||
class TestCudaAvailable(unittest.TestCase):
|
||||
|
||||
def test_claves_del_dict_siempre_presentes(self):
|
||||
"""claves del dict siempre presentes"""
|
||||
result = cuda_available()
|
||||
for key in ("available", "device_count", "devices", "torch_version", "cuda_version"):
|
||||
self.assertIn(key, result, f"Falta clave: {key}")
|
||||
|
||||
def test_sin_torch_retorna_available_False_y_torch_version_not_installed(self):
|
||||
"""sin torch retorna available=False y torch_version=not_installed"""
|
||||
with patch.dict(sys.modules, {"torch": None}):
|
||||
result = cuda_available()
|
||||
self.assertFalse(result["available"])
|
||||
self.assertEqual(result["torch_version"], "not_installed")
|
||||
self.assertEqual(result["device_count"], 0)
|
||||
self.assertEqual(result["devices"], [])
|
||||
self.assertIsNone(result["cuda_version"])
|
||||
|
||||
def test_con_torch_sin_cuda_retorna_available_False_y_device_count_0(self):
|
||||
"""con torch sin cuda retorna available=False y device_count=0"""
|
||||
import types
|
||||
fake_torch = types.ModuleType("torch")
|
||||
fake_torch.__version__ = "2.3.0"
|
||||
fake_torch.cuda = types.SimpleNamespace(
|
||||
is_available=lambda: False,
|
||||
device_count=lambda: 0,
|
||||
)
|
||||
fake_torch.version = types.SimpleNamespace(cuda=None)
|
||||
|
||||
with patch.dict(sys.modules, {"torch": fake_torch}):
|
||||
result = cuda_available()
|
||||
|
||||
self.assertFalse(result["available"])
|
||||
self.assertEqual(result["device_count"], 0)
|
||||
self.assertEqual(result["devices"], [])
|
||||
self.assertEqual(result["torch_version"], "2.3.0")
|
||||
self.assertIsNone(result["cuda_version"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,212 @@
|
||||
"""Tests para el backend diffusers: load_pipeline, set_scheduler, generate, unload."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
# Ajustar path para importar desde python/functions/ml/
|
||||
_ML_PATH = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..",
|
||||
)
|
||||
sys.path.insert(0, os.path.abspath(_ML_PATH))
|
||||
|
||||
# Importaciones lazy de torch y diffusers — las omitimos si no estan disponibles.
|
||||
torch = pytest.importorskip("torch", reason="torch no instalado — skip tests diffusers")
|
||||
pytest.importorskip("diffusers", reason="diffusers no instalado — skip tests diffusers")
|
||||
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.generation_config import GenerationConfig
|
||||
from ml.image_gen_result import ImageGenResult
|
||||
from ml.diffusers_load_pipeline import diffusers_load_pipeline, _clear_pipeline_cache
|
||||
from ml.diffusers_set_scheduler import diffusers_set_scheduler
|
||||
from ml.diffusers_unload import diffusers_unload
|
||||
|
||||
# diffusers_generate importa image_gen_result sin prefijo de paquete.
|
||||
# Para evitar el double-import problem (ml.image_gen_result != image_gen_result),
|
||||
# forzamos que sys.modules["image_gen_result"] apunte al modulo ya cargado
|
||||
# como ml.image_gen_result antes de importar diffusers_generate.
|
||||
import sys as _sys
|
||||
import ml.image_gen_result as _igr_module
|
||||
import ml.generation_config as _gcfg_module
|
||||
import ml.genconfig_to_diffusers_kwargs as _gkwargs_module
|
||||
for _alias, _mod in [
|
||||
("image_gen_result", _igr_module),
|
||||
("generation_config", _gcfg_module),
|
||||
("genconfig_to_diffusers_kwargs", _gkwargs_module),
|
||||
]:
|
||||
if _alias not in _sys.modules:
|
||||
_sys.modules[_alias] = _mod
|
||||
|
||||
from ml.diffusers_generate import diffusers_generate
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constantes
|
||||
# ---------------------------------------------------------------------------
|
||||
SD_TURBO_PATH = "/home/lucas/vaults/imagegen_models/diffusers/sd-turbo"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sd_turbo_model() -> ModelRef:
|
||||
"""ModelRef apuntando a SD Turbo local."""
|
||||
if not os.path.isdir(SD_TURBO_PATH):
|
||||
pytest.skip(f"SD Turbo no encontrado en {SD_TURBO_PATH}")
|
||||
return ModelRef(
|
||||
name="sd-turbo",
|
||||
model_type="sd15",
|
||||
quantization="fp16",
|
||||
path=SD_TURBO_PATH,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loaded_pipe(sd_turbo_model: ModelRef):
|
||||
"""Pipeline SD Turbo cargado una sola vez para toda la sesion de tests."""
|
||||
# Intentar fp16 primero; si falla (no hay variante fp16) usar fp32
|
||||
try:
|
||||
pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16")
|
||||
except Exception:
|
||||
_clear_pipeline_cache()
|
||||
pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp32")
|
||||
yield pipe
|
||||
# Teardown: liberar al final de la sesion
|
||||
diffusers_unload(None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sd_turbo_cfg(sd_turbo_model: ModelRef) -> GenerationConfig:
|
||||
"""GenerationConfig minimo para SD Turbo (1 step, 512x512)."""
|
||||
return GenerationConfig(
|
||||
prompt="a simple red circle on white background",
|
||||
negative_prompt=None,
|
||||
seed=42,
|
||||
steps=1,
|
||||
cfg_scale=0.0,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model=sd_turbo_model,
|
||||
loras=[],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: carga pipeline y retorna callable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_load_pipeline_returns_callable(sd_turbo_model: ModelRef) -> None:
|
||||
"""carga pipeline y retorna callable"""
|
||||
_clear_pipeline_cache()
|
||||
pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16")
|
||||
assert callable(pipe), "El pipeline debe ser callable"
|
||||
assert hasattr(pipe, "scheduler"), "El pipeline debe tener atributo scheduler"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: segunda carga usa cache (< 100ms)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_load_pipeline_caches(sd_turbo_model: ModelRef) -> None:
|
||||
"""segunda carga usa cache (< 100ms)"""
|
||||
# Primera carga (puede tardar varios segundos)
|
||||
_clear_pipeline_cache()
|
||||
_ = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16")
|
||||
|
||||
# Segunda carga debe ser cache hit
|
||||
t0 = time.perf_counter()
|
||||
pipe2 = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16")
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
assert elapsed_ms < 100, (
|
||||
f"Segunda carga tardo {elapsed_ms:.1f}ms (esperado < 100ms — debe ser cache hit)"
|
||||
)
|
||||
assert pipe2 is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: set_scheduler cambia la clase del scheduler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_set_scheduler_changes_scheduler_class(loaded_pipe) -> None:
|
||||
"""euler cambia scheduler a EulerDiscreteScheduler"""
|
||||
pipe = diffusers_set_scheduler(loaded_pipe, "euler")
|
||||
scheduler_name = type(pipe.scheduler).__name__
|
||||
assert scheduler_name == "EulerDiscreteScheduler", (
|
||||
f"Esperado EulerDiscreteScheduler, obtenido {scheduler_name}"
|
||||
)
|
||||
|
||||
|
||||
def test_set_scheduler_euler_a(loaded_pipe) -> None:
|
||||
"""euler_a cambia scheduler a EulerAncestralDiscreteScheduler"""
|
||||
pipe = diffusers_set_scheduler(loaded_pipe, "euler_a")
|
||||
scheduler_name = type(pipe.scheduler).__name__
|
||||
assert scheduler_name == "EulerAncestralDiscreteScheduler", (
|
||||
f"Esperado EulerAncestralDiscreteScheduler, obtenido {scheduler_name}"
|
||||
)
|
||||
# Restaurar euler para no afectar otros tests
|
||||
diffusers_set_scheduler(loaded_pipe, "euler")
|
||||
|
||||
|
||||
def test_set_scheduler_invalid_raises_value_error(loaded_pipe) -> None:
|
||||
"""sampler invalido lanza ValueError"""
|
||||
with pytest.raises(ValueError, match="no soportado"):
|
||||
diffusers_set_scheduler(loaded_pipe, "nonexistent_sampler_xyz")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: genera imagen retorna ImageGenResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_generate_returns_image_gen_result(
|
||||
loaded_pipe, sd_turbo_cfg: GenerationConfig
|
||||
) -> None:
|
||||
"""genera imagen retorna ImageGenResult"""
|
||||
result = diffusers_generate(loaded_pipe, sd_turbo_cfg)
|
||||
|
||||
assert isinstance(result, ImageGenResult), (
|
||||
f"Esperado ImageGenResult, obtenido {type(result)}"
|
||||
)
|
||||
assert result.image is not None, "result.image no debe ser None"
|
||||
assert result.duration_ms > 0, (
|
||||
f"duration_ms debe ser positivo, obtenido {result.duration_ms}"
|
||||
)
|
||||
assert "backend" in result.meta, "meta debe tener key 'backend'"
|
||||
assert result.meta["backend"] == "diffusers", (
|
||||
f"meta['backend'] debe ser 'diffusers', obtenido {result.meta['backend']}"
|
||||
)
|
||||
assert "model" in result.meta, "meta debe tener key 'model'"
|
||||
|
||||
# Verificar que la imagen tiene las dimensiones correctas
|
||||
w, h = result.image.size
|
||||
assert w == sd_turbo_cfg.width and h == sd_turbo_cfg.height, (
|
||||
f"Imagen esperada {sd_turbo_cfg.width}x{sd_turbo_cfg.height}, "
|
||||
f"obtenida {w}x{h}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: unload limpia cache cuda si disponible
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_unload_clears_cuda() -> None:
|
||||
"""unload None limpia cache cuda si disponible"""
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
# Limpiar cache — no debe lanzar excepcion independientemente de si hay CUDA
|
||||
diffusers_unload(None)
|
||||
|
||||
if cuda_available:
|
||||
# Despues de empty_cache, la memoria reservada por el allocator baja
|
||||
# No podemos asumir que sea 0 (otros tensores pueden estar vivos),
|
||||
# pero la llamada debe completarse sin error.
|
||||
reserved = torch.cuda.memory_reserved()
|
||||
# Solo verificamos que no lanza excepcion y que la llamada completo
|
||||
assert reserved >= 0, "memory_reserved debe ser >= 0"
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests de roundtrip JSON para genconfig_save_json y genconfig_load_json."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ml.genconfig_save_json import genconfig_save_json
|
||||
from ml.genconfig_load_json import genconfig_load_json
|
||||
from ml.generation_config import GenerationConfig
|
||||
|
||||
|
||||
def _make_cfg(**overrides):
|
||||
"""Construye un GenerationConfig sintetico usando model_validate para evitar
|
||||
problemas de identidad de clase entre modulos pydantic separados."""
|
||||
defaults = dict(
|
||||
prompt="a forest at dusk",
|
||||
negative_prompt="blurry, low quality",
|
||||
seed=123,
|
||||
steps=25,
|
||||
cfg_scale=7.5,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model={"name": "runwayml/stable-diffusion-v1-5", "model_type": "sd15"},
|
||||
loras=[{"path": "/loras/detail.safetensors", "weight": 0.7}],
|
||||
)
|
||||
defaults.update(overrides)
|
||||
try:
|
||||
return GenerationConfig.model_validate(defaults)
|
||||
except AttributeError:
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.lora_ref import LoraRef
|
||||
m = defaults.pop("model")
|
||||
if isinstance(m, dict):
|
||||
m = ModelRef(**m)
|
||||
loras = defaults.pop("loras", [])
|
||||
built = [LoraRef(**lr) if isinstance(lr, dict) else lr for lr in loras]
|
||||
return GenerationConfig(model=m, loras=tuple(built), **defaults)
|
||||
|
||||
|
||||
class TestGenconfigJsonRoundtrip(unittest.TestCase):
|
||||
|
||||
def test_save_escribe_archivo_json_valido_en_la_ruta_indicada(self):
|
||||
"""save escribe archivo JSON valido en la ruta indicada"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
saved = genconfig_save_json(cfg, path)
|
||||
self.assertTrue(os.path.isfile(saved))
|
||||
with open(saved, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
self.assertIsInstance(data, dict)
|
||||
self.assertEqual(data["prompt"], "a forest at dusk")
|
||||
|
||||
def test_save_crea_directorios_padre_si_no_existen(self):
|
||||
"""save crea directorios padre si no existen"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
nested = os.path.join(tmpdir, "a", "b", "c", "config.json")
|
||||
saved = genconfig_save_json(cfg, nested)
|
||||
self.assertTrue(os.path.isfile(saved))
|
||||
|
||||
def test_json_contiene_claves_en_snake_case(self):
|
||||
"""json contiene claves en snake_case"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
# Claves deben ser snake_case (interoperabilidad con Go)
|
||||
expected_keys = {
|
||||
"prompt", "negative_prompt", "seed", "steps",
|
||||
"cfg_scale", "sampler", "width", "height", "model",
|
||||
}
|
||||
for key in expected_keys:
|
||||
self.assertIn(key, data, f"Clave snake_case faltante: {key}")
|
||||
# No debe haber camelCase
|
||||
self.assertNotIn("negativePrompt", data)
|
||||
self.assertNotIn("cfgScale", data)
|
||||
self.assertNotIn("numInferenceSteps", data)
|
||||
|
||||
def test_roundtrip_preserva_campos_escalares(self):
|
||||
"""roundtrip save→load preserva campos escalares"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
loaded = genconfig_load_json(path)
|
||||
self.assertEqual(loaded.prompt, cfg.prompt)
|
||||
self.assertEqual(loaded.negative_prompt, cfg.negative_prompt)
|
||||
self.assertEqual(loaded.seed, cfg.seed)
|
||||
self.assertEqual(loaded.steps, cfg.steps)
|
||||
self.assertAlmostEqual(loaded.cfg_scale, cfg.cfg_scale)
|
||||
self.assertEqual(loaded.sampler, cfg.sampler)
|
||||
self.assertEqual(loaded.width, cfg.width)
|
||||
self.assertEqual(loaded.height, cfg.height)
|
||||
|
||||
def test_roundtrip_preserva_model_ref(self):
|
||||
"""roundtrip preserva ModelRef"""
|
||||
cfg = _make_cfg(
|
||||
model={
|
||||
"name": "stabilityai/sdxl-base-1.0",
|
||||
"model_type": "sdxl",
|
||||
"quantization": "fp16",
|
||||
"path": "/models/sdxl.safetensors",
|
||||
}
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
loaded = genconfig_load_json(path)
|
||||
self.assertEqual(loaded.model.name, "stabilityai/sdxl-base-1.0")
|
||||
self.assertEqual(loaded.model.model_type, "sdxl")
|
||||
self.assertEqual(loaded.model.quantization, "fp16")
|
||||
self.assertEqual(loaded.model.path, "/models/sdxl.safetensors")
|
||||
|
||||
def test_roundtrip_preserva_loras(self):
|
||||
"""roundtrip preserva lista de LoraRef"""
|
||||
cfg = _make_cfg(
|
||||
loras=[
|
||||
{"path": "/loras/a.safetensors", "weight": 0.8},
|
||||
{"path": "/loras/b.safetensors", "weight": 0.5},
|
||||
]
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
loaded = genconfig_load_json(path)
|
||||
loaded_loras = list(loaded.loras)
|
||||
self.assertEqual(len(loaded_loras), 2)
|
||||
paths = [lr.path for lr in loaded_loras]
|
||||
self.assertIn("/loras/a.safetensors", paths)
|
||||
self.assertIn("/loras/b.safetensors", paths)
|
||||
|
||||
def test_roundtrip_negative_prompt_none(self):
|
||||
"""roundtrip con negative_prompt=None"""
|
||||
cfg = _make_cfg(negative_prompt=None)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
loaded = genconfig_load_json(path)
|
||||
self.assertIsNone(loaded.negative_prompt)
|
||||
|
||||
def test_load_falla_con_file_not_found(self):
|
||||
"""load lanza FileNotFoundError si el archivo no existe"""
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
genconfig_load_json("/tmp/nonexistent_fn_registry_test_12345.json")
|
||||
|
||||
def test_save_retorna_path_absoluto(self):
|
||||
"""save retorna path absoluto aunque se pase path relativo"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
abs_path = os.path.join(tmpdir, "cfg.json")
|
||||
result = genconfig_save_json(cfg, abs_path)
|
||||
self.assertTrue(os.path.isabs(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Tests para genconfig_to_diffusers_kwargs."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ml.genconfig_to_diffusers_kwargs import genconfig_to_diffusers_kwargs
|
||||
from ml.generation_config import GenerationConfig
|
||||
|
||||
|
||||
def _make_cfg(**overrides):
|
||||
"""Crea un GenerationConfig sintetico para tests via model_validate / constructor."""
|
||||
defaults = dict(
|
||||
prompt="a dog in the park",
|
||||
seed=42,
|
||||
steps=30,
|
||||
cfg_scale=7.5,
|
||||
sampler="euler_a",
|
||||
width=512,
|
||||
height=768,
|
||||
model={"name": "runwayml/stable-diffusion-v1-5", "model_type": "sd15"},
|
||||
)
|
||||
defaults.update(overrides)
|
||||
try:
|
||||
return GenerationConfig.model_validate(defaults)
|
||||
except AttributeError:
|
||||
# dataclass fallback: model y loras ya son dicts, construir manualmente
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.lora_ref import LoraRef
|
||||
m = defaults.pop("model")
|
||||
if isinstance(m, dict):
|
||||
m = ModelRef(**m)
|
||||
loras = defaults.pop("loras", [])
|
||||
built_loras = []
|
||||
for lr in loras:
|
||||
if isinstance(lr, dict):
|
||||
built_loras.append(LoraRef(**lr))
|
||||
else:
|
||||
built_loras.append(lr)
|
||||
return GenerationConfig(model=m, loras=tuple(built_loras), **defaults)
|
||||
|
||||
|
||||
class TestGenconfigToDiffusersKwargs(unittest.TestCase):
|
||||
|
||||
def test_kwargs_contiene_todas_las_claves_requeridas(self):
|
||||
"""kwargs contiene todas las claves requeridas"""
|
||||
cfg = _make_cfg()
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
required_keys = {
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"width",
|
||||
"height",
|
||||
"generator",
|
||||
}
|
||||
self.assertEqual(set(kwargs.keys()), required_keys)
|
||||
|
||||
def test_negative_prompt_none_se_pasa_tal_cual(self):
|
||||
"""negative_prompt None se pasa tal cual"""
|
||||
cfg = _make_cfg(negative_prompt=None)
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertIsNone(kwargs["negative_prompt"])
|
||||
|
||||
def test_steps_y_cfg_scale_se_mapean_a_num_inference_steps_y_guidance_scale(self):
|
||||
"""steps y cfg_scale se mapean a num_inference_steps y guidance_scale"""
|
||||
cfg = _make_cfg(steps=20, cfg_scale=8.0)
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertEqual(kwargs["num_inference_steps"], 20)
|
||||
self.assertAlmostEqual(kwargs["guidance_scale"], 8.0)
|
||||
|
||||
def test_generator_siempre_es_none(self):
|
||||
"""generator siempre es None"""
|
||||
cfg = _make_cfg()
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertIsNone(kwargs["generator"])
|
||||
|
||||
def test_prompt_se_copia_sin_modificar(self):
|
||||
"""prompt se copia sin modificar"""
|
||||
cfg = _make_cfg(prompt="a cat on a roof")
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertEqual(kwargs["prompt"], "a cat on a roof")
|
||||
|
||||
def test_width_y_height_se_preservan(self):
|
||||
"""width y height se preservan"""
|
||||
cfg = _make_cfg(width=1024, height=768)
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertEqual(kwargs["width"], 1024)
|
||||
self.assertEqual(kwargs["height"], 768)
|
||||
|
||||
def test_negative_prompt_string_se_pasa_tal_cual(self):
|
||||
"""negative_prompt string se pasa tal cual"""
|
||||
cfg = _make_cfg(negative_prompt="blurry, low quality")
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertEqual(kwargs["negative_prompt"], "blurry, low quality")
|
||||
|
||||
def test_no_incluye_seed_sampler_ni_loras(self):
|
||||
"""no incluye seed sampler ni loras en el dict"""
|
||||
cfg = _make_cfg(
|
||||
loras=[{"path": "/loras/detail.safetensors", "weight": 0.8}]
|
||||
)
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertNotIn("seed", kwargs)
|
||||
self.assertNotIn("sampler", kwargs)
|
||||
self.assertNotIn("loras", kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Tests para genconfig_to_sdcpp_args."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ml.genconfig_to_sdcpp_args import genconfig_to_sdcpp_args, _SAMPLER_MAP
|
||||
from ml.generation_config import GenerationConfig
|
||||
|
||||
|
||||
def _make_cfg(**overrides):
|
||||
"""Crea un GenerationConfig sintetico para tests via model_validate / constructor."""
|
||||
defaults = dict(
|
||||
prompt="a cat",
|
||||
seed=1,
|
||||
steps=20,
|
||||
cfg_scale=7.0,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model={"name": "v1-5-pruned.ckpt", "model_type": "sd15", "path": "/models/v1-5.ckpt"},
|
||||
)
|
||||
defaults.update(overrides)
|
||||
# Normalizar loras a dicts si fueron pasados como LoraRef
|
||||
if "loras" in defaults:
|
||||
normalized = []
|
||||
for lr in defaults["loras"]:
|
||||
if hasattr(lr, "__dict__") and not isinstance(lr, dict):
|
||||
normalized.append({"path": lr.path, "weight": lr.weight, "scale": lr.scale})
|
||||
else:
|
||||
normalized.append(lr)
|
||||
defaults["loras"] = normalized
|
||||
try:
|
||||
return GenerationConfig.model_validate(defaults)
|
||||
except AttributeError:
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.lora_ref import LoraRef
|
||||
m = defaults.pop("model")
|
||||
if isinstance(m, dict):
|
||||
m = ModelRef(**m)
|
||||
loras = defaults.pop("loras", [])
|
||||
built = [LoraRef(**lr) if isinstance(lr, dict) else lr for lr in loras]
|
||||
return GenerationConfig(model=m, loras=tuple(built), **defaults)
|
||||
|
||||
|
||||
def _get_flag_value(args: list[str], flag: str) -> str | None:
|
||||
"""Extrae el valor de un flag en la lista de args."""
|
||||
try:
|
||||
idx = args.index(flag)
|
||||
return args[idx + 1]
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
def _get_all_flag_values(args: list[str], flag: str) -> list[str]:
|
||||
"""Extrae todos los valores de un flag repetido (ej. --lora)."""
|
||||
values = []
|
||||
for i, arg in enumerate(args):
|
||||
if arg == flag and i + 1 < len(args):
|
||||
values.append(args[i + 1])
|
||||
return values
|
||||
|
||||
|
||||
class TestGenconfigToSdcppArgs(unittest.TestCase):
|
||||
|
||||
def test_sampler_euler_a_se_mapea_a_euler_a_en_el_flag_sampling_method(self):
|
||||
"""sampler euler_a se mapea a euler_a en el flag --sampling-method"""
|
||||
cfg = _make_cfg(sampler="euler_a")
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "--sampling-method"), "euler_a")
|
||||
|
||||
def test_sampler_dpm_pp_2m_se_mapea_a_dpmpp2m(self):
|
||||
"""sampler dpm++2m se mapea a dpmpp2m"""
|
||||
cfg = _make_cfg(sampler="dpm++2m")
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "--sampling-method"), "dpmpp2m")
|
||||
|
||||
def test_lora_con_path_y_weight_se_agrega_como_lora_path_weight(self):
|
||||
"""lora con path y weight se agrega como --lora path:weight"""
|
||||
cfg = _make_cfg(
|
||||
loras=[{"path": "/loras/detail.safetensors", "weight": 0.8}]
|
||||
)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
lora_values = _get_all_flag_values(args, "--lora")
|
||||
self.assertEqual(len(lora_values), 1)
|
||||
self.assertEqual(lora_values[0], "/loras/detail.safetensors:0.8")
|
||||
|
||||
def test_multiples_loras_generan_multiples_pares_lora(self):
|
||||
"""multiples loras generan multiples pares --lora"""
|
||||
cfg = _make_cfg(
|
||||
loras=[
|
||||
{"path": "/loras/a.safetensors", "weight": 0.5},
|
||||
{"path": "/loras/b.safetensors", "weight": 1.0},
|
||||
]
|
||||
)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
lora_values = _get_all_flag_values(args, "--lora")
|
||||
self.assertEqual(len(lora_values), 2)
|
||||
self.assertIn("/loras/a.safetensors:0.5", lora_values)
|
||||
self.assertIn("/loras/b.safetensors:1.0", lora_values)
|
||||
|
||||
def test_negative_prompt_none_produce_string_vacio_en_negative_prompt(self):
|
||||
"""negative_prompt None produce string vacio en --negative-prompt"""
|
||||
cfg = _make_cfg(negative_prompt=None)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "--negative-prompt"), "")
|
||||
|
||||
def test_model_path_tiene_prioridad_sobre_model_name_en_m(self):
|
||||
"""model.path tiene prioridad sobre model.name en -m"""
|
||||
cfg = _make_cfg(
|
||||
model={"name": "hub-name", "model_type": "sd15", "path": "/local/path/model.ckpt"}
|
||||
)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "-m"), "/local/path/model.ckpt")
|
||||
|
||||
def test_sin_path_usa_model_name_en_m(self):
|
||||
"""sin path usa model.name en -m"""
|
||||
cfg = _make_cfg(
|
||||
model={"name": "runwayml/sd-v1-5", "model_type": "sd15", "path": None}
|
||||
)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "-m"), "runwayml/sd-v1-5")
|
||||
|
||||
def test_args_contiene_flags_obligatorios(self):
|
||||
"""args contiene --prompt --seed --steps --cfg-scale --sampling-method -W -H -m"""
|
||||
cfg = _make_cfg()
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
for flag in ["--prompt", "--seed", "--steps", "--cfg-scale", "--sampling-method", "-W", "-H", "-m"]:
|
||||
self.assertIn(flag, args, f"Flag faltante: {flag}")
|
||||
|
||||
def test_sampler_map_cubre_todos_los_samplers_canonicos(self):
|
||||
"""_SAMPLER_MAP cubre todos los samplers canonicos del dominio ml"""
|
||||
canonical = {"euler", "euler_a", "dpm++2m", "dpm++2m_v2", "heun", "dpm2", "lcm"}
|
||||
self.assertEqual(set(_SAMPLER_MAP.keys()), canonical)
|
||||
|
||||
def test_seed_steps_width_height_se_convierten_a_string(self):
|
||||
"""seed steps width height se convierten a string en los args"""
|
||||
cfg = _make_cfg(seed=42, steps=25, width=768, height=512)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "--seed"), "42")
|
||||
self.assertEqual(_get_flag_value(args, "--steps"), "25")
|
||||
self.assertEqual(_get_flag_value(args, "-W"), "768")
|
||||
self.assertEqual(_get_flag_value(args, "-H"), "512")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Tests para GenerationConfig — serialización, roundtrip y frozen."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
# Añadir python/functions/ml al path para que los imports internos del módulo
|
||||
# (from lora_ref import LoraRef, from model_ref import ModelRef) funcionen.
|
||||
# Los módulos se importan directamente desde el subdirectorio para evitar
|
||||
# colisiones de tipos entre ml.generation_config.* y generation_config.*.
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
from generation_config import GenerationConfig
|
||||
from lora_ref import LoraRef
|
||||
from model_ref import ModelRef
|
||||
|
||||
|
||||
def _make_model() -> ModelRef:
|
||||
return ModelRef(name="stabilityai/stable-diffusion-v1-5", model_type="sd15")
|
||||
|
||||
|
||||
def _make_config() -> GenerationConfig:
|
||||
return GenerationConfig(
|
||||
prompt="a cat in the moonlight",
|
||||
negative_prompt="blurry, low quality",
|
||||
seed=42,
|
||||
steps=30,
|
||||
cfg_scale=7.5,
|
||||
sampler="euler_a",
|
||||
width=512,
|
||||
height=512,
|
||||
model=_make_model(),
|
||||
loras=[],
|
||||
clip_skip=1,
|
||||
)
|
||||
|
||||
|
||||
def test_instancia_ok():
|
||||
"""GenerationConfig crea instancia sin errores"""
|
||||
cfg = _make_config()
|
||||
assert cfg.prompt == "a cat in the moonlight"
|
||||
assert cfg.seed == 42
|
||||
assert cfg.steps == 30
|
||||
assert cfg.cfg_scale == 7.5
|
||||
assert cfg.sampler == "euler_a"
|
||||
assert cfg.width == 512
|
||||
assert cfg.height == 512
|
||||
assert cfg.clip_skip == 1
|
||||
|
||||
|
||||
def test_model_dump_keys_snake_case():
|
||||
"""model_dump devuelve dict con keys snake_case incluyendo negative_prompt, cfg_scale, clip_skip"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
cfg = _make_config()
|
||||
d = cfg.model_dump()
|
||||
assert isinstance(d, dict)
|
||||
assert "negative_prompt" in d
|
||||
assert "cfg_scale" in d
|
||||
assert "clip_skip" in d
|
||||
assert d["negative_prompt"] == "blurry, low quality"
|
||||
assert d["cfg_scale"] == 7.5
|
||||
assert d["clip_skip"] == 1
|
||||
|
||||
|
||||
def test_model_dump_json_parseable():
|
||||
"""model_dump_json retorna str JSON parseable"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
cfg = _make_config()
|
||||
raw = cfg.model_dump_json()
|
||||
assert isinstance(raw, str)
|
||||
parsed = json.loads(raw)
|
||||
assert isinstance(parsed, dict)
|
||||
assert parsed["prompt"] == "a cat in the moonlight"
|
||||
assert parsed["seed"] == 42
|
||||
|
||||
|
||||
def test_roundtrip_model_validate():
|
||||
"""GenerationConfig.model_validate(json.loads(...)) roundtrip ok"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
cfg = _make_config()
|
||||
raw_json = cfg.model_dump_json()
|
||||
parsed = json.loads(raw_json)
|
||||
cfg2 = GenerationConfig.model_validate(parsed)
|
||||
assert cfg2.prompt == cfg.prompt
|
||||
assert cfg2.seed == cfg.seed
|
||||
assert cfg2.cfg_scale == cfg.cfg_scale
|
||||
assert cfg2.sampler == cfg.sampler
|
||||
assert cfg2.clip_skip == cfg.clip_skip
|
||||
assert cfg2.model.name == cfg.model.name
|
||||
assert cfg2.model.model_type == cfg.model.model_type
|
||||
|
||||
|
||||
def test_frozen_levanta_error_al_mutar():
|
||||
"""frozen: intentar mutar levanta AttributeError, ValidationError o FrozenInstanceError"""
|
||||
cfg = _make_config()
|
||||
raised = False
|
||||
try:
|
||||
# dataclass frozen y pydantic frozen levantan distintas excepciones
|
||||
cfg.prompt = "mutated" # type: ignore[misc]
|
||||
except Exception:
|
||||
raised = True
|
||||
|
||||
assert raised, "Se esperaba que mutar un campo frozen lanzara una excepcion"
|
||||
|
||||
|
||||
def test_negative_prompt_opcional():
|
||||
"""negative_prompt es opcional (default None)"""
|
||||
cfg = GenerationConfig(
|
||||
prompt="mountains",
|
||||
seed=0,
|
||||
steps=20,
|
||||
cfg_scale=7.0,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model=_make_model(),
|
||||
)
|
||||
assert cfg.negative_prompt is None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Tests para gpu_info."""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(0, "python/functions")
|
||||
|
||||
from ml.gpu_info import gpu_info
|
||||
|
||||
|
||||
class TestGpuInfo(unittest.TestCase):
|
||||
|
||||
def test_sin_nvidia_smi_devuelve_lista_vacia(self):
|
||||
"""sin nvidia-smi devuelve lista vacia"""
|
||||
with patch("subprocess.run", side_effect=FileNotFoundError()):
|
||||
result = gpu_info()
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_formato_CSV_correcto_devuelve_lista_con_un_dict_por_GPU(self):
|
||||
"""formato CSV correcto devuelve lista con un dict por GPU"""
|
||||
csv_output = " 0, NVIDIA RTX 4090, 24564, 22000, 535.183.01, 8.9\n"
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = csv_output
|
||||
with patch("subprocess.run", return_value=mock_result):
|
||||
result = gpu_info()
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["index"], 0)
|
||||
self.assertEqual(result[0]["name"], "NVIDIA RTX 4090")
|
||||
self.assertEqual(result[0]["vram_total_mb"], 24564)
|
||||
self.assertEqual(result[0]["vram_free_mb"], 22000)
|
||||
self.assertEqual(result[0]["driver_version"], "535.183.01")
|
||||
self.assertEqual(result[0]["cuda_version"], "8.9")
|
||||
|
||||
def test_fila_malformada_en_CSV_se_ignora_sin_excepcion(self):
|
||||
"""fila malformada en CSV se ignora sin excepcion"""
|
||||
csv_output = " 0, RTX 4090, NONNUMERIC, 22000, 535.183.01, 8.9\n"
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = csv_output
|
||||
with patch("subprocess.run", return_value=mock_result):
|
||||
result = gpu_info()
|
||||
self.assertEqual(result, [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Tests para hf_snapshot_download — mockear snapshot_download y verificar args."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
# Saltar si huggingface_hub no esta disponible Y no podemos mockearlo
|
||||
# Usamos un mock inline para no requerir la lib real.
|
||||
# Si la lib esta disponible, monkeypatch la reemplaza. Si no, la inyectamos manualmente.
|
||||
|
||||
|
||||
def _inject_fake_hf_hub(monkeypatch, capture: list):
|
||||
"""Inyecta un modulo huggingface_hub falso con snapshot_download que captura kwargs."""
|
||||
|
||||
def fake_snapshot_download(**kwargs):
|
||||
capture.append(kwargs)
|
||||
return "/tmp/fake_snapshot"
|
||||
|
||||
fake_module = types.ModuleType("huggingface_hub")
|
||||
fake_module.snapshot_download = fake_snapshot_download
|
||||
monkeypatch.setitem(sys.modules, "huggingface_hub", fake_module)
|
||||
|
||||
|
||||
def test_args_minimos_repo_id(monkeypatch):
|
||||
"""repo_id se pasa correctamente a snapshot_download"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
result = hf_snapshot_download("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
assert len(capture) == 1
|
||||
assert capture[0]["repo_id"] == "runwayml/stable-diffusion-v1-5"
|
||||
assert result == "/tmp/fake_snapshot"
|
||||
|
||||
|
||||
def test_retorna_string(monkeypatch):
|
||||
"""hf_snapshot_download retorna un string (la ruta local)"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
result = hf_snapshot_download("some/repo")
|
||||
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_allow_patterns_se_pasa(monkeypatch):
|
||||
"""allow_patterns se incluye en los kwargs si se especifica"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("some/repo", allow_patterns=["*.safetensors", "*.json"])
|
||||
|
||||
assert "allow_patterns" in capture[0]
|
||||
assert capture[0]["allow_patterns"] == ["*.safetensors", "*.json"]
|
||||
|
||||
|
||||
def test_ignore_patterns_se_pasa(monkeypatch):
|
||||
"""ignore_patterns se incluye en los kwargs si se especifica"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("some/repo", ignore_patterns=["*.bin", "flax_*"])
|
||||
|
||||
assert "ignore_patterns" in capture[0]
|
||||
assert capture[0]["ignore_patterns"] == ["*.bin", "flax_*"]
|
||||
|
||||
|
||||
def test_local_dir_se_pasa(monkeypatch):
|
||||
"""local_dir se incluye en los kwargs si se especifica"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("some/repo", local_dir="/models/sd15")
|
||||
|
||||
assert "local_dir" in capture[0]
|
||||
assert capture[0]["local_dir"] == "/models/sd15"
|
||||
|
||||
|
||||
def test_token_se_pasa(monkeypatch):
|
||||
"""token se incluye en los kwargs si se especifica"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("private/model", token="hf_mytoken123")
|
||||
|
||||
assert "token" in capture[0]
|
||||
assert capture[0]["token"] == "hf_mytoken123"
|
||||
|
||||
|
||||
def test_none_args_no_se_pasan(monkeypatch):
|
||||
"""args opcionales None no se incluyen en kwargs (no contaminar snapshot_download)"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("some/repo")
|
||||
|
||||
kwargs = capture[0]
|
||||
# Solo repo_id debe estar presente — los None no se incluyen
|
||||
assert "allow_patterns" not in kwargs
|
||||
assert "ignore_patterns" not in kwargs
|
||||
assert "local_dir" not in kwargs
|
||||
assert "token" not in kwargs
|
||||
|
||||
|
||||
def test_import_error_sin_huggingface_hub(monkeypatch):
|
||||
"""ImportError descriptivo si huggingface_hub no esta instalado"""
|
||||
import importlib
|
||||
|
||||
# Inyectar None en sys.modules para simular libreria no instalada
|
||||
monkeypatch.setitem(sys.modules, "huggingface_hub", None)
|
||||
|
||||
# Recargar el modulo para que el try/except del top-level vea el None
|
||||
import hf_snapshot_download as _mod
|
||||
importlib.reload(_mod)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
with pytest.raises(ImportError, match="huggingface_hub"):
|
||||
hf_snapshot_download("any/repo")
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Tests para image_compare_side_by_side."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
import pytest
|
||||
|
||||
PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping")
|
||||
|
||||
from PIL import Image
|
||||
from image_compare_side_by_side import image_compare_side_by_side
|
||||
|
||||
|
||||
def _black(w=16, h=16):
|
||||
return Image.new("RGB", (w, h), color=(0, 0, 0))
|
||||
|
||||
|
||||
def _white(w=16, h=16):
|
||||
return Image.new("RGB", (w, h), color=(255, 255, 255))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grid shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_grid_es_pil_image_con_dimensiones_correctas_show_diff_True():
|
||||
"""grid es PIL.Image con dimensiones correctas show_diff=True"""
|
||||
w, h = 16, 16
|
||||
gap = 16
|
||||
result = image_compare_side_by_side(_black(w, h), _white(w, h), gap_px=gap, show_diff=True)
|
||||
|
||||
grid = result["grid"]
|
||||
assert isinstance(grid, Image.Image), "grid debe ser PIL.Image.Image"
|
||||
|
||||
expected_w = 3 * w + 4 * gap # A + diff + B + 4 gaps
|
||||
expected_h = h + 2 * gap
|
||||
assert grid.size == (expected_w, expected_h), (
|
||||
f"Esperado ({expected_w}, {expected_h}), got {grid.size}"
|
||||
)
|
||||
|
||||
|
||||
def test_grid_es_pil_image_sin_diff_show_diff_False():
|
||||
"""grid es PIL.Image sin diff show_diff=False"""
|
||||
w, h = 16, 16
|
||||
gap = 8
|
||||
result = image_compare_side_by_side(_black(w, h), _white(w, h), gap_px=gap, show_diff=False)
|
||||
|
||||
grid = result["grid"]
|
||||
assert isinstance(grid, Image.Image), "grid debe ser PIL.Image.Image"
|
||||
|
||||
expected_w = 2 * w + 3 * gap # A + B + 3 gaps
|
||||
expected_h = h + 2 * gap
|
||||
assert grid.size == (expected_w, expected_h), (
|
||||
f"Esperado ({expected_w}, {expected_h}), got {grid.size}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MSE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_pixel_mse_positivo_para_imagenes_distintas():
|
||||
"""pixel_mse positivo para imagenes distintas"""
|
||||
result = image_compare_side_by_side(_black(), _white())
|
||||
mse = result["pixel_mse"]
|
||||
assert isinstance(mse, float), f"pixel_mse debe ser float, got {type(mse)}"
|
||||
assert mse > 0.0, f"pixel_mse debe ser > 0 para imagenes distintas, got {mse}"
|
||||
|
||||
|
||||
def test_pixel_mse_cero_para_imagen_identica():
|
||||
"""pixel_mse cero para imagen identica"""
|
||||
img = _black()
|
||||
result = image_compare_side_by_side(img, img.copy())
|
||||
mse = result["pixel_mse"]
|
||||
assert mse == 0.0, f"pixel_mse debe ser 0.0 para imagenes identicas, got {mse}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pHash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_phash_none_si_imagehash_no_disponible():
|
||||
"""phash None si imagehash no disponible"""
|
||||
try:
|
||||
import imagehash # noqa: F401
|
||||
pytest.skip("imagehash esta instalado — test de fallback no aplica")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
result = image_compare_side_by_side(_black(), _white(), show_phash=True)
|
||||
assert result["phash_a"] is None, "phash_a debe ser None si imagehash no instalado"
|
||||
assert result["phash_b"] is None, "phash_b debe ser None si imagehash no instalado"
|
||||
assert result["phash_distance"] is None, "phash_distance debe ser None si imagehash no instalado"
|
||||
|
||||
|
||||
def test_phash_presente_si_imagehash_disponible():
|
||||
"""phash presente si imagehash disponible"""
|
||||
try:
|
||||
import imagehash # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("imagehash no instalado")
|
||||
|
||||
result = image_compare_side_by_side(_black(), _white(), show_phash=True)
|
||||
assert isinstance(result["phash_a"], str), "phash_a debe ser str"
|
||||
assert isinstance(result["phash_b"], str), "phash_b debe ser str"
|
||||
assert isinstance(result["phash_distance"], int), "phash_distance debe ser int"
|
||||
assert len(result["phash_a"]) == 16, f"phash_a debe tener 16 hex chars, got {len(result['phash_a'])}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Campos del resultado
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_resultado_tiene_todas_las_claves():
|
||||
"""resultado tiene todas las claves esperadas"""
|
||||
result = image_compare_side_by_side(_black(), _white())
|
||||
for key in ("grid", "phash_a", "phash_b", "phash_distance", "pixel_mse"):
|
||||
assert key in result, f"Clave '{key}' faltante en resultado"
|
||||
|
||||
|
||||
def test_show_phash_false_deja_phash_none():
|
||||
"""show_phash=False deja phash* en None sin intentar import"""
|
||||
result = image_compare_side_by_side(_black(), _white(), show_phash=False)
|
||||
assert result["phash_a"] is None
|
||||
assert result["phash_b"] is None
|
||||
assert result["phash_distance"] is None
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Tests para ImageGenResult — dump excluye image, meta viaja correctamente."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
from image_gen_result import ImageGenResult
|
||||
|
||||
|
||||
def _make_result(image=None, duration_ms=1234, vram_peak_mb=None, meta=None):
|
||||
if meta is None:
|
||||
meta = {
|
||||
"model": "sd15",
|
||||
"seed_used": 42,
|
||||
"sampler": "euler_a",
|
||||
"prompt": "a cat",
|
||||
}
|
||||
return ImageGenResult(
|
||||
image=image,
|
||||
meta=meta,
|
||||
duration_ms=duration_ms,
|
||||
vram_peak_mb=vram_peak_mb,
|
||||
)
|
||||
|
||||
|
||||
def test_instancia_ok():
|
||||
"""ImageGenResult crea instancia sin errores"""
|
||||
r = _make_result(duration_ms=500)
|
||||
assert r.duration_ms == 500
|
||||
assert isinstance(r.meta, dict)
|
||||
|
||||
|
||||
def test_dump_excluye_image():
|
||||
"""model_dump excluye el campo image automaticamente"""
|
||||
|
||||
class FakeImage:
|
||||
"""Objeto imagen simulado (no PIL real)."""
|
||||
pass
|
||||
|
||||
r = _make_result(image=FakeImage(), duration_ms=800)
|
||||
d = r.model_dump()
|
||||
assert isinstance(d, dict)
|
||||
assert "image" not in d, "image no debe aparecer en model_dump()"
|
||||
|
||||
|
||||
def test_dump_incluye_meta_duration_vram():
|
||||
"""model_dump incluye meta, duration_ms y vram_peak_mb"""
|
||||
meta = {"model": "sdxl", "seed_used": 99, "sampler": "dpm++2m"}
|
||||
r = _make_result(duration_ms=2000, vram_peak_mb=6144, meta=meta)
|
||||
d = r.model_dump()
|
||||
assert "meta" in d
|
||||
assert "duration_ms" in d
|
||||
assert "vram_peak_mb" in d
|
||||
assert d["duration_ms"] == 2000
|
||||
assert d["vram_peak_mb"] == 6144
|
||||
|
||||
|
||||
def test_meta_dict_viaja_completo():
|
||||
"""meta dict se conserva completo en model_dump"""
|
||||
meta = {
|
||||
"model": "flux_dev",
|
||||
"seed_used": 777,
|
||||
"sampler": "euler",
|
||||
"custom_key": "custom_value",
|
||||
"nested": {"a": 1},
|
||||
}
|
||||
r = _make_result(meta=meta)
|
||||
d = r.model_dump()
|
||||
assert d["meta"] == meta
|
||||
assert d["meta"]["custom_key"] == "custom_value"
|
||||
assert d["meta"]["nested"] == {"a": 1}
|
||||
|
||||
|
||||
def test_dump_json_parseable():
|
||||
"""model_dump_json retorna string JSON parseable sin image"""
|
||||
meta = {"model": "sd15", "seed_used": 1}
|
||||
r = _make_result(duration_ms=100, meta=meta)
|
||||
raw = r.model_dump_json()
|
||||
assert isinstance(raw, str)
|
||||
parsed = json.loads(raw)
|
||||
assert "meta" in parsed
|
||||
assert "duration_ms" in parsed
|
||||
assert "image" not in parsed
|
||||
|
||||
|
||||
def test_vram_peak_mb_none_serializa():
|
||||
"""vram_peak_mb=None se serializa correctamente a null"""
|
||||
r = _make_result(vram_peak_mb=None)
|
||||
d = r.model_dump()
|
||||
assert d["vram_peak_mb"] is None
|
||||
|
||||
|
||||
def test_image_none_permitido():
|
||||
"""image puede ser None (generacion fallida)"""
|
||||
r = _make_result(image=None)
|
||||
assert r.image is None
|
||||
d = r.model_dump()
|
||||
assert "image" not in d
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Tests para ImageGenerator Protocol — runtime_checkable y structural subtyping."""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
from image_gen_result import ImageGenResult
|
||||
from image_generator import ImageGenerator
|
||||
|
||||
|
||||
class MockGenerator:
|
||||
"""Implementacion dummy que satisface ImageGenerator sin herencia explicita."""
|
||||
|
||||
def generate(self, config):
|
||||
"""Retorna un ImageGenResult sin imagen real."""
|
||||
return ImageGenResult(
|
||||
image=None,
|
||||
meta={"model": "mock", "seed_used": 0, "sampler": "euler"},
|
||||
duration_ms=1,
|
||||
vram_peak_mb=None,
|
||||
)
|
||||
|
||||
|
||||
class NotAGenerator:
|
||||
"""Clase que NO implementa generate — no satisface el Protocol."""
|
||||
|
||||
def predict(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def test_dummy_satisface_protocol():
|
||||
"""clase dummy que implementa generate satisface isinstance(x, ImageGenerator)"""
|
||||
gen = MockGenerator()
|
||||
assert isinstance(gen, ImageGenerator), (
|
||||
"MockGenerator debe satisfacer ImageGenerator Protocol (runtime_checkable)"
|
||||
)
|
||||
|
||||
|
||||
def test_resultado_es_image_gen_result():
|
||||
"""generate() retorna ImageGenResult"""
|
||||
gen = MockGenerator()
|
||||
result = gen.generate(config=None)
|
||||
assert isinstance(result, ImageGenResult)
|
||||
|
||||
|
||||
def test_clase_sin_generate_no_satisface_protocol():
|
||||
"""clase sin metodo generate NO satisface isinstance check"""
|
||||
not_gen = NotAGenerator()
|
||||
assert not isinstance(not_gen, ImageGenerator), (
|
||||
"NotAGenerator no debe satisfacer ImageGenerator Protocol"
|
||||
)
|
||||
|
||||
|
||||
def test_multiples_instancias_satisfacen_protocol():
|
||||
"""multiples instancias del mismo dummy satisfacen el Protocol"""
|
||||
for _ in range(3):
|
||||
gen = MockGenerator()
|
||||
assert isinstance(gen, ImageGenerator)
|
||||
|
||||
|
||||
def test_lambda_con_callable_no_satisface_protocol():
|
||||
"""un callable lambda no satisface el Protocol (no tiene metodo .generate)"""
|
||||
|
||||
class LambdaLike:
|
||||
def __call__(self, config):
|
||||
return None
|
||||
|
||||
obj = LambdaLike()
|
||||
# __call__ no es lo mismo que .generate — no debe satisfacer el protocol
|
||||
assert not isinstance(obj, ImageGenerator)
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Tests para image_grid — combina imagenes en grid NxM."""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import math
|
||||
import pytest
|
||||
|
||||
PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping")
|
||||
|
||||
from image_grid import image_grid
|
||||
|
||||
|
||||
def _make_images(n: int, w: int = 16, h: int = 16):
|
||||
from PIL import Image
|
||||
return [Image.new("RGB", (w, h), color=(i * 10, i * 10, i * 10)) for i in range(n)]
|
||||
|
||||
|
||||
def test_grid_4_imagenes_2_cols_dimensiones_correctas():
|
||||
"""grid de 4 imagenes 16x16 cols=2 produce ancho/alto correcto"""
|
||||
images = _make_images(4, w=16, h=16)
|
||||
result = image_grid(images, cols=2, gap_px=0)
|
||||
|
||||
# rows = ceil(4/2) = 2
|
||||
# canvas_w = 2*16 + 3*0 = 32 (con gap_px=0: cols*w + (cols+1)*0)
|
||||
# canvas_h = 2*16 + 3*0 = 32
|
||||
assert result.width == 32, f"Ancho esperado 32, got {result.width}"
|
||||
assert result.height == 32, f"Alto esperado 32, got {result.height}"
|
||||
|
||||
|
||||
def test_grid_4_imagenes_2_cols_con_gap():
|
||||
"""grid de 4 imagenes cols=2 gap_px=8 tiene dimensiones correctas con gap"""
|
||||
images = _make_images(4, w=16, h=16)
|
||||
gap = 8
|
||||
cols = 2
|
||||
rows = math.ceil(4 / cols)
|
||||
expected_w = cols * 16 + (cols + 1) * gap
|
||||
expected_h = rows * 16 + (rows + 1) * gap
|
||||
|
||||
result = image_grid(images, cols=cols, gap_px=gap)
|
||||
assert result.width == expected_w, f"Ancho: expected {expected_w}, got {result.width}"
|
||||
assert result.height == expected_h, f"Alto: expected {expected_h}, got {result.height}"
|
||||
|
||||
|
||||
def test_grid_1_imagen_1_col():
|
||||
"""grid de 1 imagen 1 col = imagen sola mas gaps"""
|
||||
images = _make_images(1, w=32, h=32)
|
||||
result = image_grid(images, cols=1, gap_px=4)
|
||||
# rows=1, cols=1 → w = 1*32 + 2*4 = 40, h = 1*32 + 2*4 = 40
|
||||
assert result.width == 40
|
||||
assert result.height == 40
|
||||
|
||||
|
||||
def test_grid_retorna_imagen_rgb():
|
||||
"""el resultado es una imagen RGB"""
|
||||
from PIL import Image
|
||||
images = _make_images(2, w=8, h=8)
|
||||
result = image_grid(images, cols=2)
|
||||
assert isinstance(result, Image.Image)
|
||||
assert result.mode == "RGB"
|
||||
|
||||
|
||||
def test_grid_con_labels_no_falla():
|
||||
"""labels opcionales — no lanza excepcion"""
|
||||
images = _make_images(4, w=16, h=16)
|
||||
labels = ["a", "b", "c", "d"]
|
||||
result = image_grid(images, cols=2, labels=labels, gap_px=0)
|
||||
# Debe devolver imagen válida
|
||||
assert result.width > 0
|
||||
assert result.height > 0
|
||||
|
||||
|
||||
def test_grid_sin_labels_no_falla():
|
||||
"""sin labels funciona correctamente"""
|
||||
images = _make_images(3, w=16, h=16)
|
||||
result = image_grid(images, cols=3, labels=None, gap_px=0)
|
||||
assert result.width == 3 * 16
|
||||
assert result.height == 16 # 1 row
|
||||
|
||||
|
||||
def test_grid_lista_vacia_levanta_value_error():
|
||||
"""lista vacia levanta ValueError"""
|
||||
with pytest.raises(ValueError):
|
||||
image_grid([], cols=2)
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Tests para image_save_png — guarda PNG con metadata tEXt embebida."""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping")
|
||||
|
||||
from image_save_png import image_save_png
|
||||
|
||||
|
||||
def test_guarda_archivo_y_retorna_ruta_absoluta(tmp_path):
|
||||
"""crea imagen 8x8, guarda y retorna ruta absoluta"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (8, 8), color=(255, 0, 0))
|
||||
dest = str(tmp_path / "test.png")
|
||||
result = image_save_png(img, dest)
|
||||
|
||||
import os
|
||||
assert os.path.isfile(result), f"El archivo no existe: {result}"
|
||||
assert os.path.isabs(result), f"La ruta no es absoluta: {result}"
|
||||
|
||||
|
||||
def test_metadata_embebida_en_chunks_text(tmp_path):
|
||||
"""metadata se embebe en chunks tEXt y se puede releer con Image.text"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (8, 8), color=(0, 128, 0))
|
||||
dest = str(tmp_path / "with_meta.png")
|
||||
meta = {"prompt": "hi", "seed": "42"}
|
||||
image_save_png(img, dest, metadata=meta)
|
||||
|
||||
reopened = Image.open(dest)
|
||||
text_data = reopened.text # dict de chunks tEXt del PNG
|
||||
assert "prompt" in text_data, f"Falta clave 'prompt' en PNG text chunks: {text_data}"
|
||||
assert "seed" in text_data, f"Falta clave 'seed' en PNG text chunks: {text_data}"
|
||||
assert text_data["prompt"] == "hi"
|
||||
assert text_data["seed"] == "42"
|
||||
|
||||
|
||||
def test_sin_metadata_no_falla(tmp_path):
|
||||
"""sin metadata el PNG se guarda igualmente"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (8, 8))
|
||||
dest = str(tmp_path / "no_meta.png")
|
||||
result = image_save_png(img, dest, metadata=None)
|
||||
|
||||
import os
|
||||
assert os.path.isfile(result)
|
||||
|
||||
|
||||
def test_crea_directorio_padre(tmp_path):
|
||||
"""crea directorio padre si no existe"""
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
img = Image.new("RGB", (8, 8))
|
||||
dest = str(tmp_path / "subdir" / "deep" / "image.png")
|
||||
result = image_save_png(img, dest)
|
||||
assert os.path.isfile(result)
|
||||
|
||||
|
||||
def test_metadata_valores_numericos_se_convierten_a_str(tmp_path):
|
||||
"""valores numericos en metadata se convierten a str automaticamente"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (8, 8))
|
||||
dest = str(tmp_path / "numeric.png")
|
||||
meta = {"steps": 30, "cfg_scale": 7.5}
|
||||
image_save_png(img, dest, metadata=meta)
|
||||
|
||||
reopened = Image.open(dest)
|
||||
text_data = reopened.text
|
||||
assert "steps" in text_data
|
||||
assert "cfg_scale" in text_data
|
||||
assert text_data["steps"] == "30"
|
||||
assert text_data["cfg_scale"] == "7.5"
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Tests para ModelRef y LoraRef — instanciación, dump y validación."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
# Importar desde el subdirectorio ml directamente para evitar colisiones de tipos
|
||||
# entre ml.model_ref.ModelRef y model_ref.ModelRef en pydantic.
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
from lora_ref import LoraRef
|
||||
from model_ref import ModelRef
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ModelRef
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_model_ref_instancia_ok():
|
||||
"""ModelRef instancia sin errores"""
|
||||
m = ModelRef(name="stabilityai/sdxl-base-1.0", model_type="sdxl")
|
||||
assert m.name == "stabilityai/sdxl-base-1.0"
|
||||
assert m.model_type == "sdxl"
|
||||
|
||||
|
||||
def test_model_ref_quantization_default_fp16():
|
||||
"""quantization default es fp16"""
|
||||
m = ModelRef(name="runwayml/stable-diffusion-v1-5", model_type="sd15")
|
||||
assert m.quantization == "fp16"
|
||||
|
||||
|
||||
def test_model_ref_quantization_override():
|
||||
"""quantization se puede cambiar a otro valor válido"""
|
||||
m = ModelRef(name="some/model", model_type="flux_dev", quantization="bf16")
|
||||
assert m.quantization == "bf16"
|
||||
|
||||
|
||||
def test_model_ref_path_default_none():
|
||||
"""path es None por defecto"""
|
||||
m = ModelRef(name="some/model", model_type="sd15")
|
||||
assert m.path is None
|
||||
|
||||
|
||||
def test_model_ref_path_set():
|
||||
"""path se puede especificar"""
|
||||
m = ModelRef(name="some/model", model_type="sd15", path="/models/sd15.safetensors")
|
||||
assert m.path == "/models/sd15.safetensors"
|
||||
|
||||
|
||||
def test_model_ref_dump():
|
||||
"""model_dump devuelve dict con las claves esperadas"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
m = ModelRef(name="some/model", model_type="sdxl", quantization="q8_0")
|
||||
d = m.model_dump()
|
||||
assert isinstance(d, dict)
|
||||
assert d["name"] == "some/model"
|
||||
assert d["model_type"] == "sdxl"
|
||||
assert d["quantization"] == "q8_0"
|
||||
|
||||
|
||||
def test_model_ref_validate_roundtrip():
|
||||
"""roundtrip model_dump_json / model_validate ok"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
m = ModelRef(name="some/model", model_type="sd3", quantization="fp32")
|
||||
raw = json.loads(m.model_dump_json())
|
||||
m2 = ModelRef.model_validate(raw)
|
||||
assert m2.name == m.name
|
||||
assert m2.model_type == m.model_type
|
||||
assert m2.quantization == m.quantization
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoraRef
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_lora_ref_instancia_ok():
|
||||
"""LoraRef instancia con path obligatorio"""
|
||||
lr = LoraRef(path="/loras/anime.safetensors")
|
||||
assert lr.path == "/loras/anime.safetensors"
|
||||
|
||||
|
||||
def test_lora_ref_weight_default_1():
|
||||
"""LoraRef weight default es 1.0"""
|
||||
lr = LoraRef(path="/loras/style.safetensors")
|
||||
assert lr.weight == 1.0
|
||||
|
||||
|
||||
def test_lora_ref_weight_override():
|
||||
"""LoraRef weight se puede cambiar"""
|
||||
lr = LoraRef(path="/loras/style.safetensors", weight=0.7)
|
||||
assert lr.weight == 0.7
|
||||
|
||||
|
||||
def test_lora_ref_scale_default_none():
|
||||
"""LoraRef scale default es None"""
|
||||
lr = LoraRef(path="/loras/x.safetensors")
|
||||
assert lr.scale is None
|
||||
|
||||
|
||||
def test_lora_ref_dump():
|
||||
"""LoraRef model_dump devuelve dict con las claves esperadas"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
lr = LoraRef(path="/loras/x.safetensors", weight=0.8, scale=0.9)
|
||||
d = lr.model_dump()
|
||||
assert d["path"] == "/loras/x.safetensors"
|
||||
assert d["weight"] == 0.8
|
||||
assert d["scale"] == 0.9
|
||||
|
||||
|
||||
def test_lora_ref_validate_roundtrip():
|
||||
"""roundtrip dump / validate ok"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
lr = LoraRef(path="/loras/x.safetensors", weight=0.5)
|
||||
raw = json.loads(lr.model_dump_json())
|
||||
lr2 = LoraRef.model_validate(raw)
|
||||
assert lr2.path == lr.path
|
||||
assert lr2.weight == lr.weight
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Tests para safetensors_inspect — parseo de header sin dependencias externas."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
from safetensors_inspect import safetensors_inspect
|
||||
|
||||
|
||||
def _write_safetensors(path: str, header: dict, data: bytes = b"") -> None:
|
||||
"""Escribe un archivo safetensors mínimo siguiendo la spec oficial.
|
||||
|
||||
Spec: https://github.com/huggingface/safetensors#format
|
||||
- 8 bytes: uint64 little-endian con la longitud N del header JSON
|
||||
- N bytes: JSON del header
|
||||
- (opcional) bytes de datos de tensores
|
||||
"""
|
||||
header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8")
|
||||
header_len = len(header_bytes)
|
||||
with open(path, "wb") as f:
|
||||
f.write(struct.pack("<Q", header_len)) # uint64 LE
|
||||
f.write(header_bytes)
|
||||
f.write(data)
|
||||
|
||||
|
||||
def _make_minimal_header(n_tensors: int = 2) -> dict:
|
||||
"""Genera un header con n_tensors tensores sintéticos."""
|
||||
header = {
|
||||
"__metadata__": {"format": "pt", "creator": "test"},
|
||||
}
|
||||
for i in range(n_tensors):
|
||||
header[f"tensor_{i}"] = {
|
||||
"dtype": "F32",
|
||||
"shape": [4, 4],
|
||||
"data_offsets": [i * 64, (i + 1) * 64],
|
||||
}
|
||||
return header
|
||||
|
||||
|
||||
def test_n_tensors_correcto(tmp_path):
|
||||
"""n_tensors refleja el numero de tensores en el header"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
_write_safetensors(path, _make_minimal_header(n_tensors=3))
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert result["n_tensors"] == 3
|
||||
|
||||
|
||||
def test_total_size_bytes_correcto(tmp_path):
|
||||
"""total_size_bytes refleja el tamaño real del archivo"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
data = b"\x00" * 128 # 128 bytes de datos de tensor
|
||||
_write_safetensors(path, _make_minimal_header(2), data=data)
|
||||
|
||||
file_size = os.path.getsize(path)
|
||||
result = safetensors_inspect(path)
|
||||
assert result["total_size_bytes"] == file_size
|
||||
|
||||
|
||||
def test_metadata_campo_dunder_presente(tmp_path):
|
||||
"""metadata devuelve el contenido de __metadata__"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
header = {
|
||||
"__metadata__": {"format": "pt", "model_name": "test_model"},
|
||||
"weight": {"dtype": "BF16", "shape": [8], "data_offsets": [0, 16]},
|
||||
}
|
||||
_write_safetensors(path, header)
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert result["metadata"] == {"format": "pt", "model_name": "test_model"}
|
||||
|
||||
|
||||
def test_tensors_lista_correcta(tmp_path):
|
||||
"""tensors es lista con una entrada por tensor del header"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
header = {
|
||||
"__metadata__": {},
|
||||
"embed.weight": {"dtype": "F16", "shape": [128, 64], "data_offsets": [0, 16384]},
|
||||
"proj.bias": {"dtype": "F32", "shape": [64], "data_offsets": [16384, 16640]},
|
||||
}
|
||||
_write_safetensors(path, header)
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert result["n_tensors"] == 2
|
||||
names = {t["name"] for t in result["tensors"]}
|
||||
assert "embed.weight" in names
|
||||
assert "proj.bias" in names
|
||||
|
||||
|
||||
def test_tensor_campos_dtype_shape_offsets(tmp_path):
|
||||
"""cada tensor tiene dtype, shape y data_offsets"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
header = {
|
||||
"__metadata__": {},
|
||||
"my_tensor": {"dtype": "I32", "shape": [2, 3], "data_offsets": [0, 24]},
|
||||
}
|
||||
_write_safetensors(path, header)
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
t = result["tensors"][0]
|
||||
assert t["dtype"] == "I32"
|
||||
assert t["shape"] == [2, 3]
|
||||
assert t["data_offsets"] == [0, 24]
|
||||
|
||||
|
||||
def test_path_absoluto_en_resultado(tmp_path):
|
||||
"""result['path'] es la ruta absoluta del archivo"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
_write_safetensors(path, _make_minimal_header(1))
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert os.path.isabs(result["path"])
|
||||
assert result["path"].endswith("model.safetensors")
|
||||
|
||||
|
||||
def test_archivo_no_encontrado_levanta_file_not_found(tmp_path):
|
||||
"""FileNotFoundError si el archivo no existe"""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
safetensors_inspect(str(tmp_path / "nonexistent.safetensors"))
|
||||
|
||||
|
||||
def test_header_invalido_levanta_value_error(tmp_path):
|
||||
"""ValueError si el header no es JSON válido"""
|
||||
path = str(tmp_path / "bad.safetensors")
|
||||
with open(path, "wb") as f:
|
||||
bad_header = b"NOT JSON!!"
|
||||
f.write(struct.pack("<Q", len(bad_header)))
|
||||
f.write(bad_header)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
safetensors_inspect(path)
|
||||
|
||||
|
||||
def test_archivo_vacio_levanta_value_error(tmp_path):
|
||||
"""ValueError si el archivo está vacío (< 8 bytes)"""
|
||||
path = str(tmp_path / "empty.safetensors")
|
||||
with open(path, "wb") as f:
|
||||
f.write(b"")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
safetensors_inspect(path)
|
||||
|
||||
|
||||
def test_sin_metadata_dunder(tmp_path):
|
||||
"""si no hay __metadata__ en el header, metadata retorna dict vacio"""
|
||||
path = str(tmp_path / "no_meta.safetensors")
|
||||
header = {
|
||||
"weight": {"dtype": "F32", "shape": [4], "data_offsets": [0, 16]},
|
||||
}
|
||||
_write_safetensors(path, header)
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert result["metadata"] == {}
|
||||
assert result["n_tensors"] == 1
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Tests para torch_device_select."""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
sys.path.insert(0, "python/functions")
|
||||
|
||||
import ml.torch_device_select as tds_module
|
||||
from ml.torch_device_select import torch_device_select
|
||||
|
||||
|
||||
class TestTorchDeviceSelect(unittest.TestCase):
|
||||
|
||||
def _patch(self, cuda=False, mps=False, cuda_count=0):
|
||||
"""Helper: parchea los helpers internos del modulo."""
|
||||
from unittest.mock import patch
|
||||
return [
|
||||
patch.object(tds_module, "_cuda_available", return_value=cuda),
|
||||
patch.object(tds_module, "_mps_available", return_value=mps),
|
||||
patch.object(tds_module, "_cuda_device_count", return_value=cuda_count),
|
||||
]
|
||||
|
||||
def test_preference_cpu_siempre_retorna_cpu(self):
|
||||
"""preference=cpu siempre retorna cpu"""
|
||||
self.assertEqual(torch_device_select("cpu"), "cpu")
|
||||
|
||||
def test_preference_auto_sin_cuda_ni_mps_retorna_cpu(self):
|
||||
"""preference=auto sin cuda ni mps retorna cpu"""
|
||||
patches = self._patch(cuda=False, mps=False)
|
||||
for p in patches:
|
||||
p.start()
|
||||
try:
|
||||
self.assertEqual(torch_device_select("auto"), "cpu")
|
||||
finally:
|
||||
for p in patches:
|
||||
p.stop()
|
||||
|
||||
def test_preference_cuda_sin_cuda_disponible_retorna_cpu_con_warning(self):
|
||||
"""preference=cuda sin cuda disponible retorna cpu con warning"""
|
||||
patches = self._patch(cuda=False, mps=False)
|
||||
for p in patches:
|
||||
p.start()
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = torch_device_select("cuda")
|
||||
self.assertEqual(result, "cpu")
|
||||
self.assertTrue(any("CUDA" in str(warning.message) for warning in w))
|
||||
finally:
|
||||
for p in patches:
|
||||
p.stop()
|
||||
|
||||
def test_preference_cuda_N_con_solo_1_GPU_retorna_cpu_con_warning(self):
|
||||
"""preference=cuda:5 con solo 1 GPU retorna cpu con warning"""
|
||||
patches = self._patch(cuda=True, cuda_count=1)
|
||||
for p in patches:
|
||||
p.start()
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = torch_device_select("cuda:5")
|
||||
self.assertEqual(result, "cpu")
|
||||
self.assertTrue(len(w) > 0)
|
||||
finally:
|
||||
for p in patches:
|
||||
p.stop()
|
||||
|
||||
def test_preference_desconocida_retorna_cpu_con_warning(self):
|
||||
"""preference desconocida retorna cpu con warning"""
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = torch_device_select("vulkan")
|
||||
self.assertEqual(result, "cpu")
|
||||
self.assertTrue(len(w) > 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests para vram_budget."""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Ajustar path para importar desde python/functions/
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ml.vram_budget import vram_budget
|
||||
|
||||
|
||||
def test_sdxl_fp16_fits_24gb():
|
||||
"""SDXL fp16 en una GPU de 24 GB deberia caber con headroom positivo."""
|
||||
result = vram_budget(
|
||||
gpu_vram_total_mb=24576, # 24 GB
|
||||
model_type="sdxl",
|
||||
quantization="fp16",
|
||||
n_loras=0,
|
||||
width=1024,
|
||||
height=1024,
|
||||
batch_size=1,
|
||||
)
|
||||
assert result["required_mb"] > 0, "required_mb debe ser positivo"
|
||||
assert result["fits"] is True, f"SDXL fp16 debe caber en 24 GB, required={result['required_mb']} MB"
|
||||
assert result["headroom_mb"] > 0, f"headroom debe ser positivo, got {result['headroom_mb']}"
|
||||
assert result["warning"] is None, f"no debe haber warning, got: {result['warning']}"
|
||||
|
||||
|
||||
def test_flux_fp16_no_fits_8gb():
|
||||
"""Flux fp16 (~23 GB) no debe caber en una GPU de 8 GB."""
|
||||
result = vram_budget(
|
||||
gpu_vram_total_mb=8192, # 8 GB
|
||||
model_type="flux_dev",
|
||||
quantization="fp16",
|
||||
n_loras=0,
|
||||
width=1024,
|
||||
height=1024,
|
||||
batch_size=1,
|
||||
)
|
||||
assert result["required_mb"] > 8192, f"Flux fp16 debe requerir mas de 8 GB, got {result['required_mb']} MB"
|
||||
assert result["fits"] is False, "Flux fp16 no debe caber en 8 GB"
|
||||
assert result["headroom_mb"] < 0, f"headroom debe ser negativo, got {result['headroom_mb']}"
|
||||
assert result["warning"] is not None, "debe haber warning con informacion de deficit"
|
||||
assert "+N MB" in result["warning"] or "+" in result["warning"], \
|
||||
f"warning debe indicar cuantos MB extra se necesitan: {result['warning']}"
|
||||
|
||||
|
||||
def test_lora_plus_quant_warning():
|
||||
"""LoRA con quantization q8_0 debe emitir warning de incompatibilidad."""
|
||||
result = vram_budget(
|
||||
gpu_vram_total_mb=24576,
|
||||
model_type="sdxl",
|
||||
quantization="q8_0",
|
||||
n_loras=2,
|
||||
width=1024,
|
||||
height=1024,
|
||||
batch_size=1,
|
||||
)
|
||||
assert result["warning"] is not None, "debe haber warning por lora+quantization incompatible"
|
||||
assert "incompatible" in result["warning"].lower(), \
|
||||
f"warning debe mencionar incompatibilidad: {result['warning']}"
|
||||
assert "fp16" in result["warning"], \
|
||||
f"warning debe sugerir fp16: {result['warning']}"
|
||||
|
||||
|
||||
def test_unknown_combo():
|
||||
"""Combinacion (model_type, quant) desconocida debe retornar required_mb=-1 y warning."""
|
||||
result = vram_budget(
|
||||
gpu_vram_total_mb=24576,
|
||||
model_type="modelo_inventado",
|
||||
quantization="q99_k",
|
||||
n_loras=0,
|
||||
)
|
||||
assert result["required_mb"] == -1, \
|
||||
f"required_mb debe ser -1 para combo desconocido, got {result['required_mb']}"
|
||||
assert result["fits"] is False, "fits debe ser False para combo desconocido"
|
||||
assert result["warning"] is not None, "debe haber warning para combo desconocido"
|
||||
assert "unknown" in result["warning"].lower(), \
|
||||
f"warning debe mencionar 'unknown': {result['warning']}"
|
||||
Reference in New Issue
Block a user