chore: auto-commit (95 archivos)
- cmd/fn/doctor.go - cmd/fn/main.go - cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt - cpp/apps/primitives_gallery/playground/tables/data_table.cpp - cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp - cpp/apps/primitives_gallery/playground/tables/data_table_logic.h - cpp/apps/primitives_gallery/playground/tables/self_test.cpp - cpp/apps/primitives_gallery/playground/tables/tql.cpp - cpp/apps/primitives_gallery/playground/tables/viz.cpp - cpp/apps/primitives_gallery/playground/tables/viz.h - ... Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,161 @@
|
||||
"""Tests para vault_csv_profile."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from vault_csv_profile import vault_csv_profile
|
||||
|
||||
|
||||
def _make_vault(tmp: Path) -> tuple[Path, Path]:
|
||||
"""Crea un vault mínimo con vault_index.db y tabla files + files_fts + csv_profiles."""
|
||||
db = tmp / "vault_index.db"
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
rel_path TEXT UNIQUE NOT NULL,
|
||||
size_bytes INTEGER,
|
||||
ext TEXT
|
||||
);
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS files_fts
|
||||
USING fts5(rel_path, content_text, content='', contentless_delete=1);
|
||||
CREATE TABLE IF NOT EXISTS csv_profiles (
|
||||
rel_path TEXT PRIMARY KEY,
|
||||
cols_json TEXT,
|
||||
n_rows INTEGER,
|
||||
encoding TEXT,
|
||||
date_min TEXT,
|
||||
date_max TEXT,
|
||||
profiled_at INTEGER
|
||||
);
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return tmp, db
|
||||
|
||||
|
||||
def _insert_file_entry(db: Path, rel_path: str):
|
||||
"""Inserta entrada en files para que files_fts tenga rowid válido."""
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO files(rel_path, size_bytes, ext) VALUES (?, 0, '.csv')",
|
||||
(rel_path,),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_csv_basic(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "data/basic.csv"
|
||||
csv_file = vault / rel
|
||||
csv_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_file.write_text("nombre,edad,score\nAna,30,9.5\nBob,25,8.0\nCarla,35,7.5\n", encoding="utf-8")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_csv_profile(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["rel_path"] == rel
|
||||
assert result["n_rows"] == 3
|
||||
assert len(result["cols"]) == 3
|
||||
col_names = [c["name"] for c in result["cols"]]
|
||||
assert "nombre" in col_names
|
||||
assert "edad" in col_names
|
||||
assert "score" in col_names
|
||||
assert result["persisted"] is True
|
||||
|
||||
# Verificar persistencia en csv_profiles
|
||||
conn = sqlite3.connect(str(db))
|
||||
row = conn.execute("SELECT n_rows FROM csv_profiles WHERE rel_path = ?", (rel,)).fetchone()
|
||||
conn.close()
|
||||
assert row is not None
|
||||
assert row[0] == 3
|
||||
|
||||
|
||||
def test_csv_date_detection(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "data/fechas.csv"
|
||||
csv_file = vault / rel
|
||||
csv_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_file.write_text(
|
||||
"fecha,valor\n2023-01-01,100\n2023-06-15,200\n2023-12-31,300\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_csv_profile(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["date_min"] is not None
|
||||
assert result["date_max"] is not None
|
||||
assert result["date_min"] <= "2023-01-01"
|
||||
assert result["date_max"] >= "2023-12-31"
|
||||
|
||||
|
||||
def test_csv_encoding_latin1(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "data/tildes.csv"
|
||||
csv_file = vault / rel
|
||||
csv_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_file.write_bytes(
|
||||
"ciudad,poblacion\nMálaga,500000\nCórdoba,320000\n".encode("latin-1")
|
||||
)
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_csv_profile(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["n_rows"] == 2
|
||||
assert result["encoding"] != "utf-8?"
|
||||
# encoding detectado (algún valor no vacío)
|
||||
assert result["encoding"]
|
||||
assert result["persisted"] is True
|
||||
|
||||
|
||||
def test_csv_empty(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "data/empty.csv"
|
||||
csv_file = vault / rel
|
||||
csv_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_file.write_text("", encoding="utf-8")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_csv_profile(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["n_rows"] == 0
|
||||
assert result["cols"] == []
|
||||
assert result["date_min"] is None
|
||||
assert result["date_max"] is None
|
||||
|
||||
|
||||
def test_csv_persists_fts(tmp_path):
|
||||
"""FTS5 contentless: verifica que las columnas son buscables con MATCH."""
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "data/fts_test.csv"
|
||||
csv_file = vault / rel
|
||||
csv_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_file.write_text("producto,precio\nManzana,1.5\nPera,2.0\n", encoding="utf-8")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
vault_csv_profile(str(vault), rel, db_path=str(db))
|
||||
|
||||
conn = sqlite3.connect(str(db))
|
||||
# FTS5 contentless no permite SELECT directo — usar MATCH para verificar indexado
|
||||
row_prod = conn.execute(
|
||||
"SELECT rowid FROM files_fts WHERE files_fts MATCH 'producto'",
|
||||
).fetchone()
|
||||
row_prec = conn.execute(
|
||||
"SELECT rowid FROM files_fts WHERE files_fts MATCH 'precio'",
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
assert row_prod is not None, "FTS no encontró 'producto'"
|
||||
assert row_prec is not None, "FTS no encontró 'precio'"
|
||||
@@ -0,0 +1,147 @@
|
||||
"""Tests para vault_pdf_extract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from vault_pdf_extract import vault_pdf_extract
|
||||
|
||||
|
||||
def _make_vault(tmp: Path) -> tuple[Path, Path]:
|
||||
"""Crea un vault mínimo con vault_index.db."""
|
||||
db = tmp / "vault_index.db"
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
rel_path TEXT UNIQUE NOT NULL,
|
||||
size_bytes INTEGER,
|
||||
ext TEXT
|
||||
);
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS files_fts
|
||||
USING fts5(rel_path, content_text, content='', contentless_delete=1);
|
||||
CREATE TABLE IF NOT EXISTS pdf_extracts (
|
||||
rel_path TEXT PRIMARY KEY,
|
||||
page_count INTEGER,
|
||||
text_len INTEGER,
|
||||
extracted_to TEXT,
|
||||
extracted_at INTEGER
|
||||
);
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return tmp, db
|
||||
|
||||
|
||||
def _insert_file_entry(db: Path, rel_path: str):
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO files(rel_path, size_bytes, ext) VALUES (?, 0, '.pdf')",
|
||||
(rel_path,),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def _make_pdf(path: Path, text: str = "Hello vault PDF.\nPage two content."):
|
||||
"""Crea un PDF mínimo con fitz para tests."""
|
||||
import fitz
|
||||
|
||||
doc = fitz.open()
|
||||
page = doc.new_page()
|
||||
page.insert_text((72, 72), text)
|
||||
doc.save(str(path))
|
||||
doc.close()
|
||||
|
||||
|
||||
def test_pdf_extract_basic(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/test.pdf"
|
||||
pdf = vault / rel
|
||||
pdf.parent.mkdir(parents=True, exist_ok=True)
|
||||
_make_pdf(pdf)
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_pdf_extract(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["rel_path"] == rel
|
||||
assert result["page_count"] >= 1
|
||||
assert result["text_len"] > 0
|
||||
assert result["persisted"] is True
|
||||
|
||||
conn = sqlite3.connect(str(db))
|
||||
row = conn.execute("SELECT page_count, text_len FROM pdf_extracts WHERE rel_path=?", (rel,)).fetchone()
|
||||
conn.close()
|
||||
assert row is not None
|
||||
assert row[0] >= 1
|
||||
assert row[1] > 0
|
||||
|
||||
|
||||
def test_pdf_dump_text_creates_file(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/dump.pdf"
|
||||
pdf = vault / rel
|
||||
pdf.parent.mkdir(parents=True, exist_ok=True)
|
||||
_make_pdf(pdf, "Contenido para dump a disco.")
|
||||
_insert_file_entry(db, rel)
|
||||
# Crear data/processed/ para que se use ese directorio
|
||||
(vault / "data" / "processed").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = vault_pdf_extract(str(vault), rel, db_path=str(db), dump_text=True)
|
||||
|
||||
assert result["extracted_to"] is not None
|
||||
txt_path = vault / result["extracted_to"]
|
||||
assert txt_path.exists()
|
||||
assert txt_path.stat().st_size > 0
|
||||
|
||||
|
||||
def test_pdf_no_dump(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/nodump.pdf"
|
||||
pdf = vault / rel
|
||||
pdf.parent.mkdir(parents=True, exist_ok=True)
|
||||
_make_pdf(pdf, "No se debe volcar a disco.")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_pdf_extract(str(vault), rel, db_path=str(db), dump_text=False)
|
||||
|
||||
assert result["extracted_to"] is None
|
||||
|
||||
|
||||
def test_pdf_persists_to_fts(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/fts.pdf"
|
||||
pdf = vault / rel
|
||||
pdf.parent.mkdir(parents=True, exist_ok=True)
|
||||
_make_pdf(pdf, "Texto especial para FTS xyzpdftest.")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
vault_pdf_extract(str(vault), rel, db_path=str(db), dump_text=False)
|
||||
|
||||
conn = sqlite3.connect(str(db))
|
||||
# FTS5 contentless: no permite SELECT directo, usar MATCH
|
||||
row = conn.execute(
|
||||
"SELECT rowid FROM files_fts WHERE files_fts MATCH 'xyzpdftest'",
|
||||
).fetchone()
|
||||
conn.close()
|
||||
assert row is not None, "FTS no encontró el texto del PDF"
|
||||
|
||||
|
||||
def test_pdf_corrupt_errors(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/corrupt.pdf"
|
||||
pdf = vault / rel
|
||||
pdf.parent.mkdir(parents=True, exist_ok=True)
|
||||
pdf.write_bytes(b"%PDF-1.4 garbage bytes \x00\x01\x02 not a real pdf")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
with pytest.raises(RuntimeError, match="corrupto|inválido|PDF"):
|
||||
vault_pdf_extract(str(vault), rel, db_path=str(db))
|
||||
@@ -0,0 +1,61 @@
|
||||
---
|
||||
name: vault_csv_profile
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def vault_csv_profile(vault_path: str, rel_path: str, db_path: str | None = None) -> dict"
|
||||
description: "Perfila un CSV del vault: detecta encoding, lee schema con polars, extrae n_rows y columnas de fecha; persiste en csv_profiles y actualiza files_fts para búsqueda por contenido."
|
||||
tags: [vault, csv, profiling, polars, encoding, datascience, fts]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [sqlite3, time, pathlib, json, polars, chardet]
|
||||
params:
|
||||
- name: vault_path
|
||||
desc: "Ruta absoluta a la raiz del vault donde vive el CSV y vault_index.db."
|
||||
- name: rel_path
|
||||
desc: "Ruta relativa al CSV dentro del vault (ej. 'data/raw/ventas.csv')."
|
||||
- name: db_path
|
||||
desc: "Override opcional de la ruta a vault_index.db. Por defecto <vault_path>/vault_index.db."
|
||||
output: "Dict con: rel_path (str), cols (list de {name, dtype}), n_rows (int), encoding (str), date_min/date_max (ISO yyyy-mm-dd o None), persisted (bool)."
|
||||
tested: true
|
||||
tests:
|
||||
- "test_csv_basic"
|
||||
- "test_csv_date_detection"
|
||||
- "test_csv_encoding_latin1"
|
||||
- "test_csv_empty"
|
||||
- "test_csv_persists_fts"
|
||||
test_file_path: "python/functions/datascience/tests/test_vault_csv_profile.py"
|
||||
file_path: "python/functions/datascience/vault_csv_profile.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from vault_csv_profile import vault_csv_profile
|
||||
|
||||
result = vault_csv_profile("/vaults/mi_vault", "data/raw/ventas.csv")
|
||||
# {
|
||||
# "rel_path": "data/raw/ventas.csv",
|
||||
# "cols": [{"name": "fecha", "dtype": "String"}, {"name": "importe", "dtype": "Float64"}],
|
||||
# "n_rows": 1500,
|
||||
# "encoding": "utf-8",
|
||||
# "date_min": "2023-01-01",
|
||||
# "date_max": "2023-12-31",
|
||||
# "persisted": True
|
||||
# }
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- Usa polars (lazy scan) como motor principal; pandas como fallback.
|
||||
- Detección de encoding: chardet con confianza >= 0.6, luego intentos utf-8-sig → utf-8 → latin-1 → cp1252.
|
||||
- Detección de fechas: columnas Date/Datetime nativas de polars, o columnas String con ≥80% de valores parseables como fecha.
|
||||
- El FTS text incluye nombres de columnas + primeras 5 filas concatenadas.
|
||||
- Upsert en csv_profiles por rel_path; el rowid de files_fts se ancla al rowid de la tabla files para que vault_search funcione correctamente.
|
||||
- Si vault_index.db no existe, la función retorna el dict sin intentar persistir (persisted=False).
|
||||
- Dependencias: polars, chardet (ambas instaladas en python/.venv con uv add).
|
||||
@@ -0,0 +1,216 @@
|
||||
"""vault_csv_profile — Perfila un CSV del vault y persiste metadata en vault_index.db."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _detect_encoding(path: Path) -> str:
|
||||
"""Detecta encoding del archivo con chardet o por intentos."""
|
||||
try:
|
||||
import chardet
|
||||
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read(min(65536, path.stat().st_size))
|
||||
result = chardet.detect(raw)
|
||||
if result and result.get("encoding") and result.get("confidence", 0) >= 0.6:
|
||||
return result["encoding"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for enc in ("utf-8-sig", "utf-8", "latin-1", "cp1252"):
|
||||
try:
|
||||
with open(path, encoding=enc) as f:
|
||||
f.read(4096)
|
||||
return enc
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
return "utf-8?"
|
||||
|
||||
|
||||
def _read_with_polars(path: Path, encoding: str) -> tuple[list[dict], int]:
|
||||
"""Lee CSV con polars. Retorna (cols, n_rows)."""
|
||||
import polars as pl
|
||||
|
||||
enc = encoding.rstrip("?").replace("utf-8-sig", "utf8").replace("utf-8", "utf8")
|
||||
if enc not in ("utf8", "utf-8"):
|
||||
enc = "utf8"
|
||||
|
||||
lf = pl.scan_csv(path, encoding="utf8", ignore_errors=True, infer_schema_length=1000)
|
||||
schema = lf.collect_schema()
|
||||
cols = [{"name": name, "dtype": str(dtype)} for name, dtype in schema.items()]
|
||||
n_rows = lf.select(pl.len()).collect().item()
|
||||
return cols, n_rows
|
||||
|
||||
|
||||
def _read_with_pandas(path: Path, encoding: str) -> tuple[list[dict], int]:
|
||||
"""Fallback: lee CSV con pandas."""
|
||||
import pandas as pd
|
||||
|
||||
enc = encoding.rstrip("?") or "utf-8"
|
||||
df = pd.read_csv(path, encoding=enc, encoding_errors="replace", nrows=None)
|
||||
cols = [{"name": col, "dtype": str(df[col].dtype)} for col in df.columns]
|
||||
n_rows = len(df)
|
||||
return cols, n_rows
|
||||
|
||||
|
||||
def _detect_dates(path: Path, encoding: str) -> tuple[str | None, str | None]:
|
||||
"""Intenta detectar columna de fecha y retorna (date_min, date_max) en ISO."""
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
lf = pl.scan_csv(path, encoding="utf8", ignore_errors=True, infer_schema_length=0)
|
||||
schema = lf.collect_schema()
|
||||
df = lf.collect()
|
||||
|
||||
for col_name, dtype in schema.items():
|
||||
if "Date" in str(dtype) or "Datetime" in str(dtype):
|
||||
series = df[col_name].drop_nulls()
|
||||
if len(series) > 0:
|
||||
mn = series.min()
|
||||
mx = series.max()
|
||||
return str(mn)[:10], str(mx)[:10]
|
||||
|
||||
# Intenta parsear columnas string como fecha
|
||||
for col_name, dtype in schema.items():
|
||||
if "Utf8" not in str(dtype) and "String" not in str(dtype):
|
||||
continue
|
||||
series = df[col_name].drop_nulls()
|
||||
if len(series) == 0:
|
||||
continue
|
||||
try:
|
||||
parsed = series.str.to_date(strict=False)
|
||||
valid = parsed.drop_nulls()
|
||||
if len(valid) / max(len(series), 1) >= 0.8:
|
||||
mn = valid.min()
|
||||
mx = valid.max()
|
||||
return str(mn)[:10], str(mx)[:10]
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
return None, None
|
||||
|
||||
|
||||
def _build_fts_text(path: Path, cols: list[dict], encoding: str) -> str:
|
||||
"""Construye content_text para files_fts: nombres de cols + primeras 5 filas."""
|
||||
col_names = " ".join(c["name"] for c in cols)
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
lf = pl.scan_csv(path, encoding="utf8", ignore_errors=True)
|
||||
sample = lf.head(5).collect()
|
||||
rows_text = " ".join(
|
||||
" ".join(str(v) for v in row) for row in sample.iter_rows()
|
||||
)
|
||||
return f"{col_names} {rows_text}".strip()
|
||||
except Exception:
|
||||
pass
|
||||
return col_names
|
||||
|
||||
|
||||
def vault_csv_profile(
|
||||
vault_path: str,
|
||||
rel_path: str,
|
||||
db_path: str | None = None,
|
||||
) -> dict:
|
||||
"""Perfila un CSV del vault: schema, n_rows, encoding, fechas; persiste en vault_index.db.
|
||||
|
||||
Args:
|
||||
vault_path: Ruta absoluta a la raiz del vault.
|
||||
rel_path: Ruta relativa al CSV dentro del vault.
|
||||
db_path: Override de la ruta a vault_index.db. Por defecto <vault_path>/vault_index.db.
|
||||
|
||||
Returns:
|
||||
Dict con rel_path, cols, n_rows, encoding, date_min, date_max, persisted.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Si el archivo no existe o no se puede leer.
|
||||
"""
|
||||
vault = Path(vault_path)
|
||||
csv_file = vault / rel_path
|
||||
if not csv_file.exists():
|
||||
raise RuntimeError(f"vault_csv_profile: archivo no encontrado: {csv_file}")
|
||||
|
||||
db = Path(db_path) if db_path else vault / "vault_index.db"
|
||||
|
||||
# Resultado por defecto para CSV vacío
|
||||
result: dict = {
|
||||
"rel_path": rel_path,
|
||||
"cols": [],
|
||||
"n_rows": 0,
|
||||
"encoding": "utf-8",
|
||||
"date_min": None,
|
||||
"date_max": None,
|
||||
"persisted": False,
|
||||
}
|
||||
|
||||
# Detectar encoding
|
||||
encoding = _detect_encoding(csv_file)
|
||||
result["encoding"] = encoding
|
||||
|
||||
# Leer schema y n_rows — short-circuit para archivos vacíos
|
||||
if csv_file.stat().st_size == 0:
|
||||
cols, n_rows = [], 0
|
||||
else:
|
||||
try:
|
||||
cols, n_rows = _read_with_polars(csv_file, encoding)
|
||||
except Exception:
|
||||
try:
|
||||
cols, n_rows = _read_with_pandas(csv_file, encoding)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"vault_csv_profile: no se pudo leer {rel_path}: {exc}") from exc
|
||||
|
||||
result["cols"] = cols
|
||||
result["n_rows"] = n_rows
|
||||
|
||||
# Detección de fechas (solo si hay filas)
|
||||
if n_rows > 0 and cols:
|
||||
date_min, date_max = _detect_dates(csv_file, encoding)
|
||||
result["date_min"] = date_min
|
||||
result["date_max"] = date_max
|
||||
|
||||
# Construir texto para FTS
|
||||
fts_text = _build_fts_text(csv_file, cols, encoding) if cols else ""
|
||||
|
||||
# Persistir en vault_index.db
|
||||
if db.exists():
|
||||
conn = sqlite3.connect(str(db))
|
||||
try:
|
||||
cols_json = __import__("json").dumps(cols)
|
||||
now = int(time.time())
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO csv_profiles(rel_path, cols_json, n_rows, encoding, date_min, date_max, profiled_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(rel_path) DO UPDATE SET
|
||||
cols_json=excluded.cols_json,
|
||||
n_rows=excluded.n_rows,
|
||||
encoding=excluded.encoding,
|
||||
date_min=excluded.date_min,
|
||||
date_max=excluded.date_max,
|
||||
profiled_at=excluded.profiled_at
|
||||
""",
|
||||
(rel_path, cols_json, n_rows, encoding, result["date_min"], result["date_max"], now),
|
||||
)
|
||||
# Actualizar files_fts (rowid debe coincidir con files)
|
||||
conn.execute("DELETE FROM files_fts WHERE rel_path = ?", (rel_path,))
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO files_fts(rowid, rel_path, content_text)
|
||||
VALUES ((SELECT rowid FROM files WHERE rel_path = ?), ?, ?)
|
||||
""",
|
||||
(rel_path, rel_path, fts_text),
|
||||
)
|
||||
conn.commit()
|
||||
result["persisted"] = True
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,60 @@
|
||||
---
|
||||
name: vault_pdf_extract
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def vault_pdf_extract(vault_path: str, rel_path: str, db_path: str | None = None, dump_text: bool = True) -> dict"
|
||||
description: "Extrae texto de un PDF del vault con PyMuPDF; persiste page_count y text_len en pdf_extracts; vuelca texto a .txt en data/processed/ o .vault_extracts/; actualiza files_fts para búsqueda por contenido."
|
||||
tags: [vault, pdf, extract, pymupdf, fts, datascience]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [sqlite3, time, pathlib, fitz]
|
||||
params:
|
||||
- name: vault_path
|
||||
desc: "Ruta absoluta a la raiz del vault donde vive el PDF y vault_index.db."
|
||||
- name: rel_path
|
||||
desc: "Ruta relativa al PDF dentro del vault (ej. 'docs/informe.pdf')."
|
||||
- name: db_path
|
||||
desc: "Override opcional de la ruta a vault_index.db. Por defecto <vault_path>/vault_index.db."
|
||||
- name: dump_text
|
||||
desc: "Si True (default), escribe el texto extraído a un .txt. La carpeta destino es data/processed/ si existe, si no .vault_extracts/."
|
||||
output: "Dict con: rel_path (str), page_count (int), text_len (int), extracted_to (ruta relativa al .txt o None), persisted (bool)."
|
||||
tested: true
|
||||
tests:
|
||||
- "test_pdf_extract_basic"
|
||||
- "test_pdf_dump_text_creates_file"
|
||||
- "test_pdf_no_dump"
|
||||
- "test_pdf_persists_to_fts"
|
||||
- "test_pdf_corrupt_errors"
|
||||
test_file_path: "python/functions/datascience/tests/test_vault_pdf_extract.py"
|
||||
file_path: "python/functions/datascience/vault_pdf_extract.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from vault_pdf_extract import vault_pdf_extract
|
||||
|
||||
result = vault_pdf_extract("/vaults/mi_vault", "docs/informe_anual.pdf")
|
||||
# {
|
||||
# "rel_path": "docs/informe_anual.pdf",
|
||||
# "page_count": 24,
|
||||
# "text_len": 45210,
|
||||
# "extracted_to": "data/processed/informe_anual.txt",
|
||||
# "persisted": True
|
||||
# }
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- Requiere PyMuPDF (paquete `pymupdf`, importado como `fitz`). Ya instalado en python/.venv.
|
||||
- El texto se trunca a 10 MB antes de insertarlo en files_fts para evitar tablas FTS5 masivas.
|
||||
- Layout de volcado: si `<vault_path>/data/processed/` existe, se usa; si no, se crea `<vault_path>/.vault_extracts/`.
|
||||
- PDFs corruptos levantan RuntimeError con mensaje descriptivo.
|
||||
- El rowid de files_fts se ancla al rowid de la tabla files (subquery) para que vault_search funcione correctamente.
|
||||
- Si vault_index.db no existe, retorna el dict sin intentar persistir (persisted=False).
|
||||
@@ -0,0 +1,121 @@
|
||||
"""vault_pdf_extract — Extrae texto de un PDF del vault y persiste en vault_index.db."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def vault_pdf_extract(
|
||||
vault_path: str,
|
||||
rel_path: str,
|
||||
db_path: str | None = None,
|
||||
dump_text: bool = True,
|
||||
) -> dict:
|
||||
"""Extrae texto de un PDF del vault; persiste page_count, text_len y actualiza files_fts.
|
||||
|
||||
Args:
|
||||
vault_path: Ruta absoluta a la raiz del vault.
|
||||
rel_path: Ruta relativa al PDF dentro del vault.
|
||||
db_path: Override opcional de la ruta a vault_index.db.
|
||||
dump_text: Si True, escribe el texto extraído a un .txt en data/processed/ o .vault_extracts/.
|
||||
|
||||
Returns:
|
||||
Dict con: rel_path, page_count, text_len, extracted_to (ruta relativa o None), persisted.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Si el PDF no existe, está corrupto o no se puede leer.
|
||||
"""
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"vault_pdf_extract requiere PyMuPDF. Instalar con: uv add pymupdf"
|
||||
) from exc
|
||||
|
||||
vault = Path(vault_path)
|
||||
pdf_file = vault / rel_path
|
||||
if not pdf_file.exists():
|
||||
raise RuntimeError(f"vault_pdf_extract: archivo no encontrado: {pdf_file}")
|
||||
|
||||
db = Path(db_path) if db_path else vault / "vault_index.db"
|
||||
|
||||
# Abrir PDF
|
||||
try:
|
||||
doc = fitz.open(str(pdf_file))
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"vault_pdf_extract: PDF corrupto o inválido ({rel_path}): {exc}") from exc
|
||||
|
||||
page_count = doc.page_count
|
||||
text_parts: list[str] = []
|
||||
for page in doc:
|
||||
try:
|
||||
text_parts.append(page.get_text())
|
||||
except Exception:
|
||||
text_parts.append("")
|
||||
doc.close()
|
||||
|
||||
full_text = "\n".join(text_parts)
|
||||
text_len = len(full_text)
|
||||
|
||||
# Truncar a 10 MB para FTS
|
||||
_MAX_FTS = 10 * 1024 * 1024
|
||||
fts_text = full_text[:_MAX_FTS]
|
||||
|
||||
# Dump text a disco
|
||||
extracted_to: str | None = None
|
||||
if dump_text and full_text.strip():
|
||||
basename = Path(rel_path).stem
|
||||
# Preferir data/processed/ si existe; si no, usar .vault_extracts/
|
||||
processed_dir = vault / "data" / "processed"
|
||||
if not processed_dir.exists():
|
||||
processed_dir = vault / ".vault_extracts"
|
||||
processed_dir.mkdir(parents=True, exist_ok=True)
|
||||
txt_path = processed_dir / f"{basename}.txt"
|
||||
txt_path.write_text(full_text, encoding="utf-8")
|
||||
extracted_to = str(txt_path.relative_to(vault))
|
||||
|
||||
# Persistir en vault_index.db
|
||||
persisted = False
|
||||
if db.exists():
|
||||
conn = sqlite3.connect(str(db))
|
||||
try:
|
||||
now = int(time.time())
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO pdf_extracts(rel_path, page_count, text_len, extracted_to, extracted_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(rel_path) DO UPDATE SET
|
||||
page_count=excluded.page_count,
|
||||
text_len=excluded.text_len,
|
||||
extracted_to=excluded.extracted_to,
|
||||
extracted_at=excluded.extracted_at
|
||||
""",
|
||||
(rel_path, page_count, text_len, extracted_to, now),
|
||||
)
|
||||
# Actualizar files_fts (rowid debe coincidir con files)
|
||||
conn.execute("DELETE FROM files_fts WHERE rel_path = ?", (rel_path,))
|
||||
if fts_text.strip():
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO files_fts(rowid, rel_path, content_text)
|
||||
VALUES ((SELECT rowid FROM files WHERE rel_path = ?), ?, ?)
|
||||
""",
|
||||
(rel_path, rel_path, fts_text),
|
||||
)
|
||||
conn.commit()
|
||||
persisted = True
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
"rel_path": rel_path,
|
||||
"page_count": page_count,
|
||||
"text_len": text_len,
|
||||
"extracted_to": extracted_to,
|
||||
"persisted": persisted,
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
"""Tests para vault_dedupe_report."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from vault_dedupe_report import vault_dedupe_report
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_db(tmp_path: Path, rows: list[tuple]) -> Path:
|
||||
"""Crea vault_index.db con la tabla files y las filas dadas.
|
||||
|
||||
rows: lista de (rel_path, size, sha256)
|
||||
"""
|
||||
db_path = tmp_path / "vault_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE files (
|
||||
rel_path TEXT PRIMARY KEY,
|
||||
size INTEGER,
|
||||
mtime REAL,
|
||||
sha256 TEXT,
|
||||
mime TEXT,
|
||||
ext TEXT,
|
||||
bucket TEXT,
|
||||
sub_bucket TEXT,
|
||||
indexed_at REAL
|
||||
);
|
||||
"""
|
||||
)
|
||||
conn.executemany(
|
||||
"INSERT INTO files (rel_path, size, sha256) VALUES (?, ?, ?);",
|
||||
rows,
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_no_duplicates(tmp_path):
|
||||
"""test_no_duplicates — 3 archivos con sha256 distintos -> groups=[]."""
|
||||
_make_db(tmp_path, [
|
||||
("a/file1.txt", 100, "aaa111"),
|
||||
("a/file2.txt", 200, "bbb222"),
|
||||
("a/file3.txt", 300, "ccc333"),
|
||||
])
|
||||
result = vault_dedupe_report(str(tmp_path), db_path=str(tmp_path / "vault_index.db"))
|
||||
|
||||
assert result["groups"] == []
|
||||
assert result["total_groups"] == 0
|
||||
assert result["total_duplicates"] == 0
|
||||
assert result["total_reclaimable_bytes"] == 0
|
||||
assert result["scanned_files"] == 3
|
||||
assert result["vault_path"] == str(tmp_path)
|
||||
|
||||
|
||||
def test_basic_duplicates(tmp_path):
|
||||
"""test_basic_duplicates — 2 archivos mismo sha256 -> 1 group, count=2, reclaimable=size."""
|
||||
_make_db(tmp_path, [
|
||||
("data/orig.jpg", 500, "deadbeef"),
|
||||
("backup/orig.jpg", 500, "deadbeef"),
|
||||
])
|
||||
result = vault_dedupe_report(str(tmp_path), db_path=str(tmp_path / "vault_index.db"))
|
||||
|
||||
assert result["total_groups"] == 1
|
||||
assert result["total_duplicates"] == 1
|
||||
assert result["total_reclaimable_bytes"] == 500
|
||||
|
||||
g = result["groups"][0]
|
||||
assert g["sha256"] == "deadbeef"
|
||||
assert g["size"] == 500
|
||||
assert g["count"] == 2
|
||||
assert g["reclaimable_bytes"] == 500
|
||||
assert sorted(g["files"]) == ["backup/orig.jpg", "data/orig.jpg"]
|
||||
|
||||
|
||||
def test_three_in_group(tmp_path):
|
||||
"""test_three_in_group — 3 archivos mismo sha256 -> count=3, reclaimable=size*2."""
|
||||
size = 1000
|
||||
_make_db(tmp_path, [
|
||||
("a/f1.bin", size, "cafebabe"),
|
||||
("b/f2.bin", size, "cafebabe"),
|
||||
("c/f3.bin", size, "cafebabe"),
|
||||
])
|
||||
result = vault_dedupe_report(str(tmp_path), db_path=str(tmp_path / "vault_index.db"))
|
||||
|
||||
assert result["total_groups"] == 1
|
||||
assert result["total_duplicates"] == 2
|
||||
assert result["total_reclaimable_bytes"] == size * 2
|
||||
|
||||
g = result["groups"][0]
|
||||
assert g["count"] == 3
|
||||
assert g["reclaimable_bytes"] == size * 2
|
||||
assert g["files"] == sorted(["a/f1.bin", "b/f2.bin", "c/f3.bin"])
|
||||
|
||||
|
||||
def test_min_size_filter(tmp_path):
|
||||
"""test_min_size_filter — duplicados de tamano 50, min_size=100 -> groups=[]."""
|
||||
_make_db(tmp_path, [
|
||||
("x/small1.txt", 50, "tiny123"),
|
||||
("y/small2.txt", 50, "tiny123"),
|
||||
])
|
||||
result = vault_dedupe_report(
|
||||
str(tmp_path),
|
||||
min_size=100,
|
||||
db_path=str(tmp_path / "vault_index.db"),
|
||||
)
|
||||
|
||||
assert result["groups"] == []
|
||||
assert result["total_groups"] == 0
|
||||
assert result["total_reclaimable_bytes"] == 0
|
||||
assert result["scanned_files"] == 0
|
||||
|
||||
|
||||
def test_multiple_groups_ordered(tmp_path):
|
||||
"""test_multiple_groups_ordered — 2 grupos con distinto ahorro -> orden DESC."""
|
||||
# grupo A: 2 copias de 200 bytes -> reclaimable=200
|
||||
# grupo B: 3 copias de 500 bytes -> reclaimable=1000
|
||||
# el grupo B debe salir primero
|
||||
_make_db(tmp_path, [
|
||||
("p/a1.dat", 200, "groupA"),
|
||||
("q/a2.dat", 200, "groupA"),
|
||||
("r/b1.dat", 500, "groupB"),
|
||||
("s/b2.dat", 500, "groupB"),
|
||||
("t/b3.dat", 500, "groupB"),
|
||||
("u/uniq.dat", 999, "unique1"),
|
||||
])
|
||||
result = vault_dedupe_report(str(tmp_path), db_path=str(tmp_path / "vault_index.db"))
|
||||
|
||||
assert result["total_groups"] == 2
|
||||
assert result["total_duplicates"] == 3 # (2-1) + (3-1)
|
||||
assert result["total_reclaimable_bytes"] == 1200 # 200 + 1000
|
||||
assert result["scanned_files"] == 6 # 6 filas con sha256 != '' (incluye el unico)
|
||||
|
||||
# Primer grupo debe ser el de mayor ahorro (B: 1000)
|
||||
assert result["groups"][0]["sha256"] == "groupB"
|
||||
assert result["groups"][0]["reclaimable_bytes"] == 1000
|
||||
assert result["groups"][1]["sha256"] == "groupA"
|
||||
assert result["groups"][1]["reclaimable_bytes"] == 200
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Tests para vault_knowledge_parse."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from vault_knowledge_parse import vault_knowledge_parse
|
||||
|
||||
|
||||
def _make_vault(tmp: Path) -> tuple[Path, Path]:
|
||||
"""Crea un vault mínimo con vault_index.db."""
|
||||
db = tmp / "vault_index.db"
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
rel_path TEXT UNIQUE NOT NULL,
|
||||
size_bytes INTEGER,
|
||||
ext TEXT
|
||||
);
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS files_fts
|
||||
USING fts5(rel_path, content_text, content='', contentless_delete=1);
|
||||
CREATE TABLE IF NOT EXISTS knowledge_docs (
|
||||
rel_path TEXT PRIMARY KEY,
|
||||
title TEXT,
|
||||
frontmatter_json TEXT,
|
||||
headings_json TEXT,
|
||||
parsed_at INTEGER
|
||||
);
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return tmp, db
|
||||
|
||||
|
||||
def _insert_file_entry(db: Path, rel_path: str):
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO files(rel_path, size_bytes, ext) VALUES (?, 0, '.md')",
|
||||
(rel_path,),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_md_with_frontmatter(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/guia.md"
|
||||
md = vault / rel
|
||||
md.parent.mkdir(parents=True, exist_ok=True)
|
||||
md.write_text(
|
||||
"---\ntitle: Mi Guía\nauthor: Lucas\n---\n\n# Mi Guía\n\nContenido del documento.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_knowledge_parse(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["title"] == "Mi Guía"
|
||||
assert result["frontmatter"]["author"] == "Lucas"
|
||||
assert "Contenido del documento" in result["content_text"]
|
||||
assert result["persisted"] is True
|
||||
|
||||
|
||||
def test_md_no_frontmatter(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/sin_fm.md"
|
||||
md = vault / rel
|
||||
md.parent.mkdir(parents=True, exist_ok=True)
|
||||
md.write_text("# Título\n\nCuerpo sin frontmatter.\n", encoding="utf-8")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_knowledge_parse(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["frontmatter"] == {}
|
||||
assert result["title"] == "Título"
|
||||
assert "Cuerpo sin frontmatter" in result["content_text"]
|
||||
|
||||
|
||||
def test_md_title_from_h1(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/title_h1.md"
|
||||
md = vault / rel
|
||||
md.parent.mkdir(parents=True, exist_ok=True)
|
||||
md.write_text("# Primer H1\n\nAlgún texto.\n", encoding="utf-8")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_knowledge_parse(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["title"] == "Primer H1"
|
||||
|
||||
|
||||
def test_md_title_from_filename(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/nombre_archivo.md"
|
||||
md = vault / rel
|
||||
md.parent.mkdir(parents=True, exist_ok=True)
|
||||
md.write_text("Solo texto sin headings ni frontmatter.\n", encoding="utf-8")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_knowledge_parse(str(vault), rel, db_path=str(db))
|
||||
|
||||
assert result["title"] == "nombre_archivo"
|
||||
|
||||
|
||||
def test_md_headings_levels(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/headings.md"
|
||||
md = vault / rel
|
||||
md.parent.mkdir(parents=True, exist_ok=True)
|
||||
md.write_text(
|
||||
"# H1 Título\n\nTexto.\n\n## H2 Sección\n\n### H3 Subsección\n\n## H2 Otra\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
result = vault_knowledge_parse(str(vault), rel, db_path=str(db))
|
||||
|
||||
headings = result["headings"]
|
||||
assert len(headings) == 4
|
||||
levels = [h["level"] for h in headings]
|
||||
assert levels == [1, 2, 3, 2]
|
||||
texts = [h["text"] for h in headings]
|
||||
assert "H1 Título" in texts
|
||||
assert "H2 Sección" in texts
|
||||
assert "H3 Subsección" in texts
|
||||
|
||||
|
||||
def test_md_persists_to_fts(tmp_path):
|
||||
vault, db = _make_vault(tmp_path)
|
||||
rel = "docs/fts_md.md"
|
||||
md = vault / rel
|
||||
md.parent.mkdir(parents=True, exist_ok=True)
|
||||
md.write_text("# Documento FTS\n\nPalabra clave: xenolito.\n", encoding="utf-8")
|
||||
_insert_file_entry(db, rel)
|
||||
|
||||
vault_knowledge_parse(str(vault), rel, db_path=str(db))
|
||||
|
||||
conn = sqlite3.connect(str(db))
|
||||
# FTS5 contentless: no permite SELECT directo, usar MATCH
|
||||
row = conn.execute(
|
||||
"SELECT rowid FROM files_fts WHERE files_fts MATCH 'xenolito'",
|
||||
).fetchone()
|
||||
conn.close()
|
||||
assert row is not None, "FTS no encontró 'xenolito'"
|
||||
@@ -0,0 +1,57 @@
|
||||
---
|
||||
name: vault_dedupe_report
|
||||
kind: function
|
||||
lang: py
|
||||
domain: infra
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def vault_dedupe_report(vault_path: str, min_size: int = 0, db_path: str | None = None) -> dict"
|
||||
description: "Detecta archivos duplicados en un vault leyendo vault_index.db (agrupando por sha256) y calcula el espacio recuperable. Retorna grupos ordenados por bytes recuperables DESC."
|
||||
tags: [vault, dedupe, duplicates, disk, sha256, sqlite]
|
||||
params:
|
||||
- name: vault_path
|
||||
desc: "Ruta raiz del vault. Usada como clave en el resultado y para localizar vault_index.db cuando db_path es None."
|
||||
- name: min_size
|
||||
desc: "Tamanio minimo en bytes para incluir un archivo en el analisis. Default 0 = todos los archivos."
|
||||
- name: db_path
|
||||
desc: "Override opcional de la ruta a vault_index.db. Si es None se usa <vault_path>/vault_index.db."
|
||||
output: "dict con vault_path, groups (sha256/size/count/files/reclaimable_bytes), total_groups, total_duplicates, total_reclaimable_bytes, scanned_files. groups ordenados por reclaimable_bytes DESC."
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_py_core"
|
||||
imports: ["sqlite3", "pathlib"]
|
||||
tested: true
|
||||
tests:
|
||||
- "test_no_duplicates"
|
||||
- "test_basic_duplicates"
|
||||
- "test_three_in_group"
|
||||
- "test_min_size_filter"
|
||||
- "test_multiple_groups_ordered"
|
||||
test_file_path: "python/functions/infra/tests/test_vault_dedupe_report.py"
|
||||
file_path: "python/functions/infra/vault_dedupe_report.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from infra.vault_dedupe_report import vault_dedupe_report
|
||||
|
||||
report = vault_dedupe_report("/data/vaults/my_vault", min_size=1024)
|
||||
print(f"Grupos duplicados: {report['total_groups']}")
|
||||
print(f"Espacio recuperable: {report['total_reclaimable_bytes'] // (1024**2)} MB")
|
||||
|
||||
for g in report["groups"][:5]:
|
||||
print(f" sha256={g['sha256'][:12]}... size={g['size']} count={g['count']}")
|
||||
for f in g["files"]:
|
||||
print(f" {f}")
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- Solo considera filas con `sha256 != ''` (archivos efectivamente hasheados por `vault_inventory_scan_go_infra`).
|
||||
- Abre la BD en modo read-only (`?mode=ro`) para no interferir con escrituras concurrentes.
|
||||
- `GROUP_CONCAT` de SQLite no garantiza orden — los `files` se reordenan lexicograficamente en Python.
|
||||
- Si la BD no existe o le falta la tabla `files`, lanza `RuntimeError` con mensaje orientativo.
|
||||
- Prerequisito: haber corrido `fn vault index <name>` (pipeline `vault_inventory_scan_go_infra` + `vault_index_write_go_infra`) sobre el vault.
|
||||
@@ -0,0 +1,122 @@
|
||||
"""vault_dedupe_report — Detecta duplicados en vault_index.db y calcula espacio recuperable."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def vault_dedupe_report(
|
||||
vault_path: str,
|
||||
min_size: int = 0,
|
||||
db_path: str | None = None,
|
||||
) -> dict:
|
||||
"""Detecta archivos duplicados en un vault a partir de su vault_index.db.
|
||||
|
||||
Lee la tabla ``files`` de ``vault_index.db`` agrupando por ``sha256`` y
|
||||
retorna todos los grupos con mas de un archivo, ordenados por bytes
|
||||
recuperables de mayor a menor.
|
||||
|
||||
Args:
|
||||
vault_path: Ruta raiz del vault. Usada como clave en el resultado y
|
||||
para localizar ``vault_index.db`` cuando ``db_path`` es None.
|
||||
min_size: Ignora archivos cuyo ``size`` (bytes) sea menor que este
|
||||
valor. Default 0 = incluir todos los archivos.
|
||||
db_path: Ruta absoluta o relativa a la BD SQLite. Si es None se
|
||||
usa ``<vault_path>/vault_index.db``.
|
||||
|
||||
Returns:
|
||||
dict con las claves:
|
||||
- ``vault_path``: str — mismo valor recibido.
|
||||
- ``groups``: list de dicts, cada uno con:
|
||||
- ``sha256``: str
|
||||
- ``size``: int — tamanio en bytes de cada copia
|
||||
- ``count``: int — numero de copias encontradas
|
||||
- ``files``: list[str] — rel_paths ordenados lexicograficamente
|
||||
- ``reclaimable_bytes``: int — ``size * (count - 1)``
|
||||
- ``total_groups``: int — numero de grupos con duplicados
|
||||
- ``total_duplicates``: int — suma de ``(count - 1)`` por grupo
|
||||
- ``total_reclaimable_bytes``: int — bytes totales recuperables
|
||||
- ``scanned_files``: int — total de filas consideradas en la query
|
||||
|
||||
Raises:
|
||||
RuntimeError: Si la BD no existe, no tiene tabla ``files``, o hay
|
||||
algun error de lectura.
|
||||
"""
|
||||
resolved_db = db_path if db_path is not None else str(Path(vault_path) / "vault_index.db")
|
||||
|
||||
db_file = Path(resolved_db)
|
||||
if not db_file.exists():
|
||||
raise RuntimeError(
|
||||
f"No se encontro vault_index.db en '{resolved_db}'. "
|
||||
"Corre 'fn vault index <name>' primero."
|
||||
)
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(f"file:{resolved_db}?mode=ro", uri=True)
|
||||
except sqlite3.OperationalError as exc:
|
||||
raise RuntimeError(f"No se pudo abrir '{resolved_db}': {exc}") from exc
|
||||
|
||||
try:
|
||||
# Verificar que existe la tabla files
|
||||
cur = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='files';"
|
||||
)
|
||||
if cur.fetchone() is None:
|
||||
raise RuntimeError(
|
||||
f"vault_index.db sin tabla 'files'. "
|
||||
"Corre 'fn vault index <name>' primero."
|
||||
)
|
||||
|
||||
# Contar filas totales consideradas (sha256 no vacio, size >= min_size)
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) FROM files WHERE size >= ? AND sha256 != '';",
|
||||
(min_size,),
|
||||
).fetchone()
|
||||
scanned_files: int = row[0] if row else 0
|
||||
|
||||
# Query principal: grupos con mas de una copia
|
||||
query = """
|
||||
SELECT
|
||||
sha256,
|
||||
size,
|
||||
COUNT(*) AS cnt,
|
||||
GROUP_CONCAT(rel_path) AS paths
|
||||
FROM files
|
||||
WHERE size >= ? AND sha256 != ''
|
||||
GROUP BY sha256
|
||||
HAVING COUNT(*) > 1
|
||||
ORDER BY size * (COUNT(*) - 1) DESC;
|
||||
"""
|
||||
rows = conn.execute(query, (min_size,)).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
groups: list[dict] = []
|
||||
total_duplicates = 0
|
||||
total_reclaimable_bytes = 0
|
||||
|
||||
for sha256, size, cnt, paths_concat in rows:
|
||||
# GROUP_CONCAT no garantiza orden — ordenar lexicograficamente
|
||||
files = sorted(paths_concat.split(","))
|
||||
reclaimable = size * (cnt - 1)
|
||||
groups.append(
|
||||
{
|
||||
"sha256": sha256,
|
||||
"size": size,
|
||||
"count": cnt,
|
||||
"files": files,
|
||||
"reclaimable_bytes": reclaimable,
|
||||
}
|
||||
)
|
||||
total_duplicates += cnt - 1
|
||||
total_reclaimable_bytes += reclaimable
|
||||
|
||||
return {
|
||||
"vault_path": vault_path,
|
||||
"groups": groups,
|
||||
"total_groups": len(groups),
|
||||
"total_duplicates": total_duplicates,
|
||||
"total_reclaimable_bytes": total_reclaimable_bytes,
|
||||
"scanned_files": scanned_files,
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
---
|
||||
name: vault_knowledge_parse
|
||||
kind: function
|
||||
lang: py
|
||||
domain: infra
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def vault_knowledge_parse(vault_path: str, rel_path: str, db_path: str | None = None) -> dict"
|
||||
description: "Parsea un archivo Markdown del vault: extrae YAML frontmatter, título, headings y cuerpo; persiste en knowledge_docs y actualiza files_fts para búsqueda por contenido."
|
||||
tags: [vault, markdown, knowledge, frontmatter, headings, fts, infra]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [json, re, sqlite3, time, pathlib, yaml]
|
||||
params:
|
||||
- name: vault_path
|
||||
desc: "Ruta absoluta a la raiz del vault donde vive el Markdown y vault_index.db."
|
||||
- name: rel_path
|
||||
desc: "Ruta relativa al archivo .md dentro del vault (ej. 'docs/guia.md')."
|
||||
- name: db_path
|
||||
desc: "Override opcional de la ruta a vault_index.db. Por defecto <vault_path>/vault_index.db."
|
||||
output: "Dict con: rel_path (str), title (str), frontmatter (dict), headings (list de {level, text}), content_text (str cuerpo sin frontmatter), persisted (bool)."
|
||||
tested: true
|
||||
tests:
|
||||
- "test_md_with_frontmatter"
|
||||
- "test_md_no_frontmatter"
|
||||
- "test_md_title_from_h1"
|
||||
- "test_md_title_from_filename"
|
||||
- "test_md_headings_levels"
|
||||
- "test_md_persists_to_fts"
|
||||
test_file_path: "python/functions/infra/tests/test_vault_knowledge_parse.py"
|
||||
file_path: "python/functions/infra/vault_knowledge_parse.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from vault_knowledge_parse import vault_knowledge_parse
|
||||
|
||||
result = vault_knowledge_parse("/vaults/mi_vault", "docs/guia_operaciones.md")
|
||||
# {
|
||||
# "rel_path": "docs/guia_operaciones.md",
|
||||
# "title": "Guía de Operaciones",
|
||||
# "frontmatter": {"author": "Lucas", "tags": ["ops"]},
|
||||
# "headings": [{"level": 1, "text": "Guía de Operaciones"}, {"level": 2, "text": "Instalación"}],
|
||||
# "content_text": "# Guía de Operaciones\n\n## Instalación\n...",
|
||||
# "persisted": True
|
||||
# }
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- Prioridad de título: frontmatter["title"] > primer H1 en el cuerpo > basename sin extensión.
|
||||
- Frontmatter YAML delimitado por `---\n` al inicio del archivo. Si no hay frontmatter, se retorna {}.
|
||||
- content_text es el cuerpo completo sin el bloque frontmatter (incluye los headings H1-H6).
|
||||
- El rowid de files_fts se ancla al rowid de la tabla files para que vault_search funcione correctamente.
|
||||
- Si vault_index.db no existe, retorna el dict sin intentar persistir (persisted=False).
|
||||
- Dependencias: pyyaml (ya instalado en python/.venv).
|
||||
@@ -0,0 +1,142 @@
|
||||
"""vault_knowledge_parse — Parsea un Markdown del vault y persiste en knowledge_docs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _parse_frontmatter(text: str) -> tuple[dict, str]:
|
||||
"""Separa YAML frontmatter del cuerpo. Retorna (frontmatter_dict, body)."""
|
||||
if not text.startswith("---\n") and not text.startswith("---\r\n"):
|
||||
return {}, text
|
||||
|
||||
# Buscar cierre del frontmatter
|
||||
end = text.find("\n---", 4)
|
||||
if end == -1:
|
||||
return {}, text
|
||||
|
||||
yaml_block = text[4:end].strip()
|
||||
body = text[end + 4:].lstrip("\n\r")
|
||||
|
||||
try:
|
||||
import yaml
|
||||
|
||||
fm = yaml.safe_load(yaml_block) or {}
|
||||
if not isinstance(fm, dict):
|
||||
fm = {}
|
||||
except Exception:
|
||||
fm = {}
|
||||
|
||||
return fm, body
|
||||
|
||||
|
||||
def _extract_headings(body: str) -> list[dict]:
|
||||
"""Extrae headings Markdown (# ... ### ...) del cuerpo."""
|
||||
headings = []
|
||||
for line in body.splitlines():
|
||||
m = re.match(r"^(#{1,6})\s+(.*)", line)
|
||||
if m:
|
||||
headings.append({"level": len(m.group(1)), "text": m.group(2).strip()})
|
||||
return headings
|
||||
|
||||
|
||||
def _extract_title(frontmatter: dict, body: str, basename: str) -> str:
|
||||
"""Extrae título: frontmatter['title'] > primer H1 > basename."""
|
||||
if frontmatter.get("title"):
|
||||
return str(frontmatter["title"])
|
||||
for line in body.splitlines():
|
||||
m = re.match(r"^#\s+(.*)", line)
|
||||
if m:
|
||||
return m.group(1).strip()
|
||||
return basename
|
||||
|
||||
|
||||
def vault_knowledge_parse(
|
||||
vault_path: str,
|
||||
rel_path: str,
|
||||
db_path: str | None = None,
|
||||
) -> dict:
|
||||
"""Parsea un archivo Markdown del vault: extrae frontmatter, título, headings y cuerpo.
|
||||
|
||||
Args:
|
||||
vault_path: Ruta absoluta a la raiz del vault.
|
||||
rel_path: Ruta relativa al archivo Markdown dentro del vault.
|
||||
db_path: Override opcional de la ruta a vault_index.db.
|
||||
|
||||
Returns:
|
||||
Dict con: rel_path, title, frontmatter, headings, content_text, persisted.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Si el archivo no existe o no se puede leer.
|
||||
"""
|
||||
vault = Path(vault_path)
|
||||
md_file = vault / rel_path
|
||||
if not md_file.exists():
|
||||
raise RuntimeError(f"vault_knowledge_parse: archivo no encontrado: {md_file}")
|
||||
|
||||
db = Path(db_path) if db_path else vault / "vault_index.db"
|
||||
|
||||
try:
|
||||
text = md_file.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
text = md_file.read_text(encoding="latin-1", errors="replace")
|
||||
|
||||
frontmatter, body = _parse_frontmatter(text)
|
||||
headings = _extract_headings(body)
|
||||
basename = md_file.stem
|
||||
title = _extract_title(frontmatter, body, basename)
|
||||
content_text = body
|
||||
|
||||
# Persistir en vault_index.db
|
||||
persisted = False
|
||||
if db.exists():
|
||||
conn = sqlite3.connect(str(db))
|
||||
try:
|
||||
now = int(time.time())
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO knowledge_docs(rel_path, title, frontmatter_json, headings_json, parsed_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(rel_path) DO UPDATE SET
|
||||
title=excluded.title,
|
||||
frontmatter_json=excluded.frontmatter_json,
|
||||
headings_json=excluded.headings_json,
|
||||
parsed_at=excluded.parsed_at
|
||||
""",
|
||||
(
|
||||
rel_path,
|
||||
title,
|
||||
json.dumps(frontmatter, ensure_ascii=False),
|
||||
json.dumps(headings, ensure_ascii=False),
|
||||
now,
|
||||
),
|
||||
)
|
||||
# Actualizar files_fts (rowid debe coincidir con files)
|
||||
conn.execute("DELETE FROM files_fts WHERE rel_path = ?", (rel_path,))
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO files_fts(rowid, rel_path, content_text)
|
||||
VALUES ((SELECT rowid FROM files WHERE rel_path = ?), ?, ?)
|
||||
""",
|
||||
(rel_path, rel_path, content_text),
|
||||
)
|
||||
conn.commit()
|
||||
persisted = True
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
"rel_path": rel_path,
|
||||
"title": title,
|
||||
"frontmatter": frontmatter,
|
||||
"headings": headings,
|
||||
"content_text": content_text,
|
||||
"persisted": persisted,
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
---
|
||||
name: vault_profile_dispatch
|
||||
kind: function
|
||||
lang: py
|
||||
domain: infra
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def vault_profile_dispatch(vault_path: str, rel_path: str, kind: str, db_path: str | None = None) -> dict"
|
||||
description: "CLI dispatcher que enruta un archivo del vault al profiler correcto segun su tipo (csv/pdf/md). Thin wrapper sobre vault_csv_profile, vault_pdf_extract y vault_knowledge_parse. Usable desde Go via os/exec para procesar archivos en bulk."
|
||||
tags: [vault, profile, dispatch, profiler, csv, pdf, md, infra]
|
||||
uses_functions:
|
||||
- vault_csv_profile_py_datascience
|
||||
- vault_pdf_extract_py_datascience
|
||||
- vault_knowledge_parse_py_infra
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: []
|
||||
params:
|
||||
- name: vault_path
|
||||
desc: "Ruta absoluta a la raiz del vault."
|
||||
- name: rel_path
|
||||
desc: "Ruta relativa del archivo dentro del vault."
|
||||
- name: kind
|
||||
desc: "Tipo de profiler: csv | pdf | md."
|
||||
- name: db_path
|
||||
desc: "Override de la ruta a vault_index.db. Default: <vault_path>/vault_index.db."
|
||||
output: "Dict con resultado del profiler correspondiente. Para csv: {rel_path, cols, n_rows, encoding, date_min, date_max, persisted}. Para pdf: {rel_path, page_count, text_len, extracted_to, persisted}. Para md: resultado de vault_knowledge_parse."
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/infra/vault_profile_dispatch.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```bash
|
||||
# Desde CLI
|
||||
python3 python/functions/infra/vault_profile_dispatch.py \
|
||||
--vault /home/lucas/vaults/turismo_spain \
|
||||
--rel-path data/raw/report.csv \
|
||||
--kind csv
|
||||
|
||||
# Desde Go via os/exec (patron usado en fn vault profile)
|
||||
python3 vault_profile_dispatch.py --vault <path> --rel-path <p> --kind csv
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Disenado para ser invocado desde Go via `os/exec`. Imprime resultado como JSON a stdout.
|
||||
Codigos de salida: 0=exito, 1=args faltantes, 2=kind desconocido, 3=error del profiler.
|
||||
|
||||
Detecta automaticamente el PYTHONPATH mirando `FN_REGISTRY_ROOT` o subiendo desde su propia ubicacion.
|
||||
@@ -0,0 +1,92 @@
|
||||
"""vault_profile_dispatch — CLI dispatcher that routes a single vault file to the right profiler.
|
||||
|
||||
Usage:
|
||||
python3 vault_profile_dispatch.py --vault <path> --rel-path <p> --kind csv|pdf|md [--db-path <p>]
|
||||
|
||||
Exit codes:
|
||||
0 success (result printed as JSON)
|
||||
1 missing required argument
|
||||
2 unknown kind
|
||||
3 profiler raised an error
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _python_path_setup() -> None:
|
||||
"""Ensure the registry python/functions directory is on sys.path."""
|
||||
# Try FN_REGISTRY_ROOT env first, then walk up from this file's location.
|
||||
registry_root = os.environ.get("FN_REGISTRY_ROOT", "")
|
||||
if not registry_root:
|
||||
# This file lives at python/functions/infra/vault_profile_dispatch.py
|
||||
# So the registry root is four levels up from __file__.
|
||||
candidate = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if (candidate / "go.mod").exists():
|
||||
registry_root = str(candidate)
|
||||
|
||||
if registry_root:
|
||||
fn_path = str(Path(registry_root) / "python" / "functions")
|
||||
if fn_path not in sys.path:
|
||||
sys.path.insert(0, fn_path)
|
||||
|
||||
|
||||
def dispatch(vault_path: str, rel_path: str, kind: str, db_path: str | None) -> dict:
|
||||
"""Call the appropriate profiler based on kind."""
|
||||
if kind == "csv":
|
||||
from datascience.vault_csv_profile import vault_csv_profile
|
||||
return vault_csv_profile(vault_path, rel_path, db_path)
|
||||
elif kind == "pdf":
|
||||
from datascience.vault_pdf_extract import vault_pdf_extract
|
||||
return vault_pdf_extract(vault_path, rel_path, db_path)
|
||||
elif kind == "md":
|
||||
from infra.vault_knowledge_parse import vault_knowledge_parse
|
||||
return vault_knowledge_parse(vault_path, rel_path, db_path)
|
||||
else:
|
||||
raise ValueError(f"unknown kind: {kind!r} (expected csv, pdf, or md)")
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
_python_path_setup()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="vault_profile_dispatch",
|
||||
description="Route a single vault file to the right profiler (csv/pdf/md).",
|
||||
)
|
||||
parser.add_argument("--vault", required=True, help="Absolute path to vault root")
|
||||
parser.add_argument("--rel-path", required=True, dest="rel_path", help="Relative path of file inside vault")
|
||||
parser.add_argument(
|
||||
"--kind",
|
||||
required=True,
|
||||
choices=["csv", "pdf", "md"],
|
||||
help="Profiler kind: csv | pdf | md",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-path",
|
||||
dest="db_path",
|
||||
default=None,
|
||||
help="Override path to vault_index.db (default: <vault>/vault_index.db)",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
try:
|
||||
result = dispatch(args.vault, args.rel_path, args.kind, args.db_path)
|
||||
except ValueError as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 2
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 3
|
||||
|
||||
print(json.dumps(result, indent=2, default=str))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1 @@
|
||||
"""ml — tipos y funciones de generacion de imagenes con modelos de difusion."""
|
||||
@@ -0,0 +1,67 @@
|
||||
---
|
||||
name: cuda_available
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def cuda_available() -> dict"
|
||||
description: "Detecta si CUDA esta disponible via torch. Devuelve device_count, nombres de GPU y version de CUDA. Si torch no esta instalado, retorna available=False sin lanzar excepcion."
|
||||
tags: [cuda, gpu, torch, pytorch, hardware, probe, ml, device]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: []
|
||||
params: []
|
||||
output: "dict con claves: available (bool), device_count (int), devices (list[str] con nombres de GPU), torch_version (str o 'not_installed'), cuda_version (str | None)"
|
||||
tested: true
|
||||
tests:
|
||||
- "sin torch retorna available=False y torch_version=not_installed"
|
||||
- "con torch sin cuda retorna available=False y device_count=0"
|
||||
- "claves del dict siempre presentes"
|
||||
test_file_path: "python/functions/ml/tests/test_cuda_available.py"
|
||||
file_path: "python/functions/ml/cuda_available.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.cuda_available import cuda_available
|
||||
|
||||
info = cuda_available()
|
||||
# Sin GPU:
|
||||
# {
|
||||
# "available": False,
|
||||
# "device_count": 0,
|
||||
# "devices": [],
|
||||
# "torch_version": "2.3.0",
|
||||
# "cuda_version": None
|
||||
# }
|
||||
|
||||
# Con GPU:
|
||||
# {
|
||||
# "available": True,
|
||||
# "device_count": 1,
|
||||
# "devices": ["NVIDIA RTX 4090"],
|
||||
# "torch_version": "2.3.0",
|
||||
# "cuda_version": "12.1"
|
||||
# }
|
||||
|
||||
# Sin torch instalado:
|
||||
# {
|
||||
# "available": False,
|
||||
# "device_count": 0,
|
||||
# "devices": [],
|
||||
# "torch_version": "not_installed",
|
||||
# "cuda_version": None
|
||||
# }
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- Nunca lanza ImportError aunque torch no este instalado.
|
||||
- `cuda_version` es la version de CUDA con la que fue compilado torch, no necesariamente la del sistema.
|
||||
- Usar junto a `torch_device_select` para elegir device y `gpu_info` para estadisticas de VRAM.
|
||||
- impure: depende del estado del hardware y de librerias del sistema en tiempo de ejecucion.
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Detecta disponibilidad de CUDA via torch sin lanzar excepcion si torch no esta instalado."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def cuda_available() -> dict:
|
||||
"""Detecta si CUDA esta disponible y devuelve info de los dispositivos GPU.
|
||||
|
||||
No requiere torch instalado: si no esta presente, devuelve
|
||||
`torch_version='not_installed'` y `available=False`.
|
||||
|
||||
Returns:
|
||||
dict con claves:
|
||||
available (bool): True si torch.cuda.is_available().
|
||||
device_count (int): numero de GPUs detectadas (0 si no hay CUDA).
|
||||
devices (list[str]): nombres de cada GPU (ej. "NVIDIA RTX 4090").
|
||||
torch_version (str): version de torch o "not_installed".
|
||||
cuda_version (str | None): version de CUDA usada por torch, o None.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
return {
|
||||
"available": False,
|
||||
"device_count": 0,
|
||||
"devices": [],
|
||||
"torch_version": "not_installed",
|
||||
"cuda_version": None,
|
||||
}
|
||||
|
||||
available = torch.cuda.is_available()
|
||||
device_count = torch.cuda.device_count() if available else 0
|
||||
devices = [torch.cuda.get_device_name(i) for i in range(device_count)]
|
||||
cuda_version = torch.version.cuda if available else None
|
||||
|
||||
return {
|
||||
"available": available,
|
||||
"device_count": device_count,
|
||||
"devices": devices,
|
||||
"torch_version": torch.__version__,
|
||||
"cuda_version": cuda_version,
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
---
|
||||
name: diffusers_generate
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def diffusers_generate(pipe: Any, cfg: GenerationConfig) -> ImageGenResult"
|
||||
description: "Ejecuta inferencia con un pipeline diffusers usando GenerationConfig. Mide duracion y pico de VRAM. Retorna ImageGenResult con imagen PIL, meta y metricas."
|
||||
tags: [diffusers, ml, image-generation, inference, vram, metrics]
|
||||
uses_functions: [genconfig_to_diffusers_kwargs_py_ml]
|
||||
uses_types: [generation_config_py_ml, image_gen_result_py_ml]
|
||||
returns: [image_gen_result_py_ml]
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [torch, diffusers]
|
||||
params:
|
||||
- name: pipe
|
||||
desc: "Pipeline diffusers cargado y listo para inferencia (resultado de diffusers_load_pipeline, opcionalmente con scheduler y LoRA configurados)."
|
||||
- name: cfg
|
||||
desc: "Parametros de generacion. cfg.seed >= 0 para semilla fija; -1 usa time-based. cfg.sampler se incluye en meta pero no se aplica aqui (usar diffusers_set_scheduler antes)."
|
||||
output: "ImageGenResult con image=PIL.Image.Image, meta={backend, model, sampler, actual_steps, seed, width, height, cfg_scale}, duration_ms en entero milisegundos, vram_peak_mb (None si no hay CUDA)."
|
||||
tested: true
|
||||
tests:
|
||||
- "genera imagen retorna ImageGenResult"
|
||||
test_file_path: "python/functions/ml/tests/test_diffusers_backend.py"
|
||||
file_path: "python/functions/ml/diffusers_generate.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from diffusers_load_pipeline import diffusers_load_pipeline
|
||||
from diffusers_generate import diffusers_generate
|
||||
from generation_config import GenerationConfig
|
||||
from model_ref import ModelRef
|
||||
|
||||
model = ModelRef(
|
||||
name="sd-turbo",
|
||||
model_type="sd15",
|
||||
path="/home/lucas/vaults/imagegen_models/diffusers/sd-turbo",
|
||||
)
|
||||
cfg = GenerationConfig(
|
||||
prompt="a photo of a cat",
|
||||
seed=42,
|
||||
steps=1,
|
||||
cfg_scale=0.0,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model=model,
|
||||
)
|
||||
pipe = diffusers_load_pipeline(model, device="cuda", dtype="fp16")
|
||||
result = diffusers_generate(pipe, cfg)
|
||||
# result.image -> PIL.Image.Image 512x512
|
||||
# result.duration_ms -> int > 0
|
||||
# result.meta["backend"] -> "diffusers"
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
`cfg.seed = -1` genera seed aleatorio basado en `time.time()` (reproducible si
|
||||
se guarda en `result.meta["seed"]`).
|
||||
|
||||
VRAM: `torch.cuda.reset_peak_memory_stats()` antes de inferencia,
|
||||
`torch.cuda.max_memory_allocated() // 1024 // 1024` despues.
|
||||
|
||||
`genconfig_to_diffusers_kwargs` omite generator=None; esta funcion lo reemplaza
|
||||
con `torch.Generator(device=device).manual_seed(seed)`.
|
||||
|
||||
Import lazy de torch — ImportError descriptivo si no instalado.
|
||||
@@ -0,0 +1,98 @@
|
||||
"""diffusers_generate — ejecuta inferencia con un pipeline diffusers y retorna ImageGenResult."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from generation_config import GenerationConfig
|
||||
from image_gen_result import ImageGenResult
|
||||
from genconfig_to_diffusers_kwargs import genconfig_to_diffusers_kwargs
|
||||
|
||||
|
||||
def diffusers_generate(pipe: Any, cfg: GenerationConfig) -> ImageGenResult:
|
||||
"""Ejecuta inferencia con un pipeline diffusers y retorna ImageGenResult.
|
||||
|
||||
Convierte el GenerationConfig a kwargs via genconfig_to_diffusers_kwargs,
|
||||
crea un torch.Generator con la semilla configurada, mide duracion y pico
|
||||
de VRAM (si CUDA disponible). El campo meta del resultado incluye backend,
|
||||
modelo, sampler, seed y steps usados.
|
||||
|
||||
Args:
|
||||
pipe: Pipeline diffusers cargado (resultado de diffusers_load_pipeline).
|
||||
Debe ser callable: pipe(prompt=..., ...) -> objeto con .images[0].
|
||||
cfg: Parametros de generacion. cfg.seed se usa para torch.Generator.
|
||||
cfg.model.name se incluye en meta. cfg.sampler se incluye en meta
|
||||
pero NO se aplica aqui — usar diffusers_set_scheduler antes si se
|
||||
quiere cambiar el sampler.
|
||||
|
||||
Returns:
|
||||
ImageGenResult con image=PIL.Image, meta con keys backend/model/sampler/
|
||||
actual_steps/seed, duration_ms y vram_peak_mb (None si no hay CUDA).
|
||||
|
||||
Raises:
|
||||
ImportError: Si torch o diffusers no estan instalados.
|
||||
RuntimeError: Si la inferencia falla (OOM, modelo incompatible, etc.).
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"diffusers_generate requiere torch. "
|
||||
"Instalar con: pip install torch"
|
||||
) from exc
|
||||
|
||||
# Determinar device del pipeline
|
||||
device = "cpu"
|
||||
if hasattr(pipe, "device"):
|
||||
device = str(pipe.device)
|
||||
|
||||
# Medir VRAM solo en CUDA
|
||||
cuda_available = torch.cuda.is_available()
|
||||
if cuda_available:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Construir kwargs desde GenerationConfig
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
|
||||
# Crear generator con semilla
|
||||
seed = cfg.seed if cfg.seed >= 0 else int(time.time()) % (2**32)
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
kwargs["generator"] = generator
|
||||
|
||||
# Inferencia
|
||||
t0 = time.perf_counter()
|
||||
result = pipe(**kwargs)
|
||||
t1 = time.perf_counter()
|
||||
|
||||
duration_ms = int((t1 - t0) * 1000)
|
||||
|
||||
# VRAM peak
|
||||
vram_peak_mb: int | None = None
|
||||
if cuda_available:
|
||||
vram_peak_mb = torch.cuda.max_memory_allocated() // 1024 // 1024
|
||||
|
||||
# Nombre del modelo
|
||||
model_name = cfg.model.name if cfg.model and hasattr(cfg.model, "name") else "unknown"
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
"backend": "diffusers",
|
||||
"model": model_name,
|
||||
"sampler": cfg.sampler,
|
||||
"actual_steps": cfg.steps,
|
||||
"seed": seed,
|
||||
"width": cfg.width,
|
||||
"height": cfg.height,
|
||||
"cfg_scale": cfg.cfg_scale,
|
||||
}
|
||||
|
||||
return ImageGenResult(
|
||||
image=result.images[0],
|
||||
meta=meta,
|
||||
duration_ms=duration_ms,
|
||||
vram_peak_mb=vram_peak_mb,
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
name: diffusers_load_lora
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def diffusers_load_lora(pipe: Any, lora: LoraRef) -> Any"
|
||||
description: "Carga un adaptador LoRA en un pipeline diffusers via pipe.load_lora_weights. Si lora.weight != 1.0, aplica set_adapters para escalar la contribucion del LoRA."
|
||||
tags: [diffusers, ml, lora, image-generation, fine-tuning]
|
||||
uses_functions: []
|
||||
uses_types: [lora_ref_py_ml]
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [diffusers]
|
||||
params:
|
||||
- name: pipe
|
||||
desc: "Pipeline diffusers que soporte load_lora_weights (SD1.5, SDXL, etc.)."
|
||||
- name: lora
|
||||
desc: "Referencia al adaptador LoRA. lora.path al .safetensors o directorio. lora.weight escala la fusion (1.0 = completo, 0.5 = mitad)."
|
||||
output: "El mismo pipe con el LoRA cargado y peso aplicado. Modificacion in-place, retorna pipe para composicion."
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/ml/diffusers_load_lora.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from diffusers_load_lora import diffusers_load_lora
|
||||
from lora_ref import LoraRef
|
||||
|
||||
lora = LoraRef(path="/path/to/my_lora.safetensors", weight=0.8)
|
||||
pipe = diffusers_load_lora(pipe, lora)
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Usa `pipe.load_lora_weights(path)` para cargar. Si `lora.weight != 1.0`:
|
||||
- Intenta `pipe.set_adapters(["default"], adapter_weights=[weight])` (diffusers >= 0.20).
|
||||
- Fallback a `pipe.fuse_lora(lora_scale=weight)` para versiones antiguas.
|
||||
|
||||
El campo `lora.scale` (override de alpha) no se aplica aqui — diffusers no expone
|
||||
un parametro directo equivalente en la API publica actual. Se puede setear via
|
||||
`pipe.load_lora_weights(path, weight_name=...)` si el archivo tiene nombre especifico.
|
||||
|
||||
Import lazy de diffusers — ImportError descriptivo si no instalado.
|
||||
@@ -0,0 +1,55 @@
|
||||
"""diffusers_load_lora — carga un adaptador LoRA en un pipeline diffusers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from lora_ref import LoraRef
|
||||
|
||||
|
||||
def diffusers_load_lora(pipe: Any, lora: LoraRef) -> Any:
|
||||
"""Carga un adaptador LoRA en un pipeline diffusers y ajusta su peso de fusion.
|
||||
|
||||
Usa pipe.load_lora_weights(lora.path) para cargar los pesos del adaptador.
|
||||
Si lora.weight != 1.0, aplica set_adapters(['default'], adapter_weights=[w])
|
||||
para escalar la contribucion del LoRA. Modifica el pipe in-place y retorna
|
||||
el mismo objeto para composicion.
|
||||
|
||||
Args:
|
||||
pipe: Pipeline diffusers cargado. Debe soportar load_lora_weights
|
||||
(StableDiffusionPipeline, StableDiffusionXLPipeline, etc.).
|
||||
lora: Referencia al adaptador LoRA. lora.path apunta al archivo
|
||||
.safetensors o directorio del adaptador. lora.weight controla
|
||||
la intensidad de fusion (1.0 = peso completo, 0.0 = sin efecto).
|
||||
|
||||
Returns:
|
||||
El mismo pipe con el LoRA cargado y el peso de fusion aplicado.
|
||||
|
||||
Raises:
|
||||
ImportError: Si diffusers no esta instalado.
|
||||
OSError: Si lora.path no existe o el formato del archivo es invalido.
|
||||
"""
|
||||
try:
|
||||
import diffusers # noqa: F401 — verificar disponibilidad
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"diffusers_load_lora requiere diffusers. "
|
||||
"Instalar con: pip install diffusers"
|
||||
) from exc
|
||||
|
||||
pipe.load_lora_weights(lora.path)
|
||||
|
||||
if lora.weight != 1.0:
|
||||
# set_adapters acepta lista de nombres y lista de pesos.
|
||||
# El nombre "default" es el que diffusers asigna al primer LoRA cargado.
|
||||
if hasattr(pipe, "set_adapters"):
|
||||
pipe.set_adapters(["default"], adapter_weights=[lora.weight])
|
||||
elif hasattr(pipe, "fuse_lora"):
|
||||
# Fallback para versiones antiguas de diffusers
|
||||
pipe.fuse_lora(lora_scale=lora.weight)
|
||||
|
||||
return pipe
|
||||
@@ -0,0 +1,60 @@
|
||||
---
|
||||
name: diffusers_load_pipeline
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def diffusers_load_pipeline(model: ModelRef, device: str = 'auto', dtype: str = 'fp16') -> Any"
|
||||
description: "Carga un pipeline diffusers (AutoPipelineForText2Image) con cache global por (model_key, dtype, device). Segunda llamada con mismos parametros retorna el objeto cacheado sin recargar disco."
|
||||
tags: [diffusers, ml, image-generation, pipeline, cache, torch]
|
||||
uses_functions: [torch_device_select_py_ml]
|
||||
uses_types: [model_ref_py_ml]
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [torch, diffusers]
|
||||
params:
|
||||
- name: model
|
||||
desc: "Referencia al modelo. model.path si disponible (ruta local), model.name si no (HuggingFace Hub o nombre corto)."
|
||||
- name: device
|
||||
desc: "Preferencia de device: 'auto' (CUDA>MPS>CPU), 'cuda', 'cuda:N', 'mps', 'cpu'. Default 'auto'."
|
||||
- name: dtype
|
||||
desc: "Precision del modelo: 'fp16' (torch.float16 + variant=fp16), 'bf16' (bfloat16), 'fp32' (float32). Default 'fp16'."
|
||||
output: "Pipeline diffusers cargado y movido al device. Callable via pipe(prompt=..., ...). Cacheado en _PIPELINE_CACHE."
|
||||
tested: true
|
||||
tests:
|
||||
- "carga pipeline y retorna callable"
|
||||
- "segunda carga usa cache (< 100ms)"
|
||||
test_file_path: "python/functions/ml/tests/test_diffusers_backend.py"
|
||||
file_path: "python/functions/ml/diffusers_load_pipeline.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
import sys
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
from diffusers_load_pipeline import diffusers_load_pipeline
|
||||
from model_ref import ModelRef
|
||||
|
||||
model = ModelRef(
|
||||
name="sd-turbo",
|
||||
model_type="sd15",
|
||||
quantization="fp16",
|
||||
path="/home/lucas/vaults/imagegen_models/diffusers/sd-turbo",
|
||||
)
|
||||
pipe = diffusers_load_pipeline(model, device="cuda", dtype="fp16")
|
||||
# Segunda llamada: cache hit, < 100ms
|
||||
pipe2 = diffusers_load_pipeline(model, device="cuda", dtype="fp16")
|
||||
assert pipe is pipe2
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Cache global `_PIPELINE_CACHE` indexado por `(model_key, dtype, resolved_device)`.
|
||||
`model_key` es `model.path` si no es None, sino `model.name`.
|
||||
|
||||
Para liberar memoria: usar `diffusers_unload(pipe=None)` que llama `_clear_pipeline_cache()`.
|
||||
|
||||
Imports lazy de torch y diffusers dentro de la funcion — ImportError descriptivo si no instalados.
|
||||
@@ -0,0 +1,102 @@
|
||||
"""diffusers_load_pipeline — carga un pipeline diffusers con cache global por (model, dtype, device)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from model_ref import ModelRef
|
||||
from torch_device_select import torch_device_select
|
||||
|
||||
# Cache global: (model_key, dtype, device) -> pipeline object
|
||||
_PIPELINE_CACHE: dict[tuple[str, str, str], Any] = {}
|
||||
|
||||
|
||||
def _get_model_key(model: ModelRef) -> str:
|
||||
"""Retorna la clave de cache para un ModelRef."""
|
||||
return model.path if model.path else model.name
|
||||
|
||||
|
||||
def diffusers_load_pipeline(
|
||||
model: ModelRef,
|
||||
device: str = "auto",
|
||||
dtype: str = "fp16",
|
||||
) -> Any:
|
||||
"""Carga un pipeline diffusers con cache global por (model_key, dtype, device).
|
||||
|
||||
Usa AutoPipelineForText2Image.from_pretrained con torch_dtype=torch.float16
|
||||
y variant="fp16" por defecto. Hace pipe.to(device) tras la carga. Los
|
||||
pipelines se cachean en memoria — segunda llamada con los mismos parametros
|
||||
retorna el objeto cacheado sin recargar el modelo del disco.
|
||||
|
||||
Args:
|
||||
model: Referencia al modelo. model.path se usa si esta presente;
|
||||
si no, model.name se pasa directo a from_pretrained (HF hub).
|
||||
device: Preferencia de device. 'auto' delega a torch_device_select
|
||||
(CUDA > MPS > CPU). Ejemplos: 'auto', 'cuda', 'cuda:0', 'cpu'.
|
||||
dtype: Precision del modelo. 'fp16' usa torch.float16 + variant="fp16".
|
||||
'fp32' usa torch.float32 sin variant. 'bf16' usa torch.bfloat16.
|
||||
|
||||
Returns:
|
||||
Objeto pipeline diffusers cargado y movido al device seleccionado.
|
||||
El tipo concreto depende del modelo (StableDiffusionPipeline,
|
||||
StableDiffusionXLPipeline, etc.) pero siempre es callable via pipe(...).
|
||||
|
||||
Raises:
|
||||
ImportError: Si torch o diffusers no estan instalados.
|
||||
OSError: Si el path del modelo no existe o el nombre del hub es invalido.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"diffusers_load_pipeline requiere torch y diffusers. "
|
||||
"Instalar con: pip install torch diffusers"
|
||||
) from exc
|
||||
|
||||
resolved_device = torch_device_select(device)
|
||||
model_key = _get_model_key(model)
|
||||
cache_key = (model_key, dtype, resolved_device)
|
||||
|
||||
if cache_key in _PIPELINE_CACHE:
|
||||
return _PIPELINE_CACHE[cache_key]
|
||||
|
||||
load_path = model.path if model.path else model.name
|
||||
|
||||
if dtype == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
load_path,
|
||||
torch_dtype=torch_dtype,
|
||||
variant="fp16",
|
||||
)
|
||||
elif dtype == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
load_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
elif dtype == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
load_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"dtype '{dtype}' no soportado. Usar 'fp16', 'bf16' o 'fp32'."
|
||||
)
|
||||
|
||||
pipe = pipe.to(resolved_device)
|
||||
_PIPELINE_CACHE[cache_key] = pipe
|
||||
return pipe
|
||||
|
||||
|
||||
def _clear_pipeline_cache() -> None:
|
||||
"""Limpia el cache global de pipelines (uso interno y tests)."""
|
||||
_PIPELINE_CACHE.clear()
|
||||
@@ -0,0 +1,61 @@
|
||||
---
|
||||
name: diffusers_set_scheduler
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def diffusers_set_scheduler(pipe: Any, sampler: str) -> Any"
|
||||
description: "Reemplaza el scheduler de un pipeline diffusers por la clase correspondiente al sampler solicitado. Usa from_config para heredar configuracion base del modelo."
|
||||
tags: [diffusers, ml, scheduler, sampler, image-generation]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [diffusers]
|
||||
params:
|
||||
- name: pipe
|
||||
desc: "Pipeline diffusers cargado con atributo pipe.scheduler y pipe.scheduler.config."
|
||||
- name: sampler
|
||||
desc: "Nombre del sampler: euler, euler_a, dpm++2m, dpm++2m_v2, heun, dpm2, lcm."
|
||||
output: "El mismo pipe con pipe.scheduler reemplazado. Modificacion in-place, retorna pipe para composicion."
|
||||
tested: true
|
||||
tests:
|
||||
- "euler cambia scheduler a EulerDiscreteScheduler"
|
||||
- "sampler invalido lanza ValueError"
|
||||
test_file_path: "python/functions/ml/tests/test_diffusers_backend.py"
|
||||
file_path: "python/functions/ml/diffusers_set_scheduler.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from diffusers_load_pipeline import diffusers_load_pipeline
|
||||
from diffusers_set_scheduler import diffusers_set_scheduler
|
||||
from model_ref import ModelRef
|
||||
|
||||
model = ModelRef(name="sd-turbo", model_type="sd15", path="/path/to/model")
|
||||
pipe = diffusers_load_pipeline(model)
|
||||
pipe = diffusers_set_scheduler(pipe, "euler_a")
|
||||
# type(pipe.scheduler).__name__ == "EulerAncestralDiscreteScheduler"
|
||||
```
|
||||
|
||||
## Mapping de samplers
|
||||
|
||||
| sampler | clase diffusers | kwargs extra |
|
||||
|--------------|------------------------------------|-------------------------------------------|
|
||||
| euler | EulerDiscreteScheduler | — |
|
||||
| euler_a | EulerAncestralDiscreteScheduler | — |
|
||||
| dpm++2m | DPMSolverMultistepScheduler | algorithm_type="dpmsolver++" |
|
||||
| dpm++2m_v2 | DPMSolverMultistepScheduler | algorithm_type="dpmsolver++", solver_order=2 |
|
||||
| heun | HeunDiscreteScheduler | — |
|
||||
| dpm2 | KDPM2DiscreteScheduler | — |
|
||||
| lcm | LCMScheduler | — |
|
||||
|
||||
## Notas
|
||||
|
||||
Usa `SchedulerCls.from_config(pipe.scheduler.config, **extra_kwargs)` para
|
||||
heredar `beta_start`, `beta_end`, `clip_sample`, etc. del modelo base.
|
||||
|
||||
Import lazy de diffusers — ImportError descriptivo si no instalado.
|
||||
@@ -0,0 +1,70 @@
|
||||
"""diffusers_set_scheduler — cambia el scheduler de un pipeline diffusers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Mapping canónico sampler -> (scheduler_class_name, kwargs_extra)
|
||||
_SCHEDULER_MAP: dict[str, tuple[str, dict]] = {
|
||||
"euler": ("EulerDiscreteScheduler", {}),
|
||||
"euler_a": ("EulerAncestralDiscreteScheduler", {}),
|
||||
"dpm++2m": ("DPMSolverMultistepScheduler", {"algorithm_type": "dpmsolver++"}),
|
||||
"dpm++2m_v2": ("DPMSolverMultistepScheduler", {"algorithm_type": "dpmsolver++", "solver_order": 2}),
|
||||
"heun": ("HeunDiscreteScheduler", {}),
|
||||
"dpm2": ("KDPM2DiscreteScheduler", {}),
|
||||
"lcm": ("LCMScheduler", {}),
|
||||
}
|
||||
|
||||
|
||||
def diffusers_set_scheduler(pipe: Any, sampler: str) -> Any:
|
||||
"""Reemplaza el scheduler de un pipeline diffusers por el correspondiente al sampler.
|
||||
|
||||
Usa <SchedulerClass>.from_config(pipe.scheduler.config) para heredar la
|
||||
configuracion base del modelo (betas, clip_sample, etc.) y aplica encima
|
||||
los kwargs especificos del sampler. Modifica pipe.scheduler in-place y
|
||||
retorna el mismo pipe para composicion.
|
||||
|
||||
Args:
|
||||
pipe: Pipeline diffusers cargado (StableDiffusionPipeline,
|
||||
StableDiffusionXLPipeline, etc.). Debe tener atributo
|
||||
pipe.scheduler con .config.
|
||||
sampler: Nombre del sampler. Valores validos: euler, euler_a,
|
||||
dpm++2m, dpm++2m_v2, heun, dpm2, lcm.
|
||||
|
||||
Returns:
|
||||
El mismo pipe con pipe.scheduler reemplazado por la clase
|
||||
correspondiente al sampler solicitado.
|
||||
|
||||
Raises:
|
||||
ImportError: Si diffusers no esta instalado.
|
||||
ValueError: Si el sampler no esta en el mapping soportado.
|
||||
"""
|
||||
try:
|
||||
import diffusers
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"diffusers_set_scheduler requiere diffusers. "
|
||||
"Instalar con: pip install diffusers"
|
||||
) from exc
|
||||
|
||||
if sampler not in _SCHEDULER_MAP:
|
||||
supported = ", ".join(sorted(_SCHEDULER_MAP.keys()))
|
||||
raise ValueError(
|
||||
f"Sampler '{sampler}' no soportado. Valores validos: {supported}"
|
||||
)
|
||||
|
||||
class_name, extra_kwargs = _SCHEDULER_MAP[sampler]
|
||||
scheduler_cls = getattr(diffusers, class_name, None)
|
||||
|
||||
if scheduler_cls is None:
|
||||
raise ImportError(
|
||||
f"La clase '{class_name}' no esta disponible en la version de diffusers "
|
||||
f"instalada. Actualizar diffusers para usar el sampler '{sampler}'."
|
||||
)
|
||||
|
||||
pipe.scheduler = scheduler_cls.from_config(
|
||||
pipe.scheduler.config,
|
||||
**extra_kwargs,
|
||||
)
|
||||
return pipe
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
name: diffusers_unload
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def diffusers_unload(pipe: Any | None = None) -> None"
|
||||
description: "Libera la memoria de un pipeline diffusers. Si pipe=None limpia el cache global de diffusers_load_pipeline. Siempre llama gc.collect() y torch.cuda.empty_cache()."
|
||||
tags: [diffusers, ml, memory, cleanup, vram, cache]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [torch, gc]
|
||||
params:
|
||||
- name: pipe
|
||||
desc: "Pipeline a liberar con del. Si None, limpia el cache global _PIPELINE_CACHE de diffusers_load_pipeline (descarga todos los pipelines cacheados)."
|
||||
output: "None. Efecto secundario: del pipe si pasado, cache limpiado si None, gc.collect() y torch.cuda.empty_cache() siempre."
|
||||
tested: true
|
||||
tests:
|
||||
- "unload None limpia cache cuda si disponible"
|
||||
test_file_path: "python/functions/ml/tests/test_diffusers_backend.py"
|
||||
file_path: "python/functions/ml/diffusers_unload.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from diffusers_unload import diffusers_unload
|
||||
|
||||
# Liberar un pipeline especifico
|
||||
diffusers_unload(pipe)
|
||||
|
||||
# Limpiar TODO el cache (descarga todos los modelos en memoria)
|
||||
diffusers_unload()
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
`del pipe` no garantiza liberacion inmediata si hay otras referencias al objeto.
|
||||
Llamar `diffusers_unload(pipe)` + borrar la referencia local (`pipe = None`)
|
||||
para asegurar que el GC pueda recolectar.
|
||||
|
||||
`torch.cuda.empty_cache()` solo libera cache del allocator de PyTorch, no
|
||||
memoria que otros procesos ocupen. Para liberacion total, el proceso debe terminar.
|
||||
|
||||
Import lazy de torch — si no esta instalado, omite empty_cache silenciosamente.
|
||||
@@ -0,0 +1,47 @@
|
||||
"""diffusers_unload — libera memoria de un pipeline diffusers y limpia cache global."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
def diffusers_unload(pipe: Any | None = None) -> None:
|
||||
"""Libera la memoria ocupada por un pipeline diffusers.
|
||||
|
||||
Si se pasa pipe, lo elimina con del y llama gc.collect() + empty_cache().
|
||||
Si pipe es None, limpia ademas el cache global de diffusers_load_pipeline
|
||||
(descarga TODOS los pipelines cacheados). En ambos casos invoca
|
||||
torch.cuda.empty_cache() si CUDA esta disponible.
|
||||
|
||||
Args:
|
||||
pipe: Pipeline a liberar. Si None, limpia el cache global completo
|
||||
de diffusers_load_pipeline ademas de llamar gc + empty_cache.
|
||||
|
||||
Returns:
|
||||
None. Efecto secundario: memoria GPU/CPU liberada.
|
||||
"""
|
||||
if pipe is None:
|
||||
# Limpiar cache global de diffusers_load_pipeline si esta importado
|
||||
try:
|
||||
# Importar el modulo para acceder a su cache interno
|
||||
load_module_path = os.path.join(os.path.dirname(__file__))
|
||||
if load_module_path not in sys.path:
|
||||
sys.path.insert(0, load_module_path)
|
||||
from diffusers_load_pipeline import _clear_pipeline_cache
|
||||
_clear_pipeline_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
del pipe
|
||||
|
||||
gc.collect()
|
||||
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
name: genconfig_load_json
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def genconfig_load_json(path: str) -> GenerationConfig"
|
||||
description: "Carga y valida un GenerationConfig desde un archivo JSON en disco. Usa pydantic model_validate si disponible; fallback a construccion manual desde dataclass. Raises FileNotFoundError si el archivo no existe."
|
||||
tags: [ml, generation, json, io, deserialization]
|
||||
params:
|
||||
- name: path
|
||||
desc: "Ruta al archivo JSON generado por genconfig_save_json. Relativa o absoluta."
|
||||
output: "Instancia de GenerationConfig cargada y validada con todos sus campos."
|
||||
uses_functions:
|
||||
- genconfig_save_json_py_ml
|
||||
uses_types:
|
||||
- generation_config_py_ml
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "save escribe archivo JSON valido en la ruta indicada"
|
||||
- "save crea directorios padre si no existen"
|
||||
- "json contiene claves en snake_case"
|
||||
test_file_path: "python/functions/ml/tests/test_genconfig_json_roundtrip.py"
|
||||
file_path: "python/functions/ml/genconfig_load_json.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.genconfig_load_json import genconfig_load_json
|
||||
|
||||
cfg = genconfig_load_json("/tmp/gen_config.json")
|
||||
# cfg.prompt == "a forest at dusk"
|
||||
# cfg.seed == 123
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Usa pydantic model_validate cuando disponible (valida literales de model_type,
|
||||
quantization y sampler). Sin pydantic, construye la instancia directamente
|
||||
sin validar literales. Para el contrato Go-Python es importante que los nombres
|
||||
de clave sean snake_case (garantizado por pydantic model_dump_json).
|
||||
@@ -0,0 +1,77 @@
|
||||
"""genconfig_load_json — carga un GenerationConfig desde un archivo JSON."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from generation_config import GenerationConfig
|
||||
|
||||
|
||||
def genconfig_load_json(path: str) -> GenerationConfig:
|
||||
"""Carga y valida un GenerationConfig desde un archivo JSON en disco.
|
||||
|
||||
Usa GenerationConfig.model_validate(data) si pydantic esta disponible
|
||||
(version con validacion completa de tipos y literales). En caso de
|
||||
fallback a dataclass, construye la instancia manualmente mapeando
|
||||
los campos conocidos.
|
||||
|
||||
Args:
|
||||
path: Ruta al archivo JSON. Puede ser relativa o absoluta.
|
||||
|
||||
Returns:
|
||||
Instancia de GenerationConfig cargada y validada.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Si el archivo no existe.
|
||||
json.JSONDecodeError: Si el contenido no es JSON valido.
|
||||
pydantic.ValidationError: Si los datos no cumplen el schema (version pydantic).
|
||||
KeyError / TypeError: Si faltan campos obligatorios (version dataclass).
|
||||
"""
|
||||
abs_path = os.path.abspath(path)
|
||||
with open(abs_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Intentar deserializacion pydantic (version canonica con validacion)
|
||||
try:
|
||||
return GenerationConfig.model_validate(data)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Fallback: dataclass — construir manualmente desde el dict
|
||||
from lora_ref import LoraRef
|
||||
from model_ref import ModelRef
|
||||
|
||||
model_data = data["model"]
|
||||
model = ModelRef(
|
||||
name=model_data["name"],
|
||||
model_type=model_data["model_type"],
|
||||
quantization=model_data.get("quantization", "fp16"),
|
||||
path=model_data.get("path"),
|
||||
)
|
||||
|
||||
loras = [
|
||||
LoraRef(
|
||||
path=lr["path"],
|
||||
weight=lr.get("weight", 1.0),
|
||||
scale=lr.get("scale"),
|
||||
)
|
||||
for lr in data.get("loras", [])
|
||||
]
|
||||
|
||||
return GenerationConfig(
|
||||
prompt=data["prompt"],
|
||||
negative_prompt=data.get("negative_prompt"),
|
||||
seed=data["seed"],
|
||||
steps=data["steps"],
|
||||
cfg_scale=data["cfg_scale"],
|
||||
sampler=data["sampler"],
|
||||
width=data["width"],
|
||||
height=data["height"],
|
||||
model=model,
|
||||
loras=tuple(loras),
|
||||
clip_skip=data.get("clip_skip"),
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
---
|
||||
name: genconfig_save_json
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def genconfig_save_json(cfg: GenerationConfig, path: str) -> str"
|
||||
description: "Serializa un GenerationConfig a JSON (pydantic model_dump_json o dataclass fallback) y lo escribe en disco. Crea directorios padre si no existen. Retorna el path absoluto del archivo escrito."
|
||||
tags: [ml, generation, json, io, serialization]
|
||||
params:
|
||||
- name: cfg
|
||||
desc: "Instancia de GenerationConfig a serializar. Pydantic o dataclass."
|
||||
- name: path
|
||||
desc: "Ruta de destino del archivo JSON. Relativa o absoluta."
|
||||
output: "Path absoluto (str) del archivo JSON escrito en disco."
|
||||
uses_functions: []
|
||||
uses_types:
|
||||
- generation_config_py_ml
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "save escribe archivo JSON valido en la ruta indicada"
|
||||
- "save crea directorios padre si no existen"
|
||||
- "json contiene claves en snake_case"
|
||||
test_file_path: "python/functions/ml/tests/test_genconfig_json_roundtrip.py"
|
||||
file_path: "python/functions/ml/genconfig_save_json.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.genconfig_save_json import genconfig_save_json
|
||||
from ml.generation_config import GenerationConfig
|
||||
from ml.model_ref import ModelRef
|
||||
|
||||
cfg = GenerationConfig(
|
||||
prompt="a forest at dusk",
|
||||
seed=123,
|
||||
steps=25,
|
||||
cfg_scale=7.5,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model=ModelRef(name="runwayml/stable-diffusion-v1-5", model_type="sd15"),
|
||||
)
|
||||
|
||||
saved_path = genconfig_save_json(cfg, "/tmp/gen_config.json")
|
||||
# saved_path == "/tmp/gen_config.json"
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Usa pydantic model_dump_json cuando disponible (JSON canonico con snake_case
|
||||
interoperable con Go). En entornos sin pydantic usa json.dumps + dataclasses.asdict.
|
||||
Los directorios padre se crean con os.makedirs(exist_ok=True).
|
||||
@@ -0,0 +1,58 @@
|
||||
"""genconfig_save_json — persiste un GenerationConfig como JSON en disco."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from generation_config import GenerationConfig
|
||||
|
||||
|
||||
def genconfig_save_json(cfg: GenerationConfig, path: str) -> str:
|
||||
"""Serializa un GenerationConfig a JSON y lo escribe en disco.
|
||||
|
||||
Usa model_dump_json(indent=2) si GenerationConfig es instancia de
|
||||
pydantic.BaseModel (version con validacion). En caso de fallback a
|
||||
dataclass, serializa con json.dumps usando un encoder que convierte
|
||||
dataclasses a dict recursivamente.
|
||||
|
||||
Crea los directorios padre si no existen (equivalente a mkdir -p).
|
||||
|
||||
Args:
|
||||
cfg: Instancia de GenerationConfig a serializar.
|
||||
path: Ruta de destino del archivo JSON. Puede ser relativa o absoluta.
|
||||
|
||||
Returns:
|
||||
Path absoluto del archivo escrito.
|
||||
|
||||
Raises:
|
||||
OSError: Si no se puede crear el directorio o escribir el archivo.
|
||||
"""
|
||||
abs_path = os.path.abspath(path)
|
||||
parent = os.path.dirname(abs_path)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
# Intentar serializacion pydantic (version canonica)
|
||||
try:
|
||||
json_str = cfg.model_dump_json(indent=2)
|
||||
except AttributeError:
|
||||
# Fallback: dataclass — serializar manualmente
|
||||
import dataclasses
|
||||
|
||||
def _to_dict(obj: object) -> object:
|
||||
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
|
||||
return {k: _to_dict(v) for k, v in dataclasses.asdict(obj).items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [_to_dict(i) for i in obj]
|
||||
return obj
|
||||
|
||||
json_str = json.dumps(_to_dict(cfg), indent=2)
|
||||
|
||||
with open(abs_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_str)
|
||||
|
||||
return abs_path
|
||||
@@ -0,0 +1,65 @@
|
||||
---
|
||||
name: genconfig_to_diffusers_kwargs
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def genconfig_to_diffusers_kwargs(cfg: GenerationConfig) -> dict"
|
||||
description: "Convierte un GenerationConfig al dict de kwargs listo para pipe(**kwargs) de diffusers. Mapea prompt, steps, cfg_scale, width, height. LoRAs y sampler se aplican antes de la llamada; generator=None para que el caller setee torch.Generator por separado."
|
||||
tags: [ml, diffusers, generation, converter, pure]
|
||||
params:
|
||||
- name: cfg
|
||||
desc: "Instancia de GenerationConfig con los parametros de generacion validados."
|
||||
output: "dict con claves prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, generator (None). Listo para desempaquetar con pipe(**kwargs)."
|
||||
uses_functions: []
|
||||
uses_types:
|
||||
- generation_config_py_ml
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "kwargs contiene todas las claves requeridas"
|
||||
- "negative_prompt None se pasa tal cual"
|
||||
- "steps y cfg_scale se mapean a num_inference_steps y guidance_scale"
|
||||
- "generator siempre es None"
|
||||
test_file_path: "python/functions/ml/tests/test_genconfig_to_diffusers_kwargs.py"
|
||||
file_path: "python/functions/ml/genconfig_to_diffusers_kwargs.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.genconfig_to_diffusers_kwargs import genconfig_to_diffusers_kwargs
|
||||
from ml.generation_config import GenerationConfig
|
||||
from ml.model_ref import ModelRef
|
||||
|
||||
cfg = GenerationConfig(
|
||||
prompt="a dog in the park",
|
||||
seed=42,
|
||||
steps=30,
|
||||
cfg_scale=7.5,
|
||||
sampler="euler_a",
|
||||
width=512,
|
||||
height=512,
|
||||
model=ModelRef(name="runwayml/stable-diffusion-v1-5", model_type="sd15"),
|
||||
)
|
||||
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
# kwargs["num_inference_steps"] == 30
|
||||
# kwargs["guidance_scale"] == 7.5
|
||||
# kwargs["generator"] is None
|
||||
|
||||
# El caller asigna el generator:
|
||||
# kwargs["generator"] = torch.Generator(device=device).manual_seed(cfg.seed)
|
||||
# image = pipe(**kwargs).images[0]
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura: sin I/O, sin torch, sin imports opcionales en tiempo de ejecucion.
|
||||
Los LoRAs se aplican via `pipe.load_lora_weights(lora.path, adapter_name=...)` antes
|
||||
de la llamada. El scheduler/sampler se configura via `pipe.scheduler = ...` tambien
|
||||
antes. Ambos no tienen mapping directo a kwargs de `__call__`.
|
||||
@@ -0,0 +1,41 @@
|
||||
"""genconfig_to_diffusers_kwargs — convierte GenerationConfig a kwargs para diffusers pipe()."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from generation_config import GenerationConfig
|
||||
|
||||
|
||||
def genconfig_to_diffusers_kwargs(cfg: GenerationConfig) -> dict:
|
||||
"""Convierte un GenerationConfig al dict de kwargs listo para pipe(**kwargs) de diffusers.
|
||||
|
||||
Solo mapea los campos que diffusers StableDiffusionPipeline.__call__ acepta
|
||||
directamente. Los LoRAs y el sampler/scheduler se configuran antes de la
|
||||
llamada via load_lora_weights() y pipe.scheduler = ...; no tienen mapping
|
||||
1:1 con kwargs de __call__.
|
||||
|
||||
El campo "generator" se devuelve como None; el caller debe asignar
|
||||
torch.Generator(device=device).manual_seed(cfg.seed) por separado para
|
||||
poder reutilizar el GenerationConfig en distintos devices sin importar torch
|
||||
aqui (funcion pura).
|
||||
|
||||
Args:
|
||||
cfg: Parametros de generacion validados. Debe ser instancia de GenerationConfig.
|
||||
|
||||
Returns:
|
||||
dict con claves: prompt, negative_prompt, num_inference_steps,
|
||||
guidance_scale, width, height, generator (None).
|
||||
"""
|
||||
return {
|
||||
"prompt": cfg.prompt,
|
||||
"negative_prompt": cfg.negative_prompt,
|
||||
"num_inference_steps": cfg.steps,
|
||||
"guidance_scale": cfg.cfg_scale,
|
||||
"width": cfg.width,
|
||||
"height": cfg.height,
|
||||
"generator": None,
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
---
|
||||
name: genconfig_to_sdcpp_args
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def genconfig_to_sdcpp_args(cfg: GenerationConfig) -> list[str]"
|
||||
description: "Convierte un GenerationConfig a lista de args CLI para stable-diffusion.cpp (sd-cli). Mapea sampler via _SAMPLER_MAP, aplana LoRAs como pares --lora path:weight. Sin I/O ni dependencias externas."
|
||||
tags: [ml, sdcpp, stable-diffusion-cpp, cli, converter, pure]
|
||||
params:
|
||||
- name: cfg
|
||||
desc: "Instancia de GenerationConfig con los parametros de generacion validados."
|
||||
output: "Lista de strings con los argumentos CLI en orden. Listo para subprocess.run(['sd'] + args, ...)."
|
||||
uses_functions: []
|
||||
uses_types:
|
||||
- generation_config_py_ml
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "sampler euler_a se mapea a euler_a en el flag --sampling-method"
|
||||
- "sampler dpm++2m se mapea a dpmpp2m"
|
||||
- "lora con path y weight se agrega como --lora path:weight"
|
||||
- "multiples loras generan multiples pares --lora"
|
||||
- "negative_prompt None produce string vacio en --negative-prompt"
|
||||
- "model.path tiene prioridad sobre model.name en -m"
|
||||
- "args contiene --prompt --seed --steps --cfg-scale --sampling-method -W -H -m"
|
||||
test_file_path: "python/functions/ml/tests/test_genconfig_to_sdcpp_args.py"
|
||||
file_path: "python/functions/ml/genconfig_to_sdcpp_args.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.genconfig_to_sdcpp_args import genconfig_to_sdcpp_args
|
||||
from ml.generation_config import GenerationConfig
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.lora_ref import LoraRef
|
||||
|
||||
cfg = GenerationConfig(
|
||||
prompt="a cat",
|
||||
seed=1,
|
||||
steps=20,
|
||||
cfg_scale=7.0,
|
||||
sampler="dpm++2m",
|
||||
width=512,
|
||||
height=512,
|
||||
model=ModelRef(name="v1-5", model_type="sd15", path="/models/v1-5.ckpt"),
|
||||
loras=[LoraRef(path="/loras/detail.safetensors", weight=0.8)],
|
||||
)
|
||||
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
# ["--prompt", "a cat", "--negative-prompt", "", "--seed", "1",
|
||||
# "--steps", "20", "--cfg-scale", "7.0", "--sampling-method", "dpmpp2m",
|
||||
# "-W", "512", "-H", "512", "-m", "/models/v1-5.ckpt",
|
||||
# "--lora", "/loras/detail.safetensors:0.8"]
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Mapa de samplers (_SAMPLER_MAP):
|
||||
- euler → euler
|
||||
- euler_a → euler_a
|
||||
- dpm++2m → dpmpp2m
|
||||
- dpm++2m_v2 → dpmpp2mv2
|
||||
- heun → heun
|
||||
- dpm2 → dpm2
|
||||
- lcm → lcm
|
||||
|
||||
Si cfg.model.path es None, se usa cfg.model.name (nombre de hub o path relativo
|
||||
segun configuracion del entorno sdcpp). Los LoRAs sin path se omiten silenciosamente.
|
||||
@@ -0,0 +1,65 @@
|
||||
"""genconfig_to_sdcpp_args — convierte GenerationConfig a args CLI para stable-diffusion.cpp."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from generation_config import GenerationConfig
|
||||
|
||||
# Mapa de SamplerName (dominio ml) a flags de sd-cli (stable-diffusion.cpp).
|
||||
# Referencia: https://github.com/leejet/stable-diffusion.cpp#usage
|
||||
_SAMPLER_MAP: dict[str, str] = {
|
||||
"euler": "euler",
|
||||
"euler_a": "euler_a",
|
||||
"dpm++2m": "dpmpp2m",
|
||||
"dpm++2m_v2": "dpmpp2mv2",
|
||||
"heun": "heun",
|
||||
"dpm2": "dpm2",
|
||||
"lcm": "lcm",
|
||||
}
|
||||
|
||||
|
||||
def genconfig_to_sdcpp_args(cfg: GenerationConfig) -> list[str]:
|
||||
"""Convierte un GenerationConfig a la lista de args CLI para stable-diffusion.cpp.
|
||||
|
||||
Genera los argumentos necesarios para invocar `sd` (sd-cli) de
|
||||
stable-diffusion.cpp. El mapa _SAMPLER_MAP traduce los SamplerName
|
||||
canonicos del dominio ml a los identificadores de sdcpp.
|
||||
|
||||
Los LoRAs se pasan como repeticiones del par --lora "path:weight".
|
||||
Si un LoRA no tiene path definido, se omite silenciosamente.
|
||||
|
||||
El modelo se resuelve priorizando cfg.model.path; si es None usa cfg.model.name
|
||||
(puede ser un nombre de hub o un path relativo segun la configuracion de sdcpp).
|
||||
|
||||
Args:
|
||||
cfg: Parametros de generacion validados. Debe ser instancia de GenerationConfig.
|
||||
|
||||
Returns:
|
||||
Lista de strings con los argumentos CLI en el orden esperado por sd-cli.
|
||||
Listo para: subprocess.run(["sd"] + args, ...) o similar.
|
||||
"""
|
||||
model_path = cfg.model.path if cfg.model.path else cfg.model.name
|
||||
sampler_flag = _SAMPLER_MAP.get(cfg.sampler, cfg.sampler)
|
||||
|
||||
args: list[str] = [
|
||||
"--prompt", cfg.prompt,
|
||||
"--negative-prompt", cfg.negative_prompt or "",
|
||||
"--seed", str(cfg.seed),
|
||||
"--steps", str(cfg.steps),
|
||||
"--cfg-scale", str(cfg.cfg_scale),
|
||||
"--sampling-method", sampler_flag,
|
||||
"-W", str(cfg.width),
|
||||
"-H", str(cfg.height),
|
||||
"-m", model_path,
|
||||
]
|
||||
|
||||
# Aplanar LoRAs: cada uno genera un par --lora "path:weight"
|
||||
for lora in cfg.loras:
|
||||
if lora.path:
|
||||
args += ["--lora", f"{lora.path}:{lora.weight}"]
|
||||
|
||||
return args
|
||||
@@ -0,0 +1,111 @@
|
||||
"""GenerationConfig — contrato de parametros para generacion de imagenes con difusion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
_SAMPLER_VALUES = (
|
||||
"euler",
|
||||
"euler_a",
|
||||
"dpm++2m",
|
||||
"dpm++2m_v2",
|
||||
"heun",
|
||||
"dpm2",
|
||||
"lcm",
|
||||
)
|
||||
|
||||
SamplerName = Literal[
|
||||
"euler",
|
||||
"euler_a",
|
||||
"dpm++2m",
|
||||
"dpm++2m_v2",
|
||||
"heun",
|
||||
"dpm2",
|
||||
"lcm",
|
||||
]
|
||||
|
||||
try:
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from lora_ref import LoraRef
|
||||
from model_ref import ModelRef
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
"""Contrato de parametros para generacion de imagenes con modelos de difusion.
|
||||
|
||||
Tipo producto central del dominio ml. Usado como contrato compartido entre
|
||||
Python (diffusers, sd.cpp wrapper) y Go (orquestador). Serializa a JSON
|
||||
canonico via model_dump_json() para intercambio entre servicios.
|
||||
|
||||
Attributes:
|
||||
prompt: Descripcion textual positiva de la imagen a generar.
|
||||
negative_prompt: Descripcion de lo que se quiere evitar. None omite
|
||||
el condicionamiento negativo (requiere soporte del modelo).
|
||||
seed: Semilla para reproducibilidad. -1 usa semilla aleatoria.
|
||||
steps: Numero de pasos de denoising. Rango tipico: 20-50.
|
||||
LCM: 4-8 pasos. Valores altos aumentan calidad y tiempo.
|
||||
cfg_scale: Classifier-Free Guidance scale. Controla cuanto el modelo
|
||||
sigue el prompt. Rango tipico: 5.0-12.0.
|
||||
7.5 es el valor clasico. LCM: 1.0-2.0.
|
||||
sampler: Algoritmo de denoising. Ver SamplerName para valores validos.
|
||||
width: Ancho de la imagen en pixeles. Debe ser multiplo de 8.
|
||||
SD1.5: 512. SDXL: 1024. Flux: 1024+.
|
||||
height: Alto de la imagen en pixeles. Mismas restricciones que width.
|
||||
model: Referencia al modelo base. Ver ModelRef.
|
||||
loras: Lista de adaptadores LoRA a aplicar. Lista vacia = sin LoRA.
|
||||
clip_skip: Numero de capas CLIP a saltar desde el final del encoder.
|
||||
None usa el valor por defecto del modelo. Tipico: 1-2 para anime.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
prompt: str
|
||||
negative_prompt: str | None = None
|
||||
seed: int
|
||||
steps: int
|
||||
cfg_scale: float
|
||||
sampler: SamplerName
|
||||
width: int
|
||||
height: int
|
||||
model: ModelRef
|
||||
loras: list[LoraRef] = []
|
||||
clip_skip: int | None = None
|
||||
|
||||
except ImportError:
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationConfig: # type: ignore[no-redef]
|
||||
"""Contrato de parametros para generacion de imagenes (fallback dataclass).
|
||||
|
||||
Usar la version pydantic cuando este disponible para validacion y
|
||||
serializacion JSON canonica compartida con Go.
|
||||
|
||||
Attributes:
|
||||
prompt: Descripcion textual positiva de la imagen.
|
||||
negative_prompt: Descripcion de lo que evitar. None = sin condicionamiento negativo.
|
||||
seed: Semilla. -1 = aleatoria.
|
||||
steps: Pasos de denoising (20-50 tipico, 4-8 para LCM).
|
||||
cfg_scale: CFG scale (5.0-12.0 tipico, 1.0-2.0 para LCM).
|
||||
sampler: Algoritmo de denoising (ver SamplerName).
|
||||
width: Ancho en pixeles, multiplo de 8.
|
||||
height: Alto en pixeles, multiplo de 8.
|
||||
model: Referencia al modelo base (ModelRef).
|
||||
loras: Lista de LoRAs a aplicar (LoraRef[]).
|
||||
clip_skip: Capas CLIP a saltar desde el final. None = default del modelo.
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
seed: int
|
||||
steps: int
|
||||
cfg_scale: float
|
||||
sampler: str
|
||||
width: int
|
||||
height: int
|
||||
model: object # ModelRef
|
||||
negative_prompt: str | None = None
|
||||
loras: tuple = field(default_factory=tuple)
|
||||
clip_skip: int | None = None
|
||||
@@ -0,0 +1,58 @@
|
||||
---
|
||||
name: gpu_info
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def gpu_info() -> list[dict]"
|
||||
description: "Consulta nvidia-smi para obtener informacion de cada GPU NVIDIA: nombre, VRAM total y libre, version de driver y CUDA. Devuelve lista vacia si nvidia-smi no esta disponible, sin lanzar excepcion."
|
||||
tags: [gpu, nvidia, cuda, vram, hardware, probe, ml, nvidia-smi]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: []
|
||||
params: []
|
||||
output: "lista de dicts por GPU con claves: index (int), name (str), vram_total_mb (int), vram_free_mb (int), driver_version (str), cuda_version (str). Lista vacia si nvidia-smi no esta disponible."
|
||||
tested: true
|
||||
tests:
|
||||
- "sin nvidia-smi devuelve lista vacia"
|
||||
- "formato CSV correcto devuelve lista con un dict por GPU"
|
||||
- "fila malformada en CSV se ignora sin excepcion"
|
||||
test_file_path: "python/functions/ml/tests/test_gpu_info.py"
|
||||
file_path: "python/functions/ml/gpu_info.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.gpu_info import gpu_info
|
||||
|
||||
gpus = gpu_info()
|
||||
# Sin nvidia-smi: []
|
||||
|
||||
# Con una GPU:
|
||||
# [
|
||||
# {
|
||||
# "index": 0,
|
||||
# "name": "NVIDIA GeForce RTX 4090",
|
||||
# "vram_total_mb": 24564,
|
||||
# "vram_free_mb": 22000,
|
||||
# "driver_version": "535.183.01",
|
||||
# "cuda_version": "8.9"
|
||||
# }
|
||||
# ]
|
||||
|
||||
for gpu in gpus:
|
||||
pct = 100 * (1 - gpu["vram_free_mb"] / gpu["vram_total_mb"])
|
||||
print(f"GPU {gpu['index']}: {gpu['name']} — VRAM {pct:.1f}% usada")
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- Usa `--query-gpu=compute_cap` como aproximacion de la version CUDA soportada. El campo `cuda_version` del output es la compute capability (ej. "8.9"), no la version CUDA del driver.
|
||||
- Robusto a `FileNotFoundError` (nvidia-smi no instalado), `TimeoutExpired` (driver colgado), y `OSError`.
|
||||
- Para datos de torch (no nvidia-smi), usar `cuda_available`.
|
||||
- impure: consulta hardware y estado del sistema en tiempo de ejecucion.
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Consulta informacion de GPUs NVIDIA via nvidia-smi."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import subprocess
|
||||
|
||||
|
||||
def gpu_info() -> list[dict]:
|
||||
"""Devuelve informacion de todas las GPUs NVIDIA detectadas por nvidia-smi.
|
||||
|
||||
Consulta nvidia-smi via subprocess. Si nvidia-smi no esta disponible o
|
||||
falla, devuelve lista vacia sin lanzar excepcion.
|
||||
|
||||
Returns:
|
||||
Lista de dicts, uno por GPU, con claves:
|
||||
index (int): indice de la GPU (0, 1, ...).
|
||||
name (str): nombre del modelo (ej. "NVIDIA GeForce RTX 4090").
|
||||
vram_total_mb (int): memoria total en MB.
|
||||
vram_free_mb (int): memoria libre en MB.
|
||||
driver_version (str): version del driver NVIDIA.
|
||||
cuda_version (str): version maxima de CUDA soportada por el driver.
|
||||
Lista vacia si nvidia-smi no esta disponible.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-gpu=index,name,memory.total,memory.free,driver_version,compute_cap",
|
||||
"--format=csv,noheader,nounits",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
except subprocess.TimeoutExpired:
|
||||
return []
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
|
||||
gpus = []
|
||||
reader = csv.reader(result.stdout.strip().splitlines())
|
||||
for row in reader:
|
||||
if len(row) < 5:
|
||||
continue
|
||||
try:
|
||||
index = int(row[0].strip())
|
||||
name = row[1].strip()
|
||||
vram_total_mb = int(row[2].strip())
|
||||
vram_free_mb = int(row[3].strip())
|
||||
driver_version = row[4].strip()
|
||||
# compute_cap (ej. "8.9") como aproximacion de cuda_version soportada
|
||||
cuda_version = row[5].strip() if len(row) > 5 else ""
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
gpus.append(
|
||||
{
|
||||
"index": index,
|
||||
"name": name,
|
||||
"vram_total_mb": vram_total_mb,
|
||||
"vram_free_mb": vram_free_mb,
|
||||
"driver_version": driver_version,
|
||||
"cuda_version": cuda_version,
|
||||
}
|
||||
)
|
||||
|
||||
return gpus
|
||||
@@ -0,0 +1,82 @@
|
||||
---
|
||||
name: hf_snapshot_download
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def hf_snapshot_download(repo_id: str, allow_patterns: list[str] | None = None, ignore_patterns: list[str] | None = None, local_dir: str | None = None, token: str | None = None) -> str"
|
||||
description: "Descarga un snapshot de un repo HuggingFace Hub (completo o filtrado por patrones glob). Wrapper de huggingface_hub.snapshot_download con ImportError descriptivo. Soporta repos privados/gated via token. Retorna path local del snapshot."
|
||||
tags: [huggingface, hf, download, snapshot, model, weights, safetensors, ml, hub]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [huggingface_hub]
|
||||
params:
|
||||
- name: repo_id
|
||||
desc: "identificador del repo en HuggingFace Hub en formato 'owner/name' (ej: 'runwayml/stable-diffusion-v1-5')"
|
||||
- name: allow_patterns
|
||||
desc: "lista opcional de patrones glob para incluir solo ciertos archivos (ej: ['*.safetensors', 'config.json']). None descarga todo."
|
||||
- name: ignore_patterns
|
||||
desc: "lista opcional de patrones glob para excluir archivos (ej: ['*.bin', 'flax_*', 'tf_*']). Util para descargar solo safetensors y evitar duplicados en otro formato."
|
||||
- name: local_dir
|
||||
desc: "directorio local de destino. Si None, usa el cache global de HuggingFace (~/.cache/huggingface/hub/)."
|
||||
- name: token
|
||||
desc: "token de acceso HuggingFace para repos privados o gated (Llama, Gemma, etc.). Si None, usa la variable de entorno HF_TOKEN."
|
||||
output: "string: path absoluto al directorio local donde quedo almacenado el snapshot"
|
||||
tested: true
|
||||
tests:
|
||||
- "repo_id se pasa correctamente a snapshot_download"
|
||||
- "retorna string (la ruta local)"
|
||||
- "allow_patterns se incluye en los kwargs si se especifica"
|
||||
- "ignore_patterns se incluye en los kwargs si se especifica"
|
||||
- "local_dir se incluye en los kwargs si se especifica"
|
||||
- "token se incluye en los kwargs si se especifica"
|
||||
- "args opcionales None no se incluyen en kwargs"
|
||||
- "ImportError descriptivo si huggingface_hub no esta instalado"
|
||||
test_file_path: "python/functions/ml/tests/test_hf_snapshot_download.py"
|
||||
file_path: "python/functions/ml/hf_snapshot_download.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.hf_snapshot_download import hf_snapshot_download
|
||||
|
||||
# Descargar solo safetensors y JSONs de SD v1.5 (evita el .bin de 4 GB)
|
||||
path = hf_snapshot_download(
|
||||
repo_id="runwayml/stable-diffusion-v1-5",
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt"],
|
||||
ignore_patterns=["*.bin"],
|
||||
local_dir=".local/models/sd-v1-5",
|
||||
)
|
||||
# path = "/home/lucas/fn_registry/.local/models/sd-v1-5"
|
||||
|
||||
# Descargar un modelo gated (Llama) con token
|
||||
path = hf_snapshot_download(
|
||||
repo_id="meta-llama/Llama-2-7b-hf",
|
||||
ignore_patterns=["*.bin"],
|
||||
local_dir=".local/models/llama-2-7b",
|
||||
token="hf_xxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
)
|
||||
|
||||
# Descargar al cache global (sin local_dir)
|
||||
path = hf_snapshot_download("BAAI/bge-m3")
|
||||
# path = "/home/lucas/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/..."
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- El wrapper es minimo: no reimplementa logica de descarga, solo asegura que
|
||||
`huggingface_hub` no sea requerido en tiempo de indexacion del registry.
|
||||
- `snapshot_download` es idempotente: si el snapshot ya existe en el cache/local_dir
|
||||
con los mismos hashes, no vuelve a descargar.
|
||||
- `allow_patterns` y `ignore_patterns` usan la semantica de `fnmatch`.
|
||||
Tienen precedencia: si un archivo coincide con ambos, `ignore_patterns` gana.
|
||||
- Para repos grandes (>10 GB), conviene usar `ignore_patterns=["*.bin"]` si el
|
||||
repo ofrece safetensors (formato mas seguro, sin pickle, y soporta mmap).
|
||||
- El token puede ponerse tambien en `~/.cache/huggingface/token` via
|
||||
`huggingface-cli login` para no pasarlo inline.
|
||||
- impure: hace I/O de red, escribe en disco, depende de disponibilidad del Hub.
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Wrapper de huggingface_hub.snapshot_download con manejo de ImportError descriptivo."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def hf_snapshot_download(
|
||||
repo_id: str,
|
||||
allow_patterns: list[str] | None = None,
|
||||
ignore_patterns: list[str] | None = None,
|
||||
local_dir: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> str:
|
||||
"""Descarga un snapshot completo (o filtrado) de un repo de HuggingFace Hub.
|
||||
|
||||
Wrapper sobre `huggingface_hub.snapshot_download` con manejo de ImportError
|
||||
descriptivo. Si `local_dir` se especifica, el snapshot se descarga alli en
|
||||
lugar del cache global de HuggingFace (~/.cache/huggingface/).
|
||||
|
||||
Args:
|
||||
repo_id: identificador del repositorio en HuggingFace Hub
|
||||
(ej: "runwayml/stable-diffusion-v1-5", "meta-llama/Llama-2-7b-hf").
|
||||
allow_patterns: lista opcional de patrones glob para incluir solo ciertos
|
||||
archivos (ej: ["*.safetensors", "*.json"]).
|
||||
Si None, se descargan todos los archivos.
|
||||
ignore_patterns: lista opcional de patrones glob para excluir archivos
|
||||
(ej: ["*.bin", "flax_*"]). Util para evitar descargar
|
||||
pesos en formato pytorch si ya se tienen en safetensors.
|
||||
local_dir: directorio local donde guardar el snapshot. Si None, usa
|
||||
el cache global de HuggingFace Hub.
|
||||
token: token de acceso a HuggingFace (para repos privados o con gated
|
||||
access como Llama). Si None, usa HF_TOKEN del entorno.
|
||||
|
||||
Returns:
|
||||
Path local (str) donde quedo almacenado el snapshot.
|
||||
|
||||
Raises:
|
||||
ImportError: si huggingface_hub no esta instalado, con sugerencia de instalacion.
|
||||
Exception: cualquier error de red o autenticacion propagado desde snapshot_download.
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"huggingface_hub no esta instalado. "
|
||||
"Instalar con: pip install huggingface_hub"
|
||||
) from exc
|
||||
|
||||
kwargs: dict = {"repo_id": repo_id}
|
||||
if allow_patterns is not None:
|
||||
kwargs["allow_patterns"] = allow_patterns
|
||||
if ignore_patterns is not None:
|
||||
kwargs["ignore_patterns"] = ignore_patterns
|
||||
if local_dir is not None:
|
||||
kwargs["local_dir"] = local_dir
|
||||
if token is not None:
|
||||
kwargs["token"] = token
|
||||
|
||||
result = snapshot_download(**kwargs)
|
||||
return str(result)
|
||||
@@ -0,0 +1,84 @@
|
||||
---
|
||||
name: image_compare_side_by_side
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def image_compare_side_by_side(a, b, label_a='A', label_b='B', gap_px=16, show_diff=True, show_phash=True) -> dict"
|
||||
description: "Compara dos PIL Images lado a lado generando una imagen compuesta A | diff | B con gap configurable. Calcula MSE pixel-wise y perceptual hash (imagehash si disponible). Util para inspeccionar diferencias entre generaciones de imagen."
|
||||
tags: [image, compare, diff, phash, mse, pil, pillow, visualization, ml, side-by-side]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [Pillow, numpy, imagehash]
|
||||
tested: true
|
||||
tests:
|
||||
- "grid es PIL.Image con dimensiones correctas show_diff=True"
|
||||
- "grid es PIL.Image sin diff show_diff=False"
|
||||
- "pixel_mse positivo para imagenes distintas"
|
||||
- "pixel_mse cero para imagen identica"
|
||||
- "phash None si imagehash no disponible"
|
||||
test_file_path: "python/functions/ml/tests/test_image_compare_side_by_side.py"
|
||||
file_path: "python/functions/ml/image_compare_side_by_side.py"
|
||||
params:
|
||||
- name: a
|
||||
desc: "Primera imagen PIL (referencia). Se convierte a RGB internamente si es RGBA/L/etc."
|
||||
- name: b
|
||||
desc: "Segunda imagen PIL (comparacion). Se redimensiona a size de a si difieren."
|
||||
- name: label_a
|
||||
desc: "Etiqueta de texto para el panel A (default 'A')."
|
||||
- name: label_b
|
||||
desc: "Etiqueta de texto para el panel B (default 'B')."
|
||||
- name: gap_px
|
||||
desc: "Espacio en pixeles entre paneles y en los bordes del canvas (default 16)."
|
||||
- name: show_diff
|
||||
desc: "Si True (default), inserta panel central con PIL.ImageChops.difference + autocontrast."
|
||||
- name: show_phash
|
||||
desc: "Si True (default), calcula perceptual hash con imagehash. Si el paquete no esta instalado, retorna None silenciosamente."
|
||||
output: "dict con: 'grid' (PIL.Image compuesta), 'phash_a' (str|None), 'phash_b' (str|None), 'phash_distance' (int|None, Hamming), 'pixel_mse' (float)."
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
from ml.image_compare_side_by_side import image_compare_side_by_side
|
||||
|
||||
img_a = Image.open("outputs/gen_v1.png")
|
||||
img_b = Image.open("outputs/gen_v2.png")
|
||||
|
||||
result = image_compare_side_by_side(img_a, img_b, label_a="v1", label_b="v2")
|
||||
|
||||
result["grid"].save("compare.png")
|
||||
print(f"MSE: {result['pixel_mse']:.2f}")
|
||||
print(f"pHash distance: {result['phash_distance']}")
|
||||
```
|
||||
|
||||
## Layout del grid
|
||||
|
||||
Con `show_diff=True` (default):
|
||||
|
||||
```
|
||||
[gap] [A] [gap] [diff] [gap] [B] [gap]
|
||||
```
|
||||
|
||||
Canvas width = 3*w + 4*gap
|
||||
Canvas height = h + 2*gap
|
||||
|
||||
Con `show_diff=False`:
|
||||
|
||||
```
|
||||
[gap] [A] [gap] [B] [gap]
|
||||
```
|
||||
|
||||
Canvas width = 2*w + 3*gap
|
||||
|
||||
## Notas
|
||||
|
||||
- `pixel_mse` usa numpy si disponible; fallback a loop puro stdlib (mas lento).
|
||||
- `phash_*` requiere `pip install imagehash`. Sin el paquete, los tres campos son `None`.
|
||||
- Las imagenes se convierten a RGB antes de cualquier operacion para consistencia.
|
||||
- Si `a` y `b` tienen distinto tamano, `b` se redimensiona con LANCZOS al tamano de `a`.
|
||||
@@ -0,0 +1,147 @@
|
||||
"""Compara dos PIL Images lado a lado con diff opcional y metricas de similitud."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL.Image
|
||||
|
||||
|
||||
def image_compare_side_by_side(
|
||||
a: "PIL.Image.Image",
|
||||
b: "PIL.Image.Image",
|
||||
label_a: str = "A",
|
||||
label_b: str = "B",
|
||||
gap_px: int = 16,
|
||||
show_diff: bool = True,
|
||||
show_phash: bool = True,
|
||||
) -> dict:
|
||||
"""Crea una imagen comparativa lado a lado con metricas opcionales.
|
||||
|
||||
Construye una imagen compuesta A | [diff] | B con gap configurable.
|
||||
Calcula MSE pixel-wise y opcionalmente perceptual hash (imagehash).
|
||||
|
||||
Args:
|
||||
a: Primera imagen PIL (imagen de referencia).
|
||||
b: Segunda imagen PIL (imagen a comparar).
|
||||
label_a: Etiqueta de texto para la imagen A (default "A").
|
||||
label_b: Etiqueta de texto para la imagen B (default "B").
|
||||
gap_px: Espacio en pixeles entre paneles (default 16).
|
||||
show_diff: Si True, inserta panel de diferencia autocontrastada entre A y B.
|
||||
show_phash: Si True, calcula perceptual hash con imagehash si disponible.
|
||||
|
||||
Returns:
|
||||
dict con:
|
||||
- "grid": PIL.Image.Image — imagen compuesta lado a lado.
|
||||
- "phash_a": str | None — 16 hex chars de perceptual hash de A (None si imagehash no disponible).
|
||||
- "phash_b": str | None — 16 hex chars de perceptual hash de B.
|
||||
- "phash_distance": int | None — Distancia de Hamming entre phash_a y phash_b.
|
||||
- "pixel_mse": float — MSE pixel-wise sobre canales RGB.
|
||||
|
||||
Raises:
|
||||
ImportError: si Pillow no esta instalado.
|
||||
"""
|
||||
try:
|
||||
from PIL import Image, ImageChops, ImageDraw, ImageFont, ImageOps
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Pillow no esta instalado. Instalar con: pip install Pillow"
|
||||
) from exc
|
||||
|
||||
# Normalizar a RGB
|
||||
img_a = a.convert("RGB")
|
||||
img_b = b.convert("RGB")
|
||||
|
||||
# Asegurar mismo tamano para comparacion (resize b a size de a si difieren)
|
||||
if img_a.size != img_b.size:
|
||||
img_b = img_b.resize(img_a.size, Image.LANCZOS)
|
||||
|
||||
w, h = img_a.size
|
||||
|
||||
# --- Construir panels ---
|
||||
panels = [img_a]
|
||||
if show_diff:
|
||||
diff = ImageChops.difference(img_a, img_b)
|
||||
diff_contrast = ImageOps.autocontrast(diff)
|
||||
panels.append(diff_contrast)
|
||||
panels.append(img_b)
|
||||
|
||||
n = len(panels)
|
||||
canvas_w = n * w + (n + 1) * gap_px
|
||||
canvas_h = h + 2 * gap_px
|
||||
canvas = Image.new("RGB", (canvas_w, canvas_h), color=(20, 20, 20))
|
||||
|
||||
# Pegar panels
|
||||
labels_map = {0: label_a, n - 1: label_b}
|
||||
if show_diff:
|
||||
labels_map[1] = "diff"
|
||||
|
||||
try:
|
||||
draw = ImageDraw.Draw(canvas)
|
||||
font = ImageFont.load_default()
|
||||
except Exception:
|
||||
draw = None
|
||||
font = None
|
||||
|
||||
for i, panel in enumerate(panels):
|
||||
x = gap_px + i * (w + gap_px)
|
||||
y = gap_px
|
||||
canvas.paste(panel, (x, y))
|
||||
if draw and i in labels_map:
|
||||
draw.text((x + 4, y + 4), labels_map[i], fill=(255, 255, 255), font=font)
|
||||
|
||||
# --- MSE pixel-wise ---
|
||||
pixel_mse = _compute_mse(img_a, img_b)
|
||||
|
||||
# --- Perceptual hash ---
|
||||
phash_a: str | None = None
|
||||
phash_b: str | None = None
|
||||
phash_distance: int | None = None
|
||||
|
||||
if show_phash:
|
||||
try:
|
||||
import imagehash # type: ignore[import]
|
||||
h_a = imagehash.phash(img_a)
|
||||
h_b = imagehash.phash(img_b)
|
||||
phash_a = str(h_a)
|
||||
phash_b = str(h_b)
|
||||
phash_distance = int(h_a - h_b)
|
||||
except ImportError:
|
||||
pass # imagehash not installed — leave None
|
||||
|
||||
return {
|
||||
"grid": canvas,
|
||||
"phash_a": phash_a,
|
||||
"phash_b": phash_b,
|
||||
"phash_distance": phash_distance,
|
||||
"pixel_mse": pixel_mse,
|
||||
}
|
||||
|
||||
|
||||
def _compute_mse(img_a: "PIL.Image.Image", img_b: "PIL.Image.Image") -> float:
|
||||
"""Calcula MSE pixel-wise sobre canales RGB."""
|
||||
try:
|
||||
import numpy as np
|
||||
arr_a = np.asarray(img_a, dtype=np.float64)
|
||||
arr_b = np.asarray(img_b, dtype=np.float64)
|
||||
return float(np.mean((arr_a - arr_b) ** 2))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback puro stdlib (lento para imagenes grandes)
|
||||
pixels_a = list(img_a.getdata())
|
||||
pixels_b = list(img_b.getdata())
|
||||
n_pixels = len(pixels_a)
|
||||
if n_pixels == 0:
|
||||
return 0.0
|
||||
|
||||
total = 0.0
|
||||
for pa, pb in zip(pixels_a, pixels_b):
|
||||
# Each pixel is (R, G, B)
|
||||
for ca, cb in zip(pa, pb):
|
||||
diff = float(ca) - float(cb)
|
||||
total += diff * diff
|
||||
|
||||
channels = 3
|
||||
return total / (n_pixels * channels)
|
||||
@@ -0,0 +1,97 @@
|
||||
"""ImageGenResult — resultado de una operacion de generacion de imagen."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# PIL.Image.Image se importa solo para type-checking estatico (mypy/pyright).
|
||||
# En runtime NO se importa aqui — el consumidor ya tiene PIL instalado si
|
||||
# trabaja con imagenes reales. Esto evita ImportError cuando el modulo se
|
||||
# importa en contextos sin Pillow (ej. el orquestador Go via grpc/json).
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image as PILImage
|
||||
|
||||
try:
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
class ImageGenResult(BaseModel):
|
||||
"""Resultado de una operacion de generacion de imagen con modelo de difusion.
|
||||
|
||||
El campo `image` contiene el objeto PIL.Image.Image generado. No es
|
||||
serializable a JSON — se accede directamente para guardar a disco o
|
||||
pasar a pipelines de post-proceso. Para serializar el resultado,
|
||||
usar solo el campo `meta` (que incluye la config usada) y guardar
|
||||
la imagen por separado.
|
||||
|
||||
Attributes:
|
||||
image: Imagen generada. Tipo PIL.Image.Image en runtime.
|
||||
No incluido en model_dump() ni model_dump_json().
|
||||
None si la generacion fallo (ver meta["error"]).
|
||||
meta: Diccionario con metadata de la generacion. Debe incluir:
|
||||
- "config": GenerationConfig.model_dump() con los params usados.
|
||||
- "model": nombre del modelo.
|
||||
- "seed_used": semilla real usada (util cuando seed=-1).
|
||||
- "sampler": nombre del sampler.
|
||||
Puede incluir campos adicionales del backend.
|
||||
duration_ms: Tiempo total de generacion en milisegundos.
|
||||
vram_peak_mb: Pico de VRAM consumida durante la generacion en MiB.
|
||||
None si no se pudo medir (CPU inference o backend sin soporte).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
image: Any # PIL.Image.Image en runtime; Any para evitar dep dura
|
||||
meta: dict[str, Any]
|
||||
duration_ms: int
|
||||
vram_peak_mb: int | None = None
|
||||
|
||||
@field_validator("image", mode="before")
|
||||
@classmethod
|
||||
def _validate_image(cls, v: Any) -> Any:
|
||||
# Aceptar None (generacion fallida) o cualquier objeto imagen.
|
||||
# No forzamos importar PIL aqui — la validacion real la hace el backend.
|
||||
return v
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Serializa a dict excluyendo el campo image (no serializable a JSON).
|
||||
|
||||
Returns:
|
||||
dict con meta, duration_ms y vram_peak_mb. El campo image se omite.
|
||||
"""
|
||||
return {
|
||||
"meta": self.meta,
|
||||
"duration_ms": self.duration_ms,
|
||||
"vram_peak_mb": self.vram_peak_mb,
|
||||
}
|
||||
|
||||
def model_dump_json(self, **kwargs: Any) -> str:
|
||||
"""Serializa a JSON excluyendo el campo image.
|
||||
|
||||
Returns:
|
||||
String JSON con meta, duration_ms y vram_peak_mb.
|
||||
"""
|
||||
import json
|
||||
|
||||
return json.dumps(self.model_dump())
|
||||
|
||||
except ImportError:
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ImageGenResult: # type: ignore[no-redef]
|
||||
"""Resultado de generacion de imagen (fallback dataclass).
|
||||
|
||||
El campo `image` es PIL.Image.Image en runtime. No serializable a JSON.
|
||||
Usar `meta` + guardar imagen por separado para persistencia.
|
||||
|
||||
Attributes:
|
||||
image: PIL.Image.Image generada. None si fallo.
|
||||
meta: Metadata: config usada, modelo, seed_used, sampler, etc.
|
||||
duration_ms: Duracion total de generacion en milisegundos.
|
||||
vram_peak_mb: Pico de VRAM en MiB. None si no se pudo medir.
|
||||
"""
|
||||
|
||||
image: Any # PIL.Image.Image
|
||||
meta: dict
|
||||
duration_ms: int
|
||||
vram_peak_mb: int | None = None
|
||||
@@ -0,0 +1,46 @@
|
||||
"""ImageGenerator — Protocol para backends de generacion de imagenes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from generation_config import GenerationConfig
|
||||
from image_gen_result import ImageGenResult
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ImageGenerator(Protocol):
|
||||
"""Interfaz comun para backends de generacion de imagenes con difusion.
|
||||
|
||||
Cualquier clase que implemente `generate(config) -> ImageGenResult`
|
||||
satisface este Protocol sin herencia explicita (structural subtyping).
|
||||
|
||||
Backends de ejemplo que satisfacen esta interfaz:
|
||||
- DiffusersGenerator: usa HuggingFace diffusers + torch.
|
||||
- SdCppGenerator: wrapper sobre stable-diffusion.cpp via ctypes/subprocess.
|
||||
- ComfyUIGenerator: cliente HTTP a ComfyUI API.
|
||||
- MockGenerator: implementacion de prueba sin GPU.
|
||||
|
||||
El Protocol es `runtime_checkable`, por lo que se puede usar con isinstance():
|
||||
assert isinstance(my_backend, ImageGenerator)
|
||||
|
||||
Nota: `isinstance()` con Protocol runtime_checkable solo verifica la presencia
|
||||
del metodo `generate`, no la firma completa. Para verificacion estricta usar mypy.
|
||||
"""
|
||||
|
||||
def generate(self, config: "GenerationConfig") -> "ImageGenResult":
|
||||
"""Genera una imagen a partir de la configuracion de difusion.
|
||||
|
||||
Args:
|
||||
config: Parametros de generacion. Ver GenerationConfig.
|
||||
|
||||
Returns:
|
||||
Resultado con la imagen PIL, metadata de la generacion,
|
||||
duracion total y pico de VRAM. Ver ImageGenResult.
|
||||
|
||||
Raises:
|
||||
Exception: El tipo concreto de error depende del backend.
|
||||
Los backends deben documentar sus excepciones propias.
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,77 @@
|
||||
---
|
||||
name: image_grid
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def image_grid(images: list[PIL.Image.Image], cols: int = 4, labels: list[str] | None = None, gap_px: int = 8, bg_color: tuple = (20,20,20)) -> PIL.Image.Image"
|
||||
description: "Combina una lista de PIL Images en un grid NxM con gap configurable, fondo oscuro y labels opcionales sobre cada celda. rows se calcula como ceil(n/cols). Retorna una sola PIL.Image RGB."
|
||||
tags: [image, grid, pil, pillow, visualization, ml, montage, collage]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [Pillow]
|
||||
params:
|
||||
- name: images
|
||||
desc: "lista de PIL.Image.Image a colocar en el grid (todas deben tener el mismo tamano o se usa el maximo)"
|
||||
- name: cols
|
||||
desc: "numero de columnas del grid (default 4)"
|
||||
- name: labels
|
||||
desc: "lista opcional de strings para etiquetar cada celda en la esquina superior izquierda"
|
||||
- name: gap_px
|
||||
desc: "espacio en pixeles entre celdas y en los bordes del canvas (default 8)"
|
||||
- name: bg_color
|
||||
desc: "color RGB de fondo del canvas como tupla (R, G, B), default (20,20,20) casi negro"
|
||||
output: "PIL.Image.Image: imagen RGB con el grid montado. Lista con n imagenes en cols columnas y ceil(n/cols) filas."
|
||||
tested: true
|
||||
tests:
|
||||
- "grid de 4 imagenes 16x16 cols=2 produce ancho/alto correcto"
|
||||
- "grid de 4 imagenes cols=2 gap_px=8 tiene dimensiones correctas con gap"
|
||||
- "grid de 1 imagen 1 col"
|
||||
- "el resultado es una imagen RGB"
|
||||
- "labels opcionales no lanza excepcion"
|
||||
- "sin labels funciona correctamente"
|
||||
- "lista vacia levanta ValueError"
|
||||
test_file_path: "python/functions/ml/tests/test_image_grid.py"
|
||||
file_path: "python/functions/ml/image_grid.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
from ml.image_grid import image_grid
|
||||
from ml.image_save_png import image_save_png
|
||||
|
||||
# Generar 6 imagenes de prueba con colores distintos
|
||||
colors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255),(255,0,255)]
|
||||
imgs = [Image.new("RGB", (256, 256), c) for c in colors]
|
||||
|
||||
grid = image_grid(
|
||||
imgs,
|
||||
cols=3,
|
||||
labels=["rojo", "verde", "azul", "amarillo", "cyan", "magenta"],
|
||||
gap_px=10,
|
||||
bg_color=(30, 30, 30),
|
||||
)
|
||||
# grid.size == (3*256 + 4*10, 2*256 + 3*10) == (788, 542)
|
||||
|
||||
image_save_png(grid, "outputs/preview_grid.png")
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- Impure: asigna memoria para el canvas y ejecuta `ImageDraw` (efectos en
|
||||
objetos PIL internos). Aunque no hace I/O de disco, las allocations PIL
|
||||
y el draw tienen side-effects sobre objetos mutables.
|
||||
- Las imagenes en modo no-RGB (RGBA, L, P, palette) se convierten a RGB
|
||||
automaticamente con `.convert("RGB")` antes de pegar.
|
||||
- Si la lista tiene menos imagenes que `cols * rows`, las celdas sobrantes
|
||||
quedan en blanco (solo el color de fondo).
|
||||
- El label usa `ImageFont.load_default()` (fuente bitmap monospace de PIL,
|
||||
sin dependencias externas). Para fuentes TTF customizadas usar
|
||||
`ImageFont.truetype(path, size)` externamente y pasar un `font` propio.
|
||||
- Pillow se importa lazy para no bloquear `fn index`.
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Combina una lista de PIL Images en un grid NxM con gap y labels opcionales."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def image_grid(
|
||||
images: list["PIL.Image.Image"],
|
||||
cols: int = 4,
|
||||
labels: list[str] | None = None,
|
||||
gap_px: int = 8,
|
||||
bg_color: tuple = (20, 20, 20),
|
||||
) -> "PIL.Image.Image":
|
||||
"""Combina una lista de imagenes en un grid NxM.
|
||||
|
||||
Asume que todas las imagenes tienen el mismo tamano (usa el maximo
|
||||
ancho/alto detectado). Calcula rows = ceil(n / cols) automaticamente.
|
||||
|
||||
Args:
|
||||
images: lista de PIL.Image.Image a colocar en el grid.
|
||||
cols: numero de columnas del grid (default 4).
|
||||
labels: lista opcional de strings. Si se proporciona, se escribe
|
||||
un label encima de cada celda usando la fuente default de PIL.
|
||||
gap_px: espacio en pixeles entre celdas y en los bordes (default 8).
|
||||
bg_color: color de fondo RGB del canvas (default casi negro (20,20,20)).
|
||||
|
||||
Returns:
|
||||
Una sola PIL.Image en modo RGB con el grid montado.
|
||||
|
||||
Raises:
|
||||
ImportError: si Pillow no esta instalado.
|
||||
ValueError: si images esta vacio.
|
||||
"""
|
||||
try:
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Pillow no esta instalado. Instalar con: pip install Pillow"
|
||||
) from exc
|
||||
|
||||
if not images:
|
||||
raise ValueError("image_grid: la lista de imagenes no puede estar vacia")
|
||||
|
||||
n = len(images)
|
||||
rows = math.ceil(n / cols)
|
||||
|
||||
# Tamano de celda: max ancho y alto de todas las imagenes
|
||||
cell_w = max(img.width for img in images)
|
||||
cell_h = max(img.height for img in images)
|
||||
|
||||
canvas_w = cols * cell_w + (cols + 1) * gap_px
|
||||
canvas_h = rows * cell_h + (rows + 1) * gap_px
|
||||
|
||||
canvas = Image.new("RGB", (canvas_w, canvas_h), color=bg_color)
|
||||
|
||||
draw = ImageDraw.Draw(canvas) if labels else None
|
||||
font = None
|
||||
if draw:
|
||||
try:
|
||||
font = ImageFont.load_default()
|
||||
except Exception:
|
||||
font = None
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
row = idx // cols
|
||||
col = idx % cols
|
||||
x = gap_px + col * (cell_w + gap_px)
|
||||
y = gap_px + row * (cell_h + gap_px)
|
||||
|
||||
# Convertir a RGB si hace falta (RGBA, L, P, etc.)
|
||||
paste_img = img.convert("RGB") if img.mode != "RGB" else img
|
||||
canvas.paste(paste_img, (x, y))
|
||||
|
||||
if draw and labels and idx < len(labels):
|
||||
label = labels[idx]
|
||||
# Texto en la esquina superior izquierda de la celda
|
||||
draw.text((x + 2, y + 2), label, fill=(255, 255, 255), font=font)
|
||||
|
||||
return canvas
|
||||
@@ -0,0 +1,69 @@
|
||||
---
|
||||
name: image_save_png
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def image_save_png(img: PIL.Image.Image, path: str, metadata: dict | None = None) -> str"
|
||||
description: "Guarda una PIL Image como PNG con metadata embebida en chunks tEXt (prompt, seed, steps, sampler, model). Crea directorio padre si no existe. Retorna path absoluto escrito."
|
||||
tags: [image, png, pil, pillow, metadata, save, ml, reproducibility]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [Pillow]
|
||||
params:
|
||||
- name: img
|
||||
desc: "imagen PIL.Image.Image a guardar"
|
||||
- name: path
|
||||
desc: "ruta de destino del archivo PNG (absoluta o relativa)"
|
||||
- name: metadata
|
||||
desc: "dict opcional de pares clave/valor a embeber en chunks tEXt del PNG para reproducibilidad (prompt, seed, steps, etc.)"
|
||||
output: "string: ruta absoluta del archivo PNG escrito"
|
||||
tested: true
|
||||
tests:
|
||||
- "crea imagen 8x8, guarda y retorna ruta absoluta"
|
||||
- "metadata se embebe en chunks tEXt y se puede releer con Image.text"
|
||||
- "sin metadata el PNG se guarda igualmente"
|
||||
- "crea directorio padre si no existe"
|
||||
- "valores numericos en metadata se convierten a str automaticamente"
|
||||
test_file_path: "python/functions/ml/tests/test_image_save_png.py"
|
||||
file_path: "python/functions/ml/image_save_png.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
from ml.image_save_png import image_save_png
|
||||
|
||||
img = Image.new("RGB", (512, 512), color=(128, 64, 200))
|
||||
path = image_save_png(
|
||||
img,
|
||||
"outputs/gen_001.png",
|
||||
metadata={
|
||||
"prompt": "a cat on a purple sofa",
|
||||
"seed": 42,
|
||||
"steps": 20,
|
||||
"sampler": "euler_a",
|
||||
"model": "sd-v1-5",
|
||||
},
|
||||
)
|
||||
# path = "/home/lucas/.../outputs/gen_001.png"
|
||||
# Los metadatos quedan embebidos en el PNG y son legibles con exiftool o PIL.
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- Usa `PngImagePlugin.PngInfo` para chunks `tEXt` (texto plano, no comprimido).
|
||||
Para texto largo/comprimido existe `add_itxt`, pero `add_text` es compatible
|
||||
con la mayoria de lectores (exiftool, A1111, ComfyUI, etc.).
|
||||
- Los valores del dict se convierten a `str` automaticamente — se puede pasar
|
||||
int, float o bool sin castear.
|
||||
- Si `metadata` es `None` o `{}`, el PNG se guarda sin chunks extra (igual que
|
||||
`img.save(path)`).
|
||||
- Pillow no esta en los imports por defecto del registry para no bloquear
|
||||
`fn index`. Se importa lazy dentro de la funcion.
|
||||
- impure: escribe en disco y crea directorios.
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Guarda una PIL Image como PNG con metadata embebida en chunks tEXt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def image_save_png(img: "PIL.Image.Image", path: str, metadata: dict | None = None) -> str:
|
||||
"""Guarda una PIL Image como PNG en la ruta indicada.
|
||||
|
||||
Embebe metadata arbitraria en chunks tEXt del PNG (clave/valor string).
|
||||
Util para registrar prompt, seed, steps, sampler, model dentro del archivo
|
||||
para reproducibilidad.
|
||||
|
||||
Crea el directorio padre si no existe.
|
||||
|
||||
Args:
|
||||
img: imagen PIL a guardar.
|
||||
path: ruta de destino (absoluta o relativa). Debe terminar en .png.
|
||||
metadata: dict opcional de pares {clave: valor} a embeber en el PNG.
|
||||
Los valores se convierten a str automaticamente.
|
||||
|
||||
Returns:
|
||||
Ruta absoluta del archivo PNG escrito.
|
||||
|
||||
Raises:
|
||||
ImportError: si Pillow no esta instalado.
|
||||
OSError: si no se puede escribir en la ruta indicada.
|
||||
"""
|
||||
try:
|
||||
from PIL import PngImagePlugin
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Pillow no esta instalado. Instalar con: pip install Pillow"
|
||||
) from exc
|
||||
|
||||
abs_path = os.path.abspath(path)
|
||||
parent = os.path.dirname(abs_path)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
png_info = PngImagePlugin.PngInfo()
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
png_info.add_text(str(key), str(value))
|
||||
|
||||
img.save(abs_path, format="PNG", pnginfo=png_info)
|
||||
return abs_path
|
||||
@@ -0,0 +1,47 @@
|
||||
"""LoraRef — referencia a un adaptador LoRA para generacion de imagenes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
class LoraRef(BaseModel):
|
||||
"""Referencia a un adaptador LoRA (Low-Rank Adaptation).
|
||||
|
||||
Un LoRA modifica el comportamiento de un modelo base sin cambiar
|
||||
sus pesos originales. Se aplica multiplicando matrices de rango bajo
|
||||
durante la inferencia.
|
||||
|
||||
Attributes:
|
||||
path: Ruta al archivo .safetensors o .bin del adaptador LoRA.
|
||||
Puede ser absoluta o relativa al directorio de modelos.
|
||||
weight: Factor de escala global del LoRA. 1.0 aplica el LoRA
|
||||
con su fuerza original. 0.0 lo desactiva completamente.
|
||||
Rango tipico: 0.0 a 1.5.
|
||||
scale: Override del alpha del LoRA (escala de rango). None usa
|
||||
el alpha del propio archivo. Util para ajuste fino sin
|
||||
reentrenar.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
path: str
|
||||
weight: float = 1.0
|
||||
scale: float | None = None
|
||||
|
||||
except ImportError:
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LoraRef: # type: ignore[no-redef]
|
||||
"""Referencia a un adaptador LoRA (fallback dataclass).
|
||||
|
||||
Attributes:
|
||||
path: Ruta al archivo del adaptador LoRA (.safetensors o .bin).
|
||||
weight: Factor de escala global. Rango tipico 0.0-1.5. Por defecto 1.0.
|
||||
scale: Override del alpha. None usa el alpha del archivo.
|
||||
"""
|
||||
|
||||
path: str
|
||||
weight: float = 1.0
|
||||
scale: float | None = None
|
||||
@@ -0,0 +1,67 @@
|
||||
"""ModelRef — referencia a un modelo de generacion de imagenes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
try:
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
class ModelRef(BaseModel):
|
||||
"""Referencia a un modelo de generacion de imagenes.
|
||||
|
||||
Identifica el modelo por nombre (HuggingFace hub o ruta local),
|
||||
tipo de arquitectura y cuantizacion. Serializable a JSON canonico
|
||||
con model_dump() / model_dump_json() para el contrato compartido con Go.
|
||||
|
||||
Attributes:
|
||||
name: Nombre del modelo en HuggingFace Hub o identificador local.
|
||||
Ejemplo: "stabilityai/stable-diffusion-xl-base-1.0".
|
||||
model_type: Arquitectura del modelo. Uno de los literales definidos.
|
||||
quantization: Precision numerica del checkpoint. Por defecto "fp16".
|
||||
path: Ruta local al checkpoint si ya fue descargado. None si
|
||||
se debe descargar del hub.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
name: str
|
||||
model_type: Literal[
|
||||
"sd15",
|
||||
"sd20",
|
||||
"sdxl",
|
||||
"sd3",
|
||||
"flux_dev",
|
||||
"flux_schnell",
|
||||
"flux_kontext",
|
||||
"qwen_image",
|
||||
"chroma",
|
||||
"z_image",
|
||||
]
|
||||
quantization: Literal[
|
||||
"fp32", "fp16", "bf16", "q8_0", "q5_1", "q5_0", "q4_1", "q4_0"
|
||||
] = "fp16"
|
||||
path: str | None = None
|
||||
|
||||
except ImportError:
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelRef: # type: ignore[no-redef]
|
||||
"""Referencia a un modelo de generacion de imagenes (fallback dataclass).
|
||||
|
||||
Usar la version pydantic cuando este disponible para validacion y
|
||||
serializacion JSON canonica. Esta version no valida los literales en
|
||||
tiempo de ejecucion.
|
||||
|
||||
Attributes:
|
||||
name: Nombre del modelo en HuggingFace Hub o ruta local.
|
||||
model_type: Arquitectura del modelo (sd15|sd20|sdxl|sd3|flux_dev|...).
|
||||
quantization: Precision numerica (fp32|fp16|bf16|q8_0|...). Por defecto "fp16".
|
||||
path: Ruta local al checkpoint. None si no esta descargado.
|
||||
"""
|
||||
|
||||
name: str
|
||||
model_type: str
|
||||
quantization: str = "fp16"
|
||||
path: str | None = None
|
||||
@@ -0,0 +1,99 @@
|
||||
---
|
||||
name: safetensors_inspect
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def safetensors_inspect(path: str) -> dict"
|
||||
description: "Lee SOLO el header de un archivo .safetensors sin cargar los tensores en RAM. Retorna metadata del modelo, lista de tensores con dtype/shape/offsets, tamano total y conteo. Util para inspeccionar checkpoints de varios GB sin agotarlos en memoria."
|
||||
tags: [safetensors, model, inspect, header, ml, huggingface, checkpoint, dtype, shape]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: []
|
||||
params:
|
||||
- name: path
|
||||
desc: "ruta al archivo .safetensors a inspeccionar (absoluta o relativa)"
|
||||
output: "dict con claves: path (str ruta absoluta), metadata (dict con __metadata__ del header), tensors (list[dict] con name/dtype/shape/data_offsets por tensor), total_size_bytes (int), n_tensors (int)"
|
||||
tested: true
|
||||
tests:
|
||||
- "n_tensors refleja el numero de tensores en el header"
|
||||
- "total_size_bytes refleja el tamano real del archivo"
|
||||
- "metadata devuelve el contenido de __metadata__"
|
||||
- "tensors es lista con una entrada por tensor del header"
|
||||
- "cada tensor tiene dtype, shape y data_offsets"
|
||||
- "result path es la ruta absoluta del archivo"
|
||||
- "FileNotFoundError si el archivo no existe"
|
||||
- "ValueError si el header no es JSON valido"
|
||||
- "ValueError si el archivo esta vacio"
|
||||
- "si no hay __metadata__ metadata retorna dict vacio"
|
||||
test_file_path: "python/functions/ml/tests/test_safetensors_inspect.py"
|
||||
file_path: "python/functions/ml/safetensors_inspect.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.safetensors_inspect import safetensors_inspect
|
||||
|
||||
info = safetensors_inspect("/models/sd-v1-5/model.safetensors")
|
||||
print(info["n_tensors"]) # 1344
|
||||
print(info["total_size_bytes"]) # 3_975_733_952 (~3.7 GB)
|
||||
print(info["metadata"]) # {"format": "pt", "model_type": "stable_diffusion"}
|
||||
|
||||
# Ver los 5 primeros tensores
|
||||
for t in info["tensors"][:5]:
|
||||
print(t["name"], t["dtype"], t["shape"])
|
||||
# model.diffusion_model.input_blocks.0.0.weight F16 [320, 4, 3, 3]
|
||||
# model.diffusion_model.input_blocks.0.0.bias F16 [320]
|
||||
# ...
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
### Formato safetensors
|
||||
|
||||
```
|
||||
[8 bytes: uint64 LE = N (longitud del header JSON)]
|
||||
[N bytes: JSON con metadata y descriptores]
|
||||
[datos binarios de los tensores (no se leen)]
|
||||
```
|
||||
|
||||
El JSON tiene esta estructura:
|
||||
```json
|
||||
{
|
||||
"__metadata__": {"format": "pt", ...},
|
||||
"tensor_name": {
|
||||
"dtype": "F32",
|
||||
"shape": [1024, 768],
|
||||
"data_offsets": [0, 3145728]
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
`data_offsets` son relativos al inicio del bloque de datos (despues del header),
|
||||
no al inicio del archivo. Para acceso lazy a un tensor concreto:
|
||||
`offset_en_archivo = 8 + header_len + data_offsets[0]`.
|
||||
|
||||
### Por que no usar la libreria `safetensors`
|
||||
|
||||
Esta funcion solo usa stdlib (`struct`, `json`, `os`) para no requerir
|
||||
instalaciones adicionales y ser ejecutable durante `fn index`. La libreria
|
||||
oficial `safetensors` de HuggingFace cargaria los tensores en RAM al usar
|
||||
`safe_open` sin `framework=None`. Esta implementacion es read-only sobre
|
||||
el header y garantiza que no se carga ningun dato de tensor.
|
||||
|
||||
### Dtypes comunes
|
||||
|
||||
| dtype | descripcion |
|
||||
|-------|-------------|
|
||||
| F32 | float32 (full precision) |
|
||||
| BF16 | bfloat16 (training, ampere+) |
|
||||
| F16 | float16 (inference) |
|
||||
| I32 | int32 |
|
||||
| I64 | int64 |
|
||||
| U8 | uint8 |
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Lee solo el header de un archivo safetensors sin cargar los tensores en RAM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
|
||||
def safetensors_inspect(path: str) -> dict:
|
||||
"""Lee el header de un archivo safetensors sin cargar los tensores.
|
||||
|
||||
El formato safetensors almacena al inicio del archivo:
|
||||
- 8 bytes: uint64 little-endian con la longitud del header JSON (N).
|
||||
- N bytes: JSON con metadata y descriptores de tensores.
|
||||
|
||||
Este enfoque evita cargar gigabytes de pesos en RAM para inspeccionar
|
||||
un checkpoint: solo se leen los primeros 8 + N bytes.
|
||||
|
||||
Spec: https://github.com/huggingface/safetensors#format
|
||||
|
||||
Args:
|
||||
path: ruta al archivo .safetensors (absoluta o relativa).
|
||||
|
||||
Returns:
|
||||
dict con claves:
|
||||
path (str): ruta absoluta del archivo.
|
||||
metadata (dict): metadatos del modelo (campo "__metadata__" del header).
|
||||
tensors (list[dict]): lista de tensores, cada uno con:
|
||||
name (str): nombre del tensor.
|
||||
dtype (str): tipo de dato (F32, BF16, F16, I32, etc.).
|
||||
shape (list[int]): dimensiones del tensor.
|
||||
data_offsets (list[int]): [inicio, fin] en bytes dentro del
|
||||
bloque de datos (para acceso lazy si se necesita).
|
||||
total_size_bytes (int): tamano total del archivo en bytes.
|
||||
n_tensors (int): numero de tensores en el archivo.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: si el archivo no existe.
|
||||
ValueError: si el archivo no es un safetensors valido.
|
||||
OSError: si no se puede leer el archivo.
|
||||
"""
|
||||
import os
|
||||
|
||||
abs_path = os.path.abspath(path)
|
||||
|
||||
if not os.path.isfile(abs_path):
|
||||
raise FileNotFoundError(f"safetensors_inspect: archivo no encontrado: {abs_path}")
|
||||
|
||||
total_size = os.path.getsize(abs_path)
|
||||
|
||||
with open(abs_path, "rb") as f:
|
||||
# Leer los 8 bytes del tamano del header
|
||||
raw_len = f.read(8)
|
||||
if len(raw_len) < 8:
|
||||
raise ValueError(
|
||||
f"safetensors_inspect: archivo demasiado corto para ser safetensors: {abs_path}"
|
||||
)
|
||||
|
||||
header_len = struct.unpack("<Q", raw_len)[0] # uint64 little-endian
|
||||
|
||||
if header_len == 0 or header_len > total_size:
|
||||
raise ValueError(
|
||||
f"safetensors_inspect: header_len invalido ({header_len}) en {abs_path}"
|
||||
)
|
||||
|
||||
raw_header = f.read(header_len)
|
||||
if len(raw_header) < header_len:
|
||||
raise ValueError(
|
||||
f"safetensors_inspect: no se pudo leer el header completo en {abs_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
header = json.loads(raw_header.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
raise ValueError(
|
||||
f"safetensors_inspect: header JSON invalido en {abs_path}: {exc}"
|
||||
) from exc
|
||||
|
||||
metadata = header.get("__metadata__", {})
|
||||
|
||||
tensors = []
|
||||
for name, desc in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
tensors.append(
|
||||
{
|
||||
"name": name,
|
||||
"dtype": desc.get("dtype", ""),
|
||||
"shape": desc.get("shape", []),
|
||||
"data_offsets": desc.get("data_offsets", []),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"path": abs_path,
|
||||
"metadata": metadata,
|
||||
"tensors": tensors,
|
||||
"total_size_bytes": total_size,
|
||||
"n_tensors": len(tensors),
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
"""SamplerName — subset de samplers compartido entre diffusers y stable-diffusion.cpp."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
# Sum type: valores validos de sampler para GenerationConfig.
|
||||
# Subset estricto que tienen correspondencia directa en diffusers (schedulers)
|
||||
# y en stable-diffusion.cpp (--sampling-method).
|
||||
SamplerName = Literal[
|
||||
"euler",
|
||||
"euler_a",
|
||||
"dpm++2m",
|
||||
"dpm++2m_v2",
|
||||
"heun",
|
||||
"dpm2",
|
||||
"lcm",
|
||||
]
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Tests para cuda_available."""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
# Asegurar que el modulo ml es importable desde el path del registry
|
||||
sys.path.insert(0, "python/functions")
|
||||
|
||||
from ml.cuda_available import cuda_available
|
||||
|
||||
|
||||
class TestCudaAvailable(unittest.TestCase):
|
||||
|
||||
def test_claves_del_dict_siempre_presentes(self):
|
||||
"""claves del dict siempre presentes"""
|
||||
result = cuda_available()
|
||||
for key in ("available", "device_count", "devices", "torch_version", "cuda_version"):
|
||||
self.assertIn(key, result, f"Falta clave: {key}")
|
||||
|
||||
def test_sin_torch_retorna_available_False_y_torch_version_not_installed(self):
|
||||
"""sin torch retorna available=False y torch_version=not_installed"""
|
||||
with patch.dict(sys.modules, {"torch": None}):
|
||||
result = cuda_available()
|
||||
self.assertFalse(result["available"])
|
||||
self.assertEqual(result["torch_version"], "not_installed")
|
||||
self.assertEqual(result["device_count"], 0)
|
||||
self.assertEqual(result["devices"], [])
|
||||
self.assertIsNone(result["cuda_version"])
|
||||
|
||||
def test_con_torch_sin_cuda_retorna_available_False_y_device_count_0(self):
|
||||
"""con torch sin cuda retorna available=False y device_count=0"""
|
||||
import types
|
||||
fake_torch = types.ModuleType("torch")
|
||||
fake_torch.__version__ = "2.3.0"
|
||||
fake_torch.cuda = types.SimpleNamespace(
|
||||
is_available=lambda: False,
|
||||
device_count=lambda: 0,
|
||||
)
|
||||
fake_torch.version = types.SimpleNamespace(cuda=None)
|
||||
|
||||
with patch.dict(sys.modules, {"torch": fake_torch}):
|
||||
result = cuda_available()
|
||||
|
||||
self.assertFalse(result["available"])
|
||||
self.assertEqual(result["device_count"], 0)
|
||||
self.assertEqual(result["devices"], [])
|
||||
self.assertEqual(result["torch_version"], "2.3.0")
|
||||
self.assertIsNone(result["cuda_version"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,212 @@
|
||||
"""Tests para el backend diffusers: load_pipeline, set_scheduler, generate, unload."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
# Ajustar path para importar desde python/functions/ml/
|
||||
_ML_PATH = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..",
|
||||
)
|
||||
sys.path.insert(0, os.path.abspath(_ML_PATH))
|
||||
|
||||
# Importaciones lazy de torch y diffusers — las omitimos si no estan disponibles.
|
||||
torch = pytest.importorskip("torch", reason="torch no instalado — skip tests diffusers")
|
||||
pytest.importorskip("diffusers", reason="diffusers no instalado — skip tests diffusers")
|
||||
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.generation_config import GenerationConfig
|
||||
from ml.image_gen_result import ImageGenResult
|
||||
from ml.diffusers_load_pipeline import diffusers_load_pipeline, _clear_pipeline_cache
|
||||
from ml.diffusers_set_scheduler import diffusers_set_scheduler
|
||||
from ml.diffusers_unload import diffusers_unload
|
||||
|
||||
# diffusers_generate importa image_gen_result sin prefijo de paquete.
|
||||
# Para evitar el double-import problem (ml.image_gen_result != image_gen_result),
|
||||
# forzamos que sys.modules["image_gen_result"] apunte al modulo ya cargado
|
||||
# como ml.image_gen_result antes de importar diffusers_generate.
|
||||
import sys as _sys
|
||||
import ml.image_gen_result as _igr_module
|
||||
import ml.generation_config as _gcfg_module
|
||||
import ml.genconfig_to_diffusers_kwargs as _gkwargs_module
|
||||
for _alias, _mod in [
|
||||
("image_gen_result", _igr_module),
|
||||
("generation_config", _gcfg_module),
|
||||
("genconfig_to_diffusers_kwargs", _gkwargs_module),
|
||||
]:
|
||||
if _alias not in _sys.modules:
|
||||
_sys.modules[_alias] = _mod
|
||||
|
||||
from ml.diffusers_generate import diffusers_generate
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constantes
|
||||
# ---------------------------------------------------------------------------
|
||||
SD_TURBO_PATH = "/home/lucas/vaults/imagegen_models/diffusers/sd-turbo"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sd_turbo_model() -> ModelRef:
|
||||
"""ModelRef apuntando a SD Turbo local."""
|
||||
if not os.path.isdir(SD_TURBO_PATH):
|
||||
pytest.skip(f"SD Turbo no encontrado en {SD_TURBO_PATH}")
|
||||
return ModelRef(
|
||||
name="sd-turbo",
|
||||
model_type="sd15",
|
||||
quantization="fp16",
|
||||
path=SD_TURBO_PATH,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loaded_pipe(sd_turbo_model: ModelRef):
|
||||
"""Pipeline SD Turbo cargado una sola vez para toda la sesion de tests."""
|
||||
# Intentar fp16 primero; si falla (no hay variante fp16) usar fp32
|
||||
try:
|
||||
pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16")
|
||||
except Exception:
|
||||
_clear_pipeline_cache()
|
||||
pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp32")
|
||||
yield pipe
|
||||
# Teardown: liberar al final de la sesion
|
||||
diffusers_unload(None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sd_turbo_cfg(sd_turbo_model: ModelRef) -> GenerationConfig:
|
||||
"""GenerationConfig minimo para SD Turbo (1 step, 512x512)."""
|
||||
return GenerationConfig(
|
||||
prompt="a simple red circle on white background",
|
||||
negative_prompt=None,
|
||||
seed=42,
|
||||
steps=1,
|
||||
cfg_scale=0.0,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model=sd_turbo_model,
|
||||
loras=[],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: carga pipeline y retorna callable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_load_pipeline_returns_callable(sd_turbo_model: ModelRef) -> None:
|
||||
"""carga pipeline y retorna callable"""
|
||||
_clear_pipeline_cache()
|
||||
pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16")
|
||||
assert callable(pipe), "El pipeline debe ser callable"
|
||||
assert hasattr(pipe, "scheduler"), "El pipeline debe tener atributo scheduler"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: segunda carga usa cache (< 100ms)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_load_pipeline_caches(sd_turbo_model: ModelRef) -> None:
|
||||
"""segunda carga usa cache (< 100ms)"""
|
||||
# Primera carga (puede tardar varios segundos)
|
||||
_clear_pipeline_cache()
|
||||
_ = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16")
|
||||
|
||||
# Segunda carga debe ser cache hit
|
||||
t0 = time.perf_counter()
|
||||
pipe2 = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16")
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
assert elapsed_ms < 100, (
|
||||
f"Segunda carga tardo {elapsed_ms:.1f}ms (esperado < 100ms — debe ser cache hit)"
|
||||
)
|
||||
assert pipe2 is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: set_scheduler cambia la clase del scheduler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_set_scheduler_changes_scheduler_class(loaded_pipe) -> None:
|
||||
"""euler cambia scheduler a EulerDiscreteScheduler"""
|
||||
pipe = diffusers_set_scheduler(loaded_pipe, "euler")
|
||||
scheduler_name = type(pipe.scheduler).__name__
|
||||
assert scheduler_name == "EulerDiscreteScheduler", (
|
||||
f"Esperado EulerDiscreteScheduler, obtenido {scheduler_name}"
|
||||
)
|
||||
|
||||
|
||||
def test_set_scheduler_euler_a(loaded_pipe) -> None:
|
||||
"""euler_a cambia scheduler a EulerAncestralDiscreteScheduler"""
|
||||
pipe = diffusers_set_scheduler(loaded_pipe, "euler_a")
|
||||
scheduler_name = type(pipe.scheduler).__name__
|
||||
assert scheduler_name == "EulerAncestralDiscreteScheduler", (
|
||||
f"Esperado EulerAncestralDiscreteScheduler, obtenido {scheduler_name}"
|
||||
)
|
||||
# Restaurar euler para no afectar otros tests
|
||||
diffusers_set_scheduler(loaded_pipe, "euler")
|
||||
|
||||
|
||||
def test_set_scheduler_invalid_raises_value_error(loaded_pipe) -> None:
|
||||
"""sampler invalido lanza ValueError"""
|
||||
with pytest.raises(ValueError, match="no soportado"):
|
||||
diffusers_set_scheduler(loaded_pipe, "nonexistent_sampler_xyz")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: genera imagen retorna ImageGenResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_generate_returns_image_gen_result(
|
||||
loaded_pipe, sd_turbo_cfg: GenerationConfig
|
||||
) -> None:
|
||||
"""genera imagen retorna ImageGenResult"""
|
||||
result = diffusers_generate(loaded_pipe, sd_turbo_cfg)
|
||||
|
||||
assert isinstance(result, ImageGenResult), (
|
||||
f"Esperado ImageGenResult, obtenido {type(result)}"
|
||||
)
|
||||
assert result.image is not None, "result.image no debe ser None"
|
||||
assert result.duration_ms > 0, (
|
||||
f"duration_ms debe ser positivo, obtenido {result.duration_ms}"
|
||||
)
|
||||
assert "backend" in result.meta, "meta debe tener key 'backend'"
|
||||
assert result.meta["backend"] == "diffusers", (
|
||||
f"meta['backend'] debe ser 'diffusers', obtenido {result.meta['backend']}"
|
||||
)
|
||||
assert "model" in result.meta, "meta debe tener key 'model'"
|
||||
|
||||
# Verificar que la imagen tiene las dimensiones correctas
|
||||
w, h = result.image.size
|
||||
assert w == sd_turbo_cfg.width and h == sd_turbo_cfg.height, (
|
||||
f"Imagen esperada {sd_turbo_cfg.width}x{sd_turbo_cfg.height}, "
|
||||
f"obtenida {w}x{h}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: unload limpia cache cuda si disponible
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_unload_clears_cuda() -> None:
|
||||
"""unload None limpia cache cuda si disponible"""
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
# Limpiar cache — no debe lanzar excepcion independientemente de si hay CUDA
|
||||
diffusers_unload(None)
|
||||
|
||||
if cuda_available:
|
||||
# Despues de empty_cache, la memoria reservada por el allocator baja
|
||||
# No podemos asumir que sea 0 (otros tensores pueden estar vivos),
|
||||
# pero la llamada debe completarse sin error.
|
||||
reserved = torch.cuda.memory_reserved()
|
||||
# Solo verificamos que no lanza excepcion y que la llamada completo
|
||||
assert reserved >= 0, "memory_reserved debe ser >= 0"
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests de roundtrip JSON para genconfig_save_json y genconfig_load_json."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ml.genconfig_save_json import genconfig_save_json
|
||||
from ml.genconfig_load_json import genconfig_load_json
|
||||
from ml.generation_config import GenerationConfig
|
||||
|
||||
|
||||
def _make_cfg(**overrides):
|
||||
"""Construye un GenerationConfig sintetico usando model_validate para evitar
|
||||
problemas de identidad de clase entre modulos pydantic separados."""
|
||||
defaults = dict(
|
||||
prompt="a forest at dusk",
|
||||
negative_prompt="blurry, low quality",
|
||||
seed=123,
|
||||
steps=25,
|
||||
cfg_scale=7.5,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model={"name": "runwayml/stable-diffusion-v1-5", "model_type": "sd15"},
|
||||
loras=[{"path": "/loras/detail.safetensors", "weight": 0.7}],
|
||||
)
|
||||
defaults.update(overrides)
|
||||
try:
|
||||
return GenerationConfig.model_validate(defaults)
|
||||
except AttributeError:
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.lora_ref import LoraRef
|
||||
m = defaults.pop("model")
|
||||
if isinstance(m, dict):
|
||||
m = ModelRef(**m)
|
||||
loras = defaults.pop("loras", [])
|
||||
built = [LoraRef(**lr) if isinstance(lr, dict) else lr for lr in loras]
|
||||
return GenerationConfig(model=m, loras=tuple(built), **defaults)
|
||||
|
||||
|
||||
class TestGenconfigJsonRoundtrip(unittest.TestCase):
|
||||
|
||||
def test_save_escribe_archivo_json_valido_en_la_ruta_indicada(self):
|
||||
"""save escribe archivo JSON valido en la ruta indicada"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
saved = genconfig_save_json(cfg, path)
|
||||
self.assertTrue(os.path.isfile(saved))
|
||||
with open(saved, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
self.assertIsInstance(data, dict)
|
||||
self.assertEqual(data["prompt"], "a forest at dusk")
|
||||
|
||||
def test_save_crea_directorios_padre_si_no_existen(self):
|
||||
"""save crea directorios padre si no existen"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
nested = os.path.join(tmpdir, "a", "b", "c", "config.json")
|
||||
saved = genconfig_save_json(cfg, nested)
|
||||
self.assertTrue(os.path.isfile(saved))
|
||||
|
||||
def test_json_contiene_claves_en_snake_case(self):
|
||||
"""json contiene claves en snake_case"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
# Claves deben ser snake_case (interoperabilidad con Go)
|
||||
expected_keys = {
|
||||
"prompt", "negative_prompt", "seed", "steps",
|
||||
"cfg_scale", "sampler", "width", "height", "model",
|
||||
}
|
||||
for key in expected_keys:
|
||||
self.assertIn(key, data, f"Clave snake_case faltante: {key}")
|
||||
# No debe haber camelCase
|
||||
self.assertNotIn("negativePrompt", data)
|
||||
self.assertNotIn("cfgScale", data)
|
||||
self.assertNotIn("numInferenceSteps", data)
|
||||
|
||||
def test_roundtrip_preserva_campos_escalares(self):
|
||||
"""roundtrip save→load preserva campos escalares"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
loaded = genconfig_load_json(path)
|
||||
self.assertEqual(loaded.prompt, cfg.prompt)
|
||||
self.assertEqual(loaded.negative_prompt, cfg.negative_prompt)
|
||||
self.assertEqual(loaded.seed, cfg.seed)
|
||||
self.assertEqual(loaded.steps, cfg.steps)
|
||||
self.assertAlmostEqual(loaded.cfg_scale, cfg.cfg_scale)
|
||||
self.assertEqual(loaded.sampler, cfg.sampler)
|
||||
self.assertEqual(loaded.width, cfg.width)
|
||||
self.assertEqual(loaded.height, cfg.height)
|
||||
|
||||
def test_roundtrip_preserva_model_ref(self):
|
||||
"""roundtrip preserva ModelRef"""
|
||||
cfg = _make_cfg(
|
||||
model={
|
||||
"name": "stabilityai/sdxl-base-1.0",
|
||||
"model_type": "sdxl",
|
||||
"quantization": "fp16",
|
||||
"path": "/models/sdxl.safetensors",
|
||||
}
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
loaded = genconfig_load_json(path)
|
||||
self.assertEqual(loaded.model.name, "stabilityai/sdxl-base-1.0")
|
||||
self.assertEqual(loaded.model.model_type, "sdxl")
|
||||
self.assertEqual(loaded.model.quantization, "fp16")
|
||||
self.assertEqual(loaded.model.path, "/models/sdxl.safetensors")
|
||||
|
||||
def test_roundtrip_preserva_loras(self):
|
||||
"""roundtrip preserva lista de LoraRef"""
|
||||
cfg = _make_cfg(
|
||||
loras=[
|
||||
{"path": "/loras/a.safetensors", "weight": 0.8},
|
||||
{"path": "/loras/b.safetensors", "weight": 0.5},
|
||||
]
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
loaded = genconfig_load_json(path)
|
||||
loaded_loras = list(loaded.loras)
|
||||
self.assertEqual(len(loaded_loras), 2)
|
||||
paths = [lr.path for lr in loaded_loras]
|
||||
self.assertIn("/loras/a.safetensors", paths)
|
||||
self.assertIn("/loras/b.safetensors", paths)
|
||||
|
||||
def test_roundtrip_negative_prompt_none(self):
|
||||
"""roundtrip con negative_prompt=None"""
|
||||
cfg = _make_cfg(negative_prompt=None)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "config.json")
|
||||
genconfig_save_json(cfg, path)
|
||||
loaded = genconfig_load_json(path)
|
||||
self.assertIsNone(loaded.negative_prompt)
|
||||
|
||||
def test_load_falla_con_file_not_found(self):
|
||||
"""load lanza FileNotFoundError si el archivo no existe"""
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
genconfig_load_json("/tmp/nonexistent_fn_registry_test_12345.json")
|
||||
|
||||
def test_save_retorna_path_absoluto(self):
|
||||
"""save retorna path absoluto aunque se pase path relativo"""
|
||||
cfg = _make_cfg()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
abs_path = os.path.join(tmpdir, "cfg.json")
|
||||
result = genconfig_save_json(cfg, abs_path)
|
||||
self.assertTrue(os.path.isabs(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Tests para genconfig_to_diffusers_kwargs."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ml.genconfig_to_diffusers_kwargs import genconfig_to_diffusers_kwargs
|
||||
from ml.generation_config import GenerationConfig
|
||||
|
||||
|
||||
def _make_cfg(**overrides):
|
||||
"""Crea un GenerationConfig sintetico para tests via model_validate / constructor."""
|
||||
defaults = dict(
|
||||
prompt="a dog in the park",
|
||||
seed=42,
|
||||
steps=30,
|
||||
cfg_scale=7.5,
|
||||
sampler="euler_a",
|
||||
width=512,
|
||||
height=768,
|
||||
model={"name": "runwayml/stable-diffusion-v1-5", "model_type": "sd15"},
|
||||
)
|
||||
defaults.update(overrides)
|
||||
try:
|
||||
return GenerationConfig.model_validate(defaults)
|
||||
except AttributeError:
|
||||
# dataclass fallback: model y loras ya son dicts, construir manualmente
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.lora_ref import LoraRef
|
||||
m = defaults.pop("model")
|
||||
if isinstance(m, dict):
|
||||
m = ModelRef(**m)
|
||||
loras = defaults.pop("loras", [])
|
||||
built_loras = []
|
||||
for lr in loras:
|
||||
if isinstance(lr, dict):
|
||||
built_loras.append(LoraRef(**lr))
|
||||
else:
|
||||
built_loras.append(lr)
|
||||
return GenerationConfig(model=m, loras=tuple(built_loras), **defaults)
|
||||
|
||||
|
||||
class TestGenconfigToDiffusersKwargs(unittest.TestCase):
|
||||
|
||||
def test_kwargs_contiene_todas_las_claves_requeridas(self):
|
||||
"""kwargs contiene todas las claves requeridas"""
|
||||
cfg = _make_cfg()
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
required_keys = {
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"width",
|
||||
"height",
|
||||
"generator",
|
||||
}
|
||||
self.assertEqual(set(kwargs.keys()), required_keys)
|
||||
|
||||
def test_negative_prompt_none_se_pasa_tal_cual(self):
|
||||
"""negative_prompt None se pasa tal cual"""
|
||||
cfg = _make_cfg(negative_prompt=None)
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertIsNone(kwargs["negative_prompt"])
|
||||
|
||||
def test_steps_y_cfg_scale_se_mapean_a_num_inference_steps_y_guidance_scale(self):
|
||||
"""steps y cfg_scale se mapean a num_inference_steps y guidance_scale"""
|
||||
cfg = _make_cfg(steps=20, cfg_scale=8.0)
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertEqual(kwargs["num_inference_steps"], 20)
|
||||
self.assertAlmostEqual(kwargs["guidance_scale"], 8.0)
|
||||
|
||||
def test_generator_siempre_es_none(self):
|
||||
"""generator siempre es None"""
|
||||
cfg = _make_cfg()
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertIsNone(kwargs["generator"])
|
||||
|
||||
def test_prompt_se_copia_sin_modificar(self):
|
||||
"""prompt se copia sin modificar"""
|
||||
cfg = _make_cfg(prompt="a cat on a roof")
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertEqual(kwargs["prompt"], "a cat on a roof")
|
||||
|
||||
def test_width_y_height_se_preservan(self):
|
||||
"""width y height se preservan"""
|
||||
cfg = _make_cfg(width=1024, height=768)
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertEqual(kwargs["width"], 1024)
|
||||
self.assertEqual(kwargs["height"], 768)
|
||||
|
||||
def test_negative_prompt_string_se_pasa_tal_cual(self):
|
||||
"""negative_prompt string se pasa tal cual"""
|
||||
cfg = _make_cfg(negative_prompt="blurry, low quality")
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertEqual(kwargs["negative_prompt"], "blurry, low quality")
|
||||
|
||||
def test_no_incluye_seed_sampler_ni_loras(self):
|
||||
"""no incluye seed sampler ni loras en el dict"""
|
||||
cfg = _make_cfg(
|
||||
loras=[{"path": "/loras/detail.safetensors", "weight": 0.8}]
|
||||
)
|
||||
kwargs = genconfig_to_diffusers_kwargs(cfg)
|
||||
self.assertNotIn("seed", kwargs)
|
||||
self.assertNotIn("sampler", kwargs)
|
||||
self.assertNotIn("loras", kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Tests para genconfig_to_sdcpp_args."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ml.genconfig_to_sdcpp_args import genconfig_to_sdcpp_args, _SAMPLER_MAP
|
||||
from ml.generation_config import GenerationConfig
|
||||
|
||||
|
||||
def _make_cfg(**overrides):
|
||||
"""Crea un GenerationConfig sintetico para tests via model_validate / constructor."""
|
||||
defaults = dict(
|
||||
prompt="a cat",
|
||||
seed=1,
|
||||
steps=20,
|
||||
cfg_scale=7.0,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model={"name": "v1-5-pruned.ckpt", "model_type": "sd15", "path": "/models/v1-5.ckpt"},
|
||||
)
|
||||
defaults.update(overrides)
|
||||
# Normalizar loras a dicts si fueron pasados como LoraRef
|
||||
if "loras" in defaults:
|
||||
normalized = []
|
||||
for lr in defaults["loras"]:
|
||||
if hasattr(lr, "__dict__") and not isinstance(lr, dict):
|
||||
normalized.append({"path": lr.path, "weight": lr.weight, "scale": lr.scale})
|
||||
else:
|
||||
normalized.append(lr)
|
||||
defaults["loras"] = normalized
|
||||
try:
|
||||
return GenerationConfig.model_validate(defaults)
|
||||
except AttributeError:
|
||||
from ml.model_ref import ModelRef
|
||||
from ml.lora_ref import LoraRef
|
||||
m = defaults.pop("model")
|
||||
if isinstance(m, dict):
|
||||
m = ModelRef(**m)
|
||||
loras = defaults.pop("loras", [])
|
||||
built = [LoraRef(**lr) if isinstance(lr, dict) else lr for lr in loras]
|
||||
return GenerationConfig(model=m, loras=tuple(built), **defaults)
|
||||
|
||||
|
||||
def _get_flag_value(args: list[str], flag: str) -> str | None:
|
||||
"""Extrae el valor de un flag en la lista de args."""
|
||||
try:
|
||||
idx = args.index(flag)
|
||||
return args[idx + 1]
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
def _get_all_flag_values(args: list[str], flag: str) -> list[str]:
|
||||
"""Extrae todos los valores de un flag repetido (ej. --lora)."""
|
||||
values = []
|
||||
for i, arg in enumerate(args):
|
||||
if arg == flag and i + 1 < len(args):
|
||||
values.append(args[i + 1])
|
||||
return values
|
||||
|
||||
|
||||
class TestGenconfigToSdcppArgs(unittest.TestCase):
|
||||
|
||||
def test_sampler_euler_a_se_mapea_a_euler_a_en_el_flag_sampling_method(self):
|
||||
"""sampler euler_a se mapea a euler_a en el flag --sampling-method"""
|
||||
cfg = _make_cfg(sampler="euler_a")
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "--sampling-method"), "euler_a")
|
||||
|
||||
def test_sampler_dpm_pp_2m_se_mapea_a_dpmpp2m(self):
|
||||
"""sampler dpm++2m se mapea a dpmpp2m"""
|
||||
cfg = _make_cfg(sampler="dpm++2m")
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "--sampling-method"), "dpmpp2m")
|
||||
|
||||
def test_lora_con_path_y_weight_se_agrega_como_lora_path_weight(self):
|
||||
"""lora con path y weight se agrega como --lora path:weight"""
|
||||
cfg = _make_cfg(
|
||||
loras=[{"path": "/loras/detail.safetensors", "weight": 0.8}]
|
||||
)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
lora_values = _get_all_flag_values(args, "--lora")
|
||||
self.assertEqual(len(lora_values), 1)
|
||||
self.assertEqual(lora_values[0], "/loras/detail.safetensors:0.8")
|
||||
|
||||
def test_multiples_loras_generan_multiples_pares_lora(self):
|
||||
"""multiples loras generan multiples pares --lora"""
|
||||
cfg = _make_cfg(
|
||||
loras=[
|
||||
{"path": "/loras/a.safetensors", "weight": 0.5},
|
||||
{"path": "/loras/b.safetensors", "weight": 1.0},
|
||||
]
|
||||
)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
lora_values = _get_all_flag_values(args, "--lora")
|
||||
self.assertEqual(len(lora_values), 2)
|
||||
self.assertIn("/loras/a.safetensors:0.5", lora_values)
|
||||
self.assertIn("/loras/b.safetensors:1.0", lora_values)
|
||||
|
||||
def test_negative_prompt_none_produce_string_vacio_en_negative_prompt(self):
|
||||
"""negative_prompt None produce string vacio en --negative-prompt"""
|
||||
cfg = _make_cfg(negative_prompt=None)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "--negative-prompt"), "")
|
||||
|
||||
def test_model_path_tiene_prioridad_sobre_model_name_en_m(self):
|
||||
"""model.path tiene prioridad sobre model.name en -m"""
|
||||
cfg = _make_cfg(
|
||||
model={"name": "hub-name", "model_type": "sd15", "path": "/local/path/model.ckpt"}
|
||||
)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "-m"), "/local/path/model.ckpt")
|
||||
|
||||
def test_sin_path_usa_model_name_en_m(self):
|
||||
"""sin path usa model.name en -m"""
|
||||
cfg = _make_cfg(
|
||||
model={"name": "runwayml/sd-v1-5", "model_type": "sd15", "path": None}
|
||||
)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "-m"), "runwayml/sd-v1-5")
|
||||
|
||||
def test_args_contiene_flags_obligatorios(self):
|
||||
"""args contiene --prompt --seed --steps --cfg-scale --sampling-method -W -H -m"""
|
||||
cfg = _make_cfg()
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
for flag in ["--prompt", "--seed", "--steps", "--cfg-scale", "--sampling-method", "-W", "-H", "-m"]:
|
||||
self.assertIn(flag, args, f"Flag faltante: {flag}")
|
||||
|
||||
def test_sampler_map_cubre_todos_los_samplers_canonicos(self):
|
||||
"""_SAMPLER_MAP cubre todos los samplers canonicos del dominio ml"""
|
||||
canonical = {"euler", "euler_a", "dpm++2m", "dpm++2m_v2", "heun", "dpm2", "lcm"}
|
||||
self.assertEqual(set(_SAMPLER_MAP.keys()), canonical)
|
||||
|
||||
def test_seed_steps_width_height_se_convierten_a_string(self):
|
||||
"""seed steps width height se convierten a string en los args"""
|
||||
cfg = _make_cfg(seed=42, steps=25, width=768, height=512)
|
||||
args = genconfig_to_sdcpp_args(cfg)
|
||||
self.assertEqual(_get_flag_value(args, "--seed"), "42")
|
||||
self.assertEqual(_get_flag_value(args, "--steps"), "25")
|
||||
self.assertEqual(_get_flag_value(args, "-W"), "768")
|
||||
self.assertEqual(_get_flag_value(args, "-H"), "512")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Tests para GenerationConfig — serialización, roundtrip y frozen."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
# Añadir python/functions/ml al path para que los imports internos del módulo
|
||||
# (from lora_ref import LoraRef, from model_ref import ModelRef) funcionen.
|
||||
# Los módulos se importan directamente desde el subdirectorio para evitar
|
||||
# colisiones de tipos entre ml.generation_config.* y generation_config.*.
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
from generation_config import GenerationConfig
|
||||
from lora_ref import LoraRef
|
||||
from model_ref import ModelRef
|
||||
|
||||
|
||||
def _make_model() -> ModelRef:
|
||||
return ModelRef(name="stabilityai/stable-diffusion-v1-5", model_type="sd15")
|
||||
|
||||
|
||||
def _make_config() -> GenerationConfig:
|
||||
return GenerationConfig(
|
||||
prompt="a cat in the moonlight",
|
||||
negative_prompt="blurry, low quality",
|
||||
seed=42,
|
||||
steps=30,
|
||||
cfg_scale=7.5,
|
||||
sampler="euler_a",
|
||||
width=512,
|
||||
height=512,
|
||||
model=_make_model(),
|
||||
loras=[],
|
||||
clip_skip=1,
|
||||
)
|
||||
|
||||
|
||||
def test_instancia_ok():
|
||||
"""GenerationConfig crea instancia sin errores"""
|
||||
cfg = _make_config()
|
||||
assert cfg.prompt == "a cat in the moonlight"
|
||||
assert cfg.seed == 42
|
||||
assert cfg.steps == 30
|
||||
assert cfg.cfg_scale == 7.5
|
||||
assert cfg.sampler == "euler_a"
|
||||
assert cfg.width == 512
|
||||
assert cfg.height == 512
|
||||
assert cfg.clip_skip == 1
|
||||
|
||||
|
||||
def test_model_dump_keys_snake_case():
|
||||
"""model_dump devuelve dict con keys snake_case incluyendo negative_prompt, cfg_scale, clip_skip"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
cfg = _make_config()
|
||||
d = cfg.model_dump()
|
||||
assert isinstance(d, dict)
|
||||
assert "negative_prompt" in d
|
||||
assert "cfg_scale" in d
|
||||
assert "clip_skip" in d
|
||||
assert d["negative_prompt"] == "blurry, low quality"
|
||||
assert d["cfg_scale"] == 7.5
|
||||
assert d["clip_skip"] == 1
|
||||
|
||||
|
||||
def test_model_dump_json_parseable():
|
||||
"""model_dump_json retorna str JSON parseable"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
cfg = _make_config()
|
||||
raw = cfg.model_dump_json()
|
||||
assert isinstance(raw, str)
|
||||
parsed = json.loads(raw)
|
||||
assert isinstance(parsed, dict)
|
||||
assert parsed["prompt"] == "a cat in the moonlight"
|
||||
assert parsed["seed"] == 42
|
||||
|
||||
|
||||
def test_roundtrip_model_validate():
|
||||
"""GenerationConfig.model_validate(json.loads(...)) roundtrip ok"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
cfg = _make_config()
|
||||
raw_json = cfg.model_dump_json()
|
||||
parsed = json.loads(raw_json)
|
||||
cfg2 = GenerationConfig.model_validate(parsed)
|
||||
assert cfg2.prompt == cfg.prompt
|
||||
assert cfg2.seed == cfg.seed
|
||||
assert cfg2.cfg_scale == cfg.cfg_scale
|
||||
assert cfg2.sampler == cfg.sampler
|
||||
assert cfg2.clip_skip == cfg.clip_skip
|
||||
assert cfg2.model.name == cfg.model.name
|
||||
assert cfg2.model.model_type == cfg.model.model_type
|
||||
|
||||
|
||||
def test_frozen_levanta_error_al_mutar():
|
||||
"""frozen: intentar mutar levanta AttributeError, ValidationError o FrozenInstanceError"""
|
||||
cfg = _make_config()
|
||||
raised = False
|
||||
try:
|
||||
# dataclass frozen y pydantic frozen levantan distintas excepciones
|
||||
cfg.prompt = "mutated" # type: ignore[misc]
|
||||
except Exception:
|
||||
raised = True
|
||||
|
||||
assert raised, "Se esperaba que mutar un campo frozen lanzara una excepcion"
|
||||
|
||||
|
||||
def test_negative_prompt_opcional():
|
||||
"""negative_prompt es opcional (default None)"""
|
||||
cfg = GenerationConfig(
|
||||
prompt="mountains",
|
||||
seed=0,
|
||||
steps=20,
|
||||
cfg_scale=7.0,
|
||||
sampler="euler",
|
||||
width=512,
|
||||
height=512,
|
||||
model=_make_model(),
|
||||
)
|
||||
assert cfg.negative_prompt is None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Tests para gpu_info."""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(0, "python/functions")
|
||||
|
||||
from ml.gpu_info import gpu_info
|
||||
|
||||
|
||||
class TestGpuInfo(unittest.TestCase):
|
||||
|
||||
def test_sin_nvidia_smi_devuelve_lista_vacia(self):
|
||||
"""sin nvidia-smi devuelve lista vacia"""
|
||||
with patch("subprocess.run", side_effect=FileNotFoundError()):
|
||||
result = gpu_info()
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_formato_CSV_correcto_devuelve_lista_con_un_dict_por_GPU(self):
|
||||
"""formato CSV correcto devuelve lista con un dict por GPU"""
|
||||
csv_output = " 0, NVIDIA RTX 4090, 24564, 22000, 535.183.01, 8.9\n"
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = csv_output
|
||||
with patch("subprocess.run", return_value=mock_result):
|
||||
result = gpu_info()
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["index"], 0)
|
||||
self.assertEqual(result[0]["name"], "NVIDIA RTX 4090")
|
||||
self.assertEqual(result[0]["vram_total_mb"], 24564)
|
||||
self.assertEqual(result[0]["vram_free_mb"], 22000)
|
||||
self.assertEqual(result[0]["driver_version"], "535.183.01")
|
||||
self.assertEqual(result[0]["cuda_version"], "8.9")
|
||||
|
||||
def test_fila_malformada_en_CSV_se_ignora_sin_excepcion(self):
|
||||
"""fila malformada en CSV se ignora sin excepcion"""
|
||||
csv_output = " 0, RTX 4090, NONNUMERIC, 22000, 535.183.01, 8.9\n"
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = csv_output
|
||||
with patch("subprocess.run", return_value=mock_result):
|
||||
result = gpu_info()
|
||||
self.assertEqual(result, [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Tests para hf_snapshot_download — mockear snapshot_download y verificar args."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
# Saltar si huggingface_hub no esta disponible Y no podemos mockearlo
|
||||
# Usamos un mock inline para no requerir la lib real.
|
||||
# Si la lib esta disponible, monkeypatch la reemplaza. Si no, la inyectamos manualmente.
|
||||
|
||||
|
||||
def _inject_fake_hf_hub(monkeypatch, capture: list):
|
||||
"""Inyecta un modulo huggingface_hub falso con snapshot_download que captura kwargs."""
|
||||
|
||||
def fake_snapshot_download(**kwargs):
|
||||
capture.append(kwargs)
|
||||
return "/tmp/fake_snapshot"
|
||||
|
||||
fake_module = types.ModuleType("huggingface_hub")
|
||||
fake_module.snapshot_download = fake_snapshot_download
|
||||
monkeypatch.setitem(sys.modules, "huggingface_hub", fake_module)
|
||||
|
||||
|
||||
def test_args_minimos_repo_id(monkeypatch):
|
||||
"""repo_id se pasa correctamente a snapshot_download"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
result = hf_snapshot_download("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
assert len(capture) == 1
|
||||
assert capture[0]["repo_id"] == "runwayml/stable-diffusion-v1-5"
|
||||
assert result == "/tmp/fake_snapshot"
|
||||
|
||||
|
||||
def test_retorna_string(monkeypatch):
|
||||
"""hf_snapshot_download retorna un string (la ruta local)"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
result = hf_snapshot_download("some/repo")
|
||||
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_allow_patterns_se_pasa(monkeypatch):
|
||||
"""allow_patterns se incluye en los kwargs si se especifica"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("some/repo", allow_patterns=["*.safetensors", "*.json"])
|
||||
|
||||
assert "allow_patterns" in capture[0]
|
||||
assert capture[0]["allow_patterns"] == ["*.safetensors", "*.json"]
|
||||
|
||||
|
||||
def test_ignore_patterns_se_pasa(monkeypatch):
|
||||
"""ignore_patterns se incluye en los kwargs si se especifica"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("some/repo", ignore_patterns=["*.bin", "flax_*"])
|
||||
|
||||
assert "ignore_patterns" in capture[0]
|
||||
assert capture[0]["ignore_patterns"] == ["*.bin", "flax_*"]
|
||||
|
||||
|
||||
def test_local_dir_se_pasa(monkeypatch):
|
||||
"""local_dir se incluye en los kwargs si se especifica"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("some/repo", local_dir="/models/sd15")
|
||||
|
||||
assert "local_dir" in capture[0]
|
||||
assert capture[0]["local_dir"] == "/models/sd15"
|
||||
|
||||
|
||||
def test_token_se_pasa(monkeypatch):
|
||||
"""token se incluye en los kwargs si se especifica"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("private/model", token="hf_mytoken123")
|
||||
|
||||
assert "token" in capture[0]
|
||||
assert capture[0]["token"] == "hf_mytoken123"
|
||||
|
||||
|
||||
def test_none_args_no_se_pasan(monkeypatch):
|
||||
"""args opcionales None no se incluyen en kwargs (no contaminar snapshot_download)"""
|
||||
capture = []
|
||||
_inject_fake_hf_hub(monkeypatch, capture)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
hf_snapshot_download("some/repo")
|
||||
|
||||
kwargs = capture[0]
|
||||
# Solo repo_id debe estar presente — los None no se incluyen
|
||||
assert "allow_patterns" not in kwargs
|
||||
assert "ignore_patterns" not in kwargs
|
||||
assert "local_dir" not in kwargs
|
||||
assert "token" not in kwargs
|
||||
|
||||
|
||||
def test_import_error_sin_huggingface_hub(monkeypatch):
|
||||
"""ImportError descriptivo si huggingface_hub no esta instalado"""
|
||||
import importlib
|
||||
|
||||
# Inyectar None en sys.modules para simular libreria no instalada
|
||||
monkeypatch.setitem(sys.modules, "huggingface_hub", None)
|
||||
|
||||
# Recargar el modulo para que el try/except del top-level vea el None
|
||||
import hf_snapshot_download as _mod
|
||||
importlib.reload(_mod)
|
||||
|
||||
from hf_snapshot_download import hf_snapshot_download
|
||||
with pytest.raises(ImportError, match="huggingface_hub"):
|
||||
hf_snapshot_download("any/repo")
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Tests para image_compare_side_by_side."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
import pytest
|
||||
|
||||
PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping")
|
||||
|
||||
from PIL import Image
|
||||
from image_compare_side_by_side import image_compare_side_by_side
|
||||
|
||||
|
||||
def _black(w=16, h=16):
|
||||
return Image.new("RGB", (w, h), color=(0, 0, 0))
|
||||
|
||||
|
||||
def _white(w=16, h=16):
|
||||
return Image.new("RGB", (w, h), color=(255, 255, 255))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grid shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_grid_es_pil_image_con_dimensiones_correctas_show_diff_True():
|
||||
"""grid es PIL.Image con dimensiones correctas show_diff=True"""
|
||||
w, h = 16, 16
|
||||
gap = 16
|
||||
result = image_compare_side_by_side(_black(w, h), _white(w, h), gap_px=gap, show_diff=True)
|
||||
|
||||
grid = result["grid"]
|
||||
assert isinstance(grid, Image.Image), "grid debe ser PIL.Image.Image"
|
||||
|
||||
expected_w = 3 * w + 4 * gap # A + diff + B + 4 gaps
|
||||
expected_h = h + 2 * gap
|
||||
assert grid.size == (expected_w, expected_h), (
|
||||
f"Esperado ({expected_w}, {expected_h}), got {grid.size}"
|
||||
)
|
||||
|
||||
|
||||
def test_grid_es_pil_image_sin_diff_show_diff_False():
|
||||
"""grid es PIL.Image sin diff show_diff=False"""
|
||||
w, h = 16, 16
|
||||
gap = 8
|
||||
result = image_compare_side_by_side(_black(w, h), _white(w, h), gap_px=gap, show_diff=False)
|
||||
|
||||
grid = result["grid"]
|
||||
assert isinstance(grid, Image.Image), "grid debe ser PIL.Image.Image"
|
||||
|
||||
expected_w = 2 * w + 3 * gap # A + B + 3 gaps
|
||||
expected_h = h + 2 * gap
|
||||
assert grid.size == (expected_w, expected_h), (
|
||||
f"Esperado ({expected_w}, {expected_h}), got {grid.size}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MSE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_pixel_mse_positivo_para_imagenes_distintas():
|
||||
"""pixel_mse positivo para imagenes distintas"""
|
||||
result = image_compare_side_by_side(_black(), _white())
|
||||
mse = result["pixel_mse"]
|
||||
assert isinstance(mse, float), f"pixel_mse debe ser float, got {type(mse)}"
|
||||
assert mse > 0.0, f"pixel_mse debe ser > 0 para imagenes distintas, got {mse}"
|
||||
|
||||
|
||||
def test_pixel_mse_cero_para_imagen_identica():
|
||||
"""pixel_mse cero para imagen identica"""
|
||||
img = _black()
|
||||
result = image_compare_side_by_side(img, img.copy())
|
||||
mse = result["pixel_mse"]
|
||||
assert mse == 0.0, f"pixel_mse debe ser 0.0 para imagenes identicas, got {mse}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pHash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_phash_none_si_imagehash_no_disponible():
|
||||
"""phash None si imagehash no disponible"""
|
||||
try:
|
||||
import imagehash # noqa: F401
|
||||
pytest.skip("imagehash esta instalado — test de fallback no aplica")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
result = image_compare_side_by_side(_black(), _white(), show_phash=True)
|
||||
assert result["phash_a"] is None, "phash_a debe ser None si imagehash no instalado"
|
||||
assert result["phash_b"] is None, "phash_b debe ser None si imagehash no instalado"
|
||||
assert result["phash_distance"] is None, "phash_distance debe ser None si imagehash no instalado"
|
||||
|
||||
|
||||
def test_phash_presente_si_imagehash_disponible():
|
||||
"""phash presente si imagehash disponible"""
|
||||
try:
|
||||
import imagehash # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("imagehash no instalado")
|
||||
|
||||
result = image_compare_side_by_side(_black(), _white(), show_phash=True)
|
||||
assert isinstance(result["phash_a"], str), "phash_a debe ser str"
|
||||
assert isinstance(result["phash_b"], str), "phash_b debe ser str"
|
||||
assert isinstance(result["phash_distance"], int), "phash_distance debe ser int"
|
||||
assert len(result["phash_a"]) == 16, f"phash_a debe tener 16 hex chars, got {len(result['phash_a'])}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Campos del resultado
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_resultado_tiene_todas_las_claves():
|
||||
"""resultado tiene todas las claves esperadas"""
|
||||
result = image_compare_side_by_side(_black(), _white())
|
||||
for key in ("grid", "phash_a", "phash_b", "phash_distance", "pixel_mse"):
|
||||
assert key in result, f"Clave '{key}' faltante en resultado"
|
||||
|
||||
|
||||
def test_show_phash_false_deja_phash_none():
|
||||
"""show_phash=False deja phash* en None sin intentar import"""
|
||||
result = image_compare_side_by_side(_black(), _white(), show_phash=False)
|
||||
assert result["phash_a"] is None
|
||||
assert result["phash_b"] is None
|
||||
assert result["phash_distance"] is None
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Tests para ImageGenResult — dump excluye image, meta viaja correctamente."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
from image_gen_result import ImageGenResult
|
||||
|
||||
|
||||
def _make_result(image=None, duration_ms=1234, vram_peak_mb=None, meta=None):
|
||||
if meta is None:
|
||||
meta = {
|
||||
"model": "sd15",
|
||||
"seed_used": 42,
|
||||
"sampler": "euler_a",
|
||||
"prompt": "a cat",
|
||||
}
|
||||
return ImageGenResult(
|
||||
image=image,
|
||||
meta=meta,
|
||||
duration_ms=duration_ms,
|
||||
vram_peak_mb=vram_peak_mb,
|
||||
)
|
||||
|
||||
|
||||
def test_instancia_ok():
|
||||
"""ImageGenResult crea instancia sin errores"""
|
||||
r = _make_result(duration_ms=500)
|
||||
assert r.duration_ms == 500
|
||||
assert isinstance(r.meta, dict)
|
||||
|
||||
|
||||
def test_dump_excluye_image():
|
||||
"""model_dump excluye el campo image automaticamente"""
|
||||
|
||||
class FakeImage:
|
||||
"""Objeto imagen simulado (no PIL real)."""
|
||||
pass
|
||||
|
||||
r = _make_result(image=FakeImage(), duration_ms=800)
|
||||
d = r.model_dump()
|
||||
assert isinstance(d, dict)
|
||||
assert "image" not in d, "image no debe aparecer en model_dump()"
|
||||
|
||||
|
||||
def test_dump_incluye_meta_duration_vram():
|
||||
"""model_dump incluye meta, duration_ms y vram_peak_mb"""
|
||||
meta = {"model": "sdxl", "seed_used": 99, "sampler": "dpm++2m"}
|
||||
r = _make_result(duration_ms=2000, vram_peak_mb=6144, meta=meta)
|
||||
d = r.model_dump()
|
||||
assert "meta" in d
|
||||
assert "duration_ms" in d
|
||||
assert "vram_peak_mb" in d
|
||||
assert d["duration_ms"] == 2000
|
||||
assert d["vram_peak_mb"] == 6144
|
||||
|
||||
|
||||
def test_meta_dict_viaja_completo():
|
||||
"""meta dict se conserva completo en model_dump"""
|
||||
meta = {
|
||||
"model": "flux_dev",
|
||||
"seed_used": 777,
|
||||
"sampler": "euler",
|
||||
"custom_key": "custom_value",
|
||||
"nested": {"a": 1},
|
||||
}
|
||||
r = _make_result(meta=meta)
|
||||
d = r.model_dump()
|
||||
assert d["meta"] == meta
|
||||
assert d["meta"]["custom_key"] == "custom_value"
|
||||
assert d["meta"]["nested"] == {"a": 1}
|
||||
|
||||
|
||||
def test_dump_json_parseable():
|
||||
"""model_dump_json retorna string JSON parseable sin image"""
|
||||
meta = {"model": "sd15", "seed_used": 1}
|
||||
r = _make_result(duration_ms=100, meta=meta)
|
||||
raw = r.model_dump_json()
|
||||
assert isinstance(raw, str)
|
||||
parsed = json.loads(raw)
|
||||
assert "meta" in parsed
|
||||
assert "duration_ms" in parsed
|
||||
assert "image" not in parsed
|
||||
|
||||
|
||||
def test_vram_peak_mb_none_serializa():
|
||||
"""vram_peak_mb=None se serializa correctamente a null"""
|
||||
r = _make_result(vram_peak_mb=None)
|
||||
d = r.model_dump()
|
||||
assert d["vram_peak_mb"] is None
|
||||
|
||||
|
||||
def test_image_none_permitido():
|
||||
"""image puede ser None (generacion fallida)"""
|
||||
r = _make_result(image=None)
|
||||
assert r.image is None
|
||||
d = r.model_dump()
|
||||
assert "image" not in d
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Tests para ImageGenerator Protocol — runtime_checkable y structural subtyping."""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
from image_gen_result import ImageGenResult
|
||||
from image_generator import ImageGenerator
|
||||
|
||||
|
||||
class MockGenerator:
|
||||
"""Implementacion dummy que satisface ImageGenerator sin herencia explicita."""
|
||||
|
||||
def generate(self, config):
|
||||
"""Retorna un ImageGenResult sin imagen real."""
|
||||
return ImageGenResult(
|
||||
image=None,
|
||||
meta={"model": "mock", "seed_used": 0, "sampler": "euler"},
|
||||
duration_ms=1,
|
||||
vram_peak_mb=None,
|
||||
)
|
||||
|
||||
|
||||
class NotAGenerator:
|
||||
"""Clase que NO implementa generate — no satisface el Protocol."""
|
||||
|
||||
def predict(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def test_dummy_satisface_protocol():
|
||||
"""clase dummy que implementa generate satisface isinstance(x, ImageGenerator)"""
|
||||
gen = MockGenerator()
|
||||
assert isinstance(gen, ImageGenerator), (
|
||||
"MockGenerator debe satisfacer ImageGenerator Protocol (runtime_checkable)"
|
||||
)
|
||||
|
||||
|
||||
def test_resultado_es_image_gen_result():
|
||||
"""generate() retorna ImageGenResult"""
|
||||
gen = MockGenerator()
|
||||
result = gen.generate(config=None)
|
||||
assert isinstance(result, ImageGenResult)
|
||||
|
||||
|
||||
def test_clase_sin_generate_no_satisface_protocol():
|
||||
"""clase sin metodo generate NO satisface isinstance check"""
|
||||
not_gen = NotAGenerator()
|
||||
assert not isinstance(not_gen, ImageGenerator), (
|
||||
"NotAGenerator no debe satisfacer ImageGenerator Protocol"
|
||||
)
|
||||
|
||||
|
||||
def test_multiples_instancias_satisfacen_protocol():
|
||||
"""multiples instancias del mismo dummy satisfacen el Protocol"""
|
||||
for _ in range(3):
|
||||
gen = MockGenerator()
|
||||
assert isinstance(gen, ImageGenerator)
|
||||
|
||||
|
||||
def test_lambda_con_callable_no_satisface_protocol():
|
||||
"""un callable lambda no satisface el Protocol (no tiene metodo .generate)"""
|
||||
|
||||
class LambdaLike:
|
||||
def __call__(self, config):
|
||||
return None
|
||||
|
||||
obj = LambdaLike()
|
||||
# __call__ no es lo mismo que .generate — no debe satisfacer el protocol
|
||||
assert not isinstance(obj, ImageGenerator)
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Tests para image_grid — combina imagenes en grid NxM."""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import math
|
||||
import pytest
|
||||
|
||||
PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping")
|
||||
|
||||
from image_grid import image_grid
|
||||
|
||||
|
||||
def _make_images(n: int, w: int = 16, h: int = 16):
|
||||
from PIL import Image
|
||||
return [Image.new("RGB", (w, h), color=(i * 10, i * 10, i * 10)) for i in range(n)]
|
||||
|
||||
|
||||
def test_grid_4_imagenes_2_cols_dimensiones_correctas():
|
||||
"""grid de 4 imagenes 16x16 cols=2 produce ancho/alto correcto"""
|
||||
images = _make_images(4, w=16, h=16)
|
||||
result = image_grid(images, cols=2, gap_px=0)
|
||||
|
||||
# rows = ceil(4/2) = 2
|
||||
# canvas_w = 2*16 + 3*0 = 32 (con gap_px=0: cols*w + (cols+1)*0)
|
||||
# canvas_h = 2*16 + 3*0 = 32
|
||||
assert result.width == 32, f"Ancho esperado 32, got {result.width}"
|
||||
assert result.height == 32, f"Alto esperado 32, got {result.height}"
|
||||
|
||||
|
||||
def test_grid_4_imagenes_2_cols_con_gap():
|
||||
"""grid de 4 imagenes cols=2 gap_px=8 tiene dimensiones correctas con gap"""
|
||||
images = _make_images(4, w=16, h=16)
|
||||
gap = 8
|
||||
cols = 2
|
||||
rows = math.ceil(4 / cols)
|
||||
expected_w = cols * 16 + (cols + 1) * gap
|
||||
expected_h = rows * 16 + (rows + 1) * gap
|
||||
|
||||
result = image_grid(images, cols=cols, gap_px=gap)
|
||||
assert result.width == expected_w, f"Ancho: expected {expected_w}, got {result.width}"
|
||||
assert result.height == expected_h, f"Alto: expected {expected_h}, got {result.height}"
|
||||
|
||||
|
||||
def test_grid_1_imagen_1_col():
|
||||
"""grid de 1 imagen 1 col = imagen sola mas gaps"""
|
||||
images = _make_images(1, w=32, h=32)
|
||||
result = image_grid(images, cols=1, gap_px=4)
|
||||
# rows=1, cols=1 → w = 1*32 + 2*4 = 40, h = 1*32 + 2*4 = 40
|
||||
assert result.width == 40
|
||||
assert result.height == 40
|
||||
|
||||
|
||||
def test_grid_retorna_imagen_rgb():
|
||||
"""el resultado es una imagen RGB"""
|
||||
from PIL import Image
|
||||
images = _make_images(2, w=8, h=8)
|
||||
result = image_grid(images, cols=2)
|
||||
assert isinstance(result, Image.Image)
|
||||
assert result.mode == "RGB"
|
||||
|
||||
|
||||
def test_grid_con_labels_no_falla():
|
||||
"""labels opcionales — no lanza excepcion"""
|
||||
images = _make_images(4, w=16, h=16)
|
||||
labels = ["a", "b", "c", "d"]
|
||||
result = image_grid(images, cols=2, labels=labels, gap_px=0)
|
||||
# Debe devolver imagen válida
|
||||
assert result.width > 0
|
||||
assert result.height > 0
|
||||
|
||||
|
||||
def test_grid_sin_labels_no_falla():
|
||||
"""sin labels funciona correctamente"""
|
||||
images = _make_images(3, w=16, h=16)
|
||||
result = image_grid(images, cols=3, labels=None, gap_px=0)
|
||||
assert result.width == 3 * 16
|
||||
assert result.height == 16 # 1 row
|
||||
|
||||
|
||||
def test_grid_lista_vacia_levanta_value_error():
|
||||
"""lista vacia levanta ValueError"""
|
||||
with pytest.raises(ValueError):
|
||||
image_grid([], cols=2)
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Tests para image_save_png — guarda PNG con metadata tEXt embebida."""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping")
|
||||
|
||||
from image_save_png import image_save_png
|
||||
|
||||
|
||||
def test_guarda_archivo_y_retorna_ruta_absoluta(tmp_path):
|
||||
"""crea imagen 8x8, guarda y retorna ruta absoluta"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (8, 8), color=(255, 0, 0))
|
||||
dest = str(tmp_path / "test.png")
|
||||
result = image_save_png(img, dest)
|
||||
|
||||
import os
|
||||
assert os.path.isfile(result), f"El archivo no existe: {result}"
|
||||
assert os.path.isabs(result), f"La ruta no es absoluta: {result}"
|
||||
|
||||
|
||||
def test_metadata_embebida_en_chunks_text(tmp_path):
|
||||
"""metadata se embebe en chunks tEXt y se puede releer con Image.text"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (8, 8), color=(0, 128, 0))
|
||||
dest = str(tmp_path / "with_meta.png")
|
||||
meta = {"prompt": "hi", "seed": "42"}
|
||||
image_save_png(img, dest, metadata=meta)
|
||||
|
||||
reopened = Image.open(dest)
|
||||
text_data = reopened.text # dict de chunks tEXt del PNG
|
||||
assert "prompt" in text_data, f"Falta clave 'prompt' en PNG text chunks: {text_data}"
|
||||
assert "seed" in text_data, f"Falta clave 'seed' en PNG text chunks: {text_data}"
|
||||
assert text_data["prompt"] == "hi"
|
||||
assert text_data["seed"] == "42"
|
||||
|
||||
|
||||
def test_sin_metadata_no_falla(tmp_path):
|
||||
"""sin metadata el PNG se guarda igualmente"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (8, 8))
|
||||
dest = str(tmp_path / "no_meta.png")
|
||||
result = image_save_png(img, dest, metadata=None)
|
||||
|
||||
import os
|
||||
assert os.path.isfile(result)
|
||||
|
||||
|
||||
def test_crea_directorio_padre(tmp_path):
|
||||
"""crea directorio padre si no existe"""
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
img = Image.new("RGB", (8, 8))
|
||||
dest = str(tmp_path / "subdir" / "deep" / "image.png")
|
||||
result = image_save_png(img, dest)
|
||||
assert os.path.isfile(result)
|
||||
|
||||
|
||||
def test_metadata_valores_numericos_se_convierten_a_str(tmp_path):
|
||||
"""valores numericos en metadata se convierten a str automaticamente"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (8, 8))
|
||||
dest = str(tmp_path / "numeric.png")
|
||||
meta = {"steps": 30, "cfg_scale": 7.5}
|
||||
image_save_png(img, dest, metadata=meta)
|
||||
|
||||
reopened = Image.open(dest)
|
||||
text_data = reopened.text
|
||||
assert "steps" in text_data
|
||||
assert "cfg_scale" in text_data
|
||||
assert text_data["steps"] == "30"
|
||||
assert text_data["cfg_scale"] == "7.5"
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Tests para ModelRef y LoraRef — instanciación, dump y validación."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
# Importar desde el subdirectorio ml directamente para evitar colisiones de tipos
|
||||
# entre ml.model_ref.ModelRef y model_ref.ModelRef en pydantic.
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
from lora_ref import LoraRef
|
||||
from model_ref import ModelRef
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ModelRef
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_model_ref_instancia_ok():
|
||||
"""ModelRef instancia sin errores"""
|
||||
m = ModelRef(name="stabilityai/sdxl-base-1.0", model_type="sdxl")
|
||||
assert m.name == "stabilityai/sdxl-base-1.0"
|
||||
assert m.model_type == "sdxl"
|
||||
|
||||
|
||||
def test_model_ref_quantization_default_fp16():
|
||||
"""quantization default es fp16"""
|
||||
m = ModelRef(name="runwayml/stable-diffusion-v1-5", model_type="sd15")
|
||||
assert m.quantization == "fp16"
|
||||
|
||||
|
||||
def test_model_ref_quantization_override():
|
||||
"""quantization se puede cambiar a otro valor válido"""
|
||||
m = ModelRef(name="some/model", model_type="flux_dev", quantization="bf16")
|
||||
assert m.quantization == "bf16"
|
||||
|
||||
|
||||
def test_model_ref_path_default_none():
|
||||
"""path es None por defecto"""
|
||||
m = ModelRef(name="some/model", model_type="sd15")
|
||||
assert m.path is None
|
||||
|
||||
|
||||
def test_model_ref_path_set():
|
||||
"""path se puede especificar"""
|
||||
m = ModelRef(name="some/model", model_type="sd15", path="/models/sd15.safetensors")
|
||||
assert m.path == "/models/sd15.safetensors"
|
||||
|
||||
|
||||
def test_model_ref_dump():
|
||||
"""model_dump devuelve dict con las claves esperadas"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
m = ModelRef(name="some/model", model_type="sdxl", quantization="q8_0")
|
||||
d = m.model_dump()
|
||||
assert isinstance(d, dict)
|
||||
assert d["name"] == "some/model"
|
||||
assert d["model_type"] == "sdxl"
|
||||
assert d["quantization"] == "q8_0"
|
||||
|
||||
|
||||
def test_model_ref_validate_roundtrip():
|
||||
"""roundtrip model_dump_json / model_validate ok"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
m = ModelRef(name="some/model", model_type="sd3", quantization="fp32")
|
||||
raw = json.loads(m.model_dump_json())
|
||||
m2 = ModelRef.model_validate(raw)
|
||||
assert m2.name == m.name
|
||||
assert m2.model_type == m.model_type
|
||||
assert m2.quantization == m.quantization
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoraRef
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_lora_ref_instancia_ok():
|
||||
"""LoraRef instancia con path obligatorio"""
|
||||
lr = LoraRef(path="/loras/anime.safetensors")
|
||||
assert lr.path == "/loras/anime.safetensors"
|
||||
|
||||
|
||||
def test_lora_ref_weight_default_1():
|
||||
"""LoraRef weight default es 1.0"""
|
||||
lr = LoraRef(path="/loras/style.safetensors")
|
||||
assert lr.weight == 1.0
|
||||
|
||||
|
||||
def test_lora_ref_weight_override():
|
||||
"""LoraRef weight se puede cambiar"""
|
||||
lr = LoraRef(path="/loras/style.safetensors", weight=0.7)
|
||||
assert lr.weight == 0.7
|
||||
|
||||
|
||||
def test_lora_ref_scale_default_none():
|
||||
"""LoraRef scale default es None"""
|
||||
lr = LoraRef(path="/loras/x.safetensors")
|
||||
assert lr.scale is None
|
||||
|
||||
|
||||
def test_lora_ref_dump():
|
||||
"""LoraRef model_dump devuelve dict con las claves esperadas"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
lr = LoraRef(path="/loras/x.safetensors", weight=0.8, scale=0.9)
|
||||
d = lr.model_dump()
|
||||
assert d["path"] == "/loras/x.safetensors"
|
||||
assert d["weight"] == 0.8
|
||||
assert d["scale"] == 0.9
|
||||
|
||||
|
||||
def test_lora_ref_validate_roundtrip():
|
||||
"""roundtrip dump / validate ok"""
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
pytest.skip("pydantic no disponible")
|
||||
|
||||
lr = LoraRef(path="/loras/x.safetensors", weight=0.5)
|
||||
raw = json.loads(lr.model_dump_json())
|
||||
lr2 = LoraRef.model_validate(raw)
|
||||
assert lr2.path == lr.path
|
||||
assert lr2.weight == lr.weight
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Tests para safetensors_inspect — parseo de header sin dependencias externas."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
sys.path.insert(0, "python/functions/ml")
|
||||
|
||||
import pytest
|
||||
|
||||
from safetensors_inspect import safetensors_inspect
|
||||
|
||||
|
||||
def _write_safetensors(path: str, header: dict, data: bytes = b"") -> None:
|
||||
"""Escribe un archivo safetensors mínimo siguiendo la spec oficial.
|
||||
|
||||
Spec: https://github.com/huggingface/safetensors#format
|
||||
- 8 bytes: uint64 little-endian con la longitud N del header JSON
|
||||
- N bytes: JSON del header
|
||||
- (opcional) bytes de datos de tensores
|
||||
"""
|
||||
header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8")
|
||||
header_len = len(header_bytes)
|
||||
with open(path, "wb") as f:
|
||||
f.write(struct.pack("<Q", header_len)) # uint64 LE
|
||||
f.write(header_bytes)
|
||||
f.write(data)
|
||||
|
||||
|
||||
def _make_minimal_header(n_tensors: int = 2) -> dict:
|
||||
"""Genera un header con n_tensors tensores sintéticos."""
|
||||
header = {
|
||||
"__metadata__": {"format": "pt", "creator": "test"},
|
||||
}
|
||||
for i in range(n_tensors):
|
||||
header[f"tensor_{i}"] = {
|
||||
"dtype": "F32",
|
||||
"shape": [4, 4],
|
||||
"data_offsets": [i * 64, (i + 1) * 64],
|
||||
}
|
||||
return header
|
||||
|
||||
|
||||
def test_n_tensors_correcto(tmp_path):
|
||||
"""n_tensors refleja el numero de tensores en el header"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
_write_safetensors(path, _make_minimal_header(n_tensors=3))
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert result["n_tensors"] == 3
|
||||
|
||||
|
||||
def test_total_size_bytes_correcto(tmp_path):
|
||||
"""total_size_bytes refleja el tamaño real del archivo"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
data = b"\x00" * 128 # 128 bytes de datos de tensor
|
||||
_write_safetensors(path, _make_minimal_header(2), data=data)
|
||||
|
||||
file_size = os.path.getsize(path)
|
||||
result = safetensors_inspect(path)
|
||||
assert result["total_size_bytes"] == file_size
|
||||
|
||||
|
||||
def test_metadata_campo_dunder_presente(tmp_path):
|
||||
"""metadata devuelve el contenido de __metadata__"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
header = {
|
||||
"__metadata__": {"format": "pt", "model_name": "test_model"},
|
||||
"weight": {"dtype": "BF16", "shape": [8], "data_offsets": [0, 16]},
|
||||
}
|
||||
_write_safetensors(path, header)
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert result["metadata"] == {"format": "pt", "model_name": "test_model"}
|
||||
|
||||
|
||||
def test_tensors_lista_correcta(tmp_path):
|
||||
"""tensors es lista con una entrada por tensor del header"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
header = {
|
||||
"__metadata__": {},
|
||||
"embed.weight": {"dtype": "F16", "shape": [128, 64], "data_offsets": [0, 16384]},
|
||||
"proj.bias": {"dtype": "F32", "shape": [64], "data_offsets": [16384, 16640]},
|
||||
}
|
||||
_write_safetensors(path, header)
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert result["n_tensors"] == 2
|
||||
names = {t["name"] for t in result["tensors"]}
|
||||
assert "embed.weight" in names
|
||||
assert "proj.bias" in names
|
||||
|
||||
|
||||
def test_tensor_campos_dtype_shape_offsets(tmp_path):
|
||||
"""cada tensor tiene dtype, shape y data_offsets"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
header = {
|
||||
"__metadata__": {},
|
||||
"my_tensor": {"dtype": "I32", "shape": [2, 3], "data_offsets": [0, 24]},
|
||||
}
|
||||
_write_safetensors(path, header)
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
t = result["tensors"][0]
|
||||
assert t["dtype"] == "I32"
|
||||
assert t["shape"] == [2, 3]
|
||||
assert t["data_offsets"] == [0, 24]
|
||||
|
||||
|
||||
def test_path_absoluto_en_resultado(tmp_path):
|
||||
"""result['path'] es la ruta absoluta del archivo"""
|
||||
path = str(tmp_path / "model.safetensors")
|
||||
_write_safetensors(path, _make_minimal_header(1))
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert os.path.isabs(result["path"])
|
||||
assert result["path"].endswith("model.safetensors")
|
||||
|
||||
|
||||
def test_archivo_no_encontrado_levanta_file_not_found(tmp_path):
|
||||
"""FileNotFoundError si el archivo no existe"""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
safetensors_inspect(str(tmp_path / "nonexistent.safetensors"))
|
||||
|
||||
|
||||
def test_header_invalido_levanta_value_error(tmp_path):
|
||||
"""ValueError si el header no es JSON válido"""
|
||||
path = str(tmp_path / "bad.safetensors")
|
||||
with open(path, "wb") as f:
|
||||
bad_header = b"NOT JSON!!"
|
||||
f.write(struct.pack("<Q", len(bad_header)))
|
||||
f.write(bad_header)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
safetensors_inspect(path)
|
||||
|
||||
|
||||
def test_archivo_vacio_levanta_value_error(tmp_path):
|
||||
"""ValueError si el archivo está vacío (< 8 bytes)"""
|
||||
path = str(tmp_path / "empty.safetensors")
|
||||
with open(path, "wb") as f:
|
||||
f.write(b"")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
safetensors_inspect(path)
|
||||
|
||||
|
||||
def test_sin_metadata_dunder(tmp_path):
|
||||
"""si no hay __metadata__ en el header, metadata retorna dict vacio"""
|
||||
path = str(tmp_path / "no_meta.safetensors")
|
||||
header = {
|
||||
"weight": {"dtype": "F32", "shape": [4], "data_offsets": [0, 16]},
|
||||
}
|
||||
_write_safetensors(path, header)
|
||||
|
||||
result = safetensors_inspect(path)
|
||||
assert result["metadata"] == {}
|
||||
assert result["n_tensors"] == 1
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Tests para torch_device_select."""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
sys.path.insert(0, "python/functions")
|
||||
|
||||
import ml.torch_device_select as tds_module
|
||||
from ml.torch_device_select import torch_device_select
|
||||
|
||||
|
||||
class TestTorchDeviceSelect(unittest.TestCase):
|
||||
|
||||
def _patch(self, cuda=False, mps=False, cuda_count=0):
|
||||
"""Helper: parchea los helpers internos del modulo."""
|
||||
from unittest.mock import patch
|
||||
return [
|
||||
patch.object(tds_module, "_cuda_available", return_value=cuda),
|
||||
patch.object(tds_module, "_mps_available", return_value=mps),
|
||||
patch.object(tds_module, "_cuda_device_count", return_value=cuda_count),
|
||||
]
|
||||
|
||||
def test_preference_cpu_siempre_retorna_cpu(self):
|
||||
"""preference=cpu siempre retorna cpu"""
|
||||
self.assertEqual(torch_device_select("cpu"), "cpu")
|
||||
|
||||
def test_preference_auto_sin_cuda_ni_mps_retorna_cpu(self):
|
||||
"""preference=auto sin cuda ni mps retorna cpu"""
|
||||
patches = self._patch(cuda=False, mps=False)
|
||||
for p in patches:
|
||||
p.start()
|
||||
try:
|
||||
self.assertEqual(torch_device_select("auto"), "cpu")
|
||||
finally:
|
||||
for p in patches:
|
||||
p.stop()
|
||||
|
||||
def test_preference_cuda_sin_cuda_disponible_retorna_cpu_con_warning(self):
|
||||
"""preference=cuda sin cuda disponible retorna cpu con warning"""
|
||||
patches = self._patch(cuda=False, mps=False)
|
||||
for p in patches:
|
||||
p.start()
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = torch_device_select("cuda")
|
||||
self.assertEqual(result, "cpu")
|
||||
self.assertTrue(any("CUDA" in str(warning.message) for warning in w))
|
||||
finally:
|
||||
for p in patches:
|
||||
p.stop()
|
||||
|
||||
def test_preference_cuda_N_con_solo_1_GPU_retorna_cpu_con_warning(self):
|
||||
"""preference=cuda:5 con solo 1 GPU retorna cpu con warning"""
|
||||
patches = self._patch(cuda=True, cuda_count=1)
|
||||
for p in patches:
|
||||
p.start()
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = torch_device_select("cuda:5")
|
||||
self.assertEqual(result, "cpu")
|
||||
self.assertTrue(len(w) > 0)
|
||||
finally:
|
||||
for p in patches:
|
||||
p.stop()
|
||||
|
||||
def test_preference_desconocida_retorna_cpu_con_warning(self):
|
||||
"""preference desconocida retorna cpu con warning"""
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = torch_device_select("vulkan")
|
||||
self.assertEqual(result, "cpu")
|
||||
self.assertTrue(len(w) > 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests para vram_budget."""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Ajustar path para importar desde python/functions/
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ml.vram_budget import vram_budget
|
||||
|
||||
|
||||
def test_sdxl_fp16_fits_24gb():
|
||||
"""SDXL fp16 en una GPU de 24 GB deberia caber con headroom positivo."""
|
||||
result = vram_budget(
|
||||
gpu_vram_total_mb=24576, # 24 GB
|
||||
model_type="sdxl",
|
||||
quantization="fp16",
|
||||
n_loras=0,
|
||||
width=1024,
|
||||
height=1024,
|
||||
batch_size=1,
|
||||
)
|
||||
assert result["required_mb"] > 0, "required_mb debe ser positivo"
|
||||
assert result["fits"] is True, f"SDXL fp16 debe caber en 24 GB, required={result['required_mb']} MB"
|
||||
assert result["headroom_mb"] > 0, f"headroom debe ser positivo, got {result['headroom_mb']}"
|
||||
assert result["warning"] is None, f"no debe haber warning, got: {result['warning']}"
|
||||
|
||||
|
||||
def test_flux_fp16_no_fits_8gb():
|
||||
"""Flux fp16 (~23 GB) no debe caber en una GPU de 8 GB."""
|
||||
result = vram_budget(
|
||||
gpu_vram_total_mb=8192, # 8 GB
|
||||
model_type="flux_dev",
|
||||
quantization="fp16",
|
||||
n_loras=0,
|
||||
width=1024,
|
||||
height=1024,
|
||||
batch_size=1,
|
||||
)
|
||||
assert result["required_mb"] > 8192, f"Flux fp16 debe requerir mas de 8 GB, got {result['required_mb']} MB"
|
||||
assert result["fits"] is False, "Flux fp16 no debe caber en 8 GB"
|
||||
assert result["headroom_mb"] < 0, f"headroom debe ser negativo, got {result['headroom_mb']}"
|
||||
assert result["warning"] is not None, "debe haber warning con informacion de deficit"
|
||||
assert "+N MB" in result["warning"] or "+" in result["warning"], \
|
||||
f"warning debe indicar cuantos MB extra se necesitan: {result['warning']}"
|
||||
|
||||
|
||||
def test_lora_plus_quant_warning():
|
||||
"""LoRA con quantization q8_0 debe emitir warning de incompatibilidad."""
|
||||
result = vram_budget(
|
||||
gpu_vram_total_mb=24576,
|
||||
model_type="sdxl",
|
||||
quantization="q8_0",
|
||||
n_loras=2,
|
||||
width=1024,
|
||||
height=1024,
|
||||
batch_size=1,
|
||||
)
|
||||
assert result["warning"] is not None, "debe haber warning por lora+quantization incompatible"
|
||||
assert "incompatible" in result["warning"].lower(), \
|
||||
f"warning debe mencionar incompatibilidad: {result['warning']}"
|
||||
assert "fp16" in result["warning"], \
|
||||
f"warning debe sugerir fp16: {result['warning']}"
|
||||
|
||||
|
||||
def test_unknown_combo():
|
||||
"""Combinacion (model_type, quant) desconocida debe retornar required_mb=-1 y warning."""
|
||||
result = vram_budget(
|
||||
gpu_vram_total_mb=24576,
|
||||
model_type="modelo_inventado",
|
||||
quantization="q99_k",
|
||||
n_loras=0,
|
||||
)
|
||||
assert result["required_mb"] == -1, \
|
||||
f"required_mb debe ser -1 para combo desconocido, got {result['required_mb']}"
|
||||
assert result["fits"] is False, "fits debe ser False para combo desconocido"
|
||||
assert result["warning"] is not None, "debe haber warning para combo desconocido"
|
||||
assert "unknown" in result["warning"].lower(), \
|
||||
f"warning debe mencionar 'unknown': {result['warning']}"
|
||||
@@ -0,0 +1,67 @@
|
||||
---
|
||||
name: torch_device_select
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def torch_device_select(preference: str = 'auto') -> str"
|
||||
description: "Selecciona el torch device optimo segun preferencia y disponibilidad real del hardware. 'auto' elige CUDA > MPS > CPU. Para preferencias explicitas valida disponibilidad y hace fallback a CPU con warnings.warn."
|
||||
tags: [torch, pytorch, cuda, mps, device, hardware, probe, ml, apple-silicon]
|
||||
uses_functions: [cuda_available_py_ml]
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: []
|
||||
params:
|
||||
- name: preference
|
||||
desc: "'auto' detecta el mejor device disponible (CUDA > MPS > CPU). 'cuda' fuerza cuda:0. 'cuda:N' fuerza GPU N. 'mps' fuerza Apple Silicon. 'cpu' siempre retorna cpu."
|
||||
output: "string de device listo para torch: 'cuda:0', 'cuda:N', 'mps' o 'cpu'. Nunca lanza excepcion — fallback a 'cpu' con warning si el device solicitado no esta disponible."
|
||||
tested: true
|
||||
tests:
|
||||
- "preference=cpu siempre retorna cpu"
|
||||
- "preference=auto sin cuda ni mps retorna cpu"
|
||||
- "preference=cuda sin cuda disponible retorna cpu con warning"
|
||||
- "preference=cuda:5 con solo 1 GPU retorna cpu con warning"
|
||||
- "preference desconocida retorna cpu con warning"
|
||||
test_file_path: "python/functions/ml/tests/test_torch_device_select.py"
|
||||
file_path: "python/functions/ml/torch_device_select.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.torch_device_select import torch_device_select
|
||||
|
||||
# Deteccion automatica (recomendado)
|
||||
device = torch_device_select() # "cuda:0" o "mps" o "cpu"
|
||||
|
||||
# Forzar CPU para reproducibilidad
|
||||
device = torch_device_select("cpu") # siempre "cpu"
|
||||
|
||||
# Preferencia explicita con fallback automatico
|
||||
device = torch_device_select("cuda") # "cuda:0" o "cpu" + warning
|
||||
|
||||
# Uso tipico al cargar un modelo
|
||||
import torch
|
||||
device_str = torch_device_select("auto")
|
||||
model = MyModel().to(torch.device(device_str))
|
||||
```
|
||||
|
||||
## Comparacion con gliner_load_model
|
||||
|
||||
`gliner_load_model` usa internamente `_resolve_device` con la misma logica
|
||||
CUDA/CPU. `torch_device_select` extiende ese patron con:
|
||||
- Soporte MPS (Apple Silicon M1/M2/M3).
|
||||
- Seleccion de GPU especifica (`cuda:N`).
|
||||
- Fallback con `warnings.warn` en vez de silencio.
|
||||
|
||||
## Notas
|
||||
|
||||
- No levanta excepcion si torch no esta instalado: todos los helpers internos
|
||||
capturan ImportError y tratan el device como no disponible.
|
||||
- `warnings.warn` en vez de logging para no imponer dependencia de logging al caller.
|
||||
- MPS requiere torch >= 1.12 y macOS 12.3+. En sistemas Linux/Windows
|
||||
`torch.backends.mps` puede no existir — el helper lo maneja con `hasattr`.
|
||||
- impure: depende del estado del hardware y de las librerias instaladas.
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Selecciona el mejor torch device disponible segun preferencia."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
"""Retorna True si torch esta instalado y CUDA disponible."""
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _mps_available() -> bool:
|
||||
"""Retorna True si torch esta instalado y MPS (Apple Silicon) disponible."""
|
||||
try:
|
||||
import torch
|
||||
return (
|
||||
hasattr(torch.backends, "mps")
|
||||
and torch.backends.mps.is_available()
|
||||
)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _cuda_device_count() -> int:
|
||||
"""Retorna el numero de dispositivos CUDA disponibles."""
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
|
||||
def torch_device_select(preference: str = "auto") -> str:
|
||||
"""Selecciona el torch device optimo segun preferencia y disponibilidad.
|
||||
|
||||
Con preference='auto': elige CUDA si disponible, luego MPS (Apple M1/M2),
|
||||
luego CPU. Para preferencias explicitas, valida disponibilidad y hace
|
||||
fallback a CPU con advertencia si el device solicitado no esta disponible.
|
||||
|
||||
Args:
|
||||
preference: 'auto' | 'cuda' | 'cuda:N' | 'mps' | 'cpu'.
|
||||
'auto': detecta automaticamente el mejor device.
|
||||
'cuda': usa cuda:0 si disponible, fallback a cpu.
|
||||
'cuda:N': usa el dispositivo N si existe, fallback a cpu.
|
||||
'mps': usa MPS si disponible (Mac Apple Silicon), fallback a cpu.
|
||||
'cpu': siempre retorna 'cpu'.
|
||||
|
||||
Returns:
|
||||
String de device para torch: 'cuda:0', 'cuda:N', 'mps' o 'cpu'.
|
||||
"""
|
||||
if preference == "cpu":
|
||||
return "cpu"
|
||||
|
||||
if preference == "auto":
|
||||
if _cuda_available():
|
||||
return "cuda:0"
|
||||
if _mps_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
if preference == "mps":
|
||||
if _mps_available():
|
||||
return "mps"
|
||||
warnings.warn(
|
||||
"MPS no esta disponible en este sistema. Usando 'cpu'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "cpu"
|
||||
|
||||
if preference == "cuda":
|
||||
if _cuda_available():
|
||||
return "cuda:0"
|
||||
warnings.warn(
|
||||
"CUDA no esta disponible en este sistema. Usando 'cpu'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "cpu"
|
||||
|
||||
if preference.startswith("cuda:"):
|
||||
try:
|
||||
device_idx = int(preference.split(":")[1])
|
||||
except (IndexError, ValueError):
|
||||
warnings.warn(
|
||||
f"Formato de device no valido: '{preference}'. Usando 'cpu'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "cpu"
|
||||
|
||||
count = _cuda_device_count()
|
||||
if _cuda_available() and device_idx < count:
|
||||
return preference
|
||||
warnings.warn(
|
||||
f"Device '{preference}' no disponible "
|
||||
f"(cuda_count={count}). Usando 'cpu'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "cpu"
|
||||
|
||||
warnings.warn(
|
||||
f"Preferencia desconocida: '{preference}'. Usando 'cpu'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "cpu"
|
||||
@@ -0,0 +1,73 @@
|
||||
---
|
||||
name: vram_budget
|
||||
kind: function
|
||||
lang: py
|
||||
domain: ml
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def vram_budget(gpu_vram_total_mb: int, model_type: str, quantization: str, n_loras: int = 0, width: int = 1024, height: int = 1024, batch_size: int = 1) -> dict"
|
||||
description: "Estima la VRAM requerida para ejecutar un modelo de generacion de imagen via heuristicas tabuladas por (model_type, quantization). Retorna VRAM estimada, si cabe en la GPU indicada, headroom disponible, y warnings por incompatibilidades (lora+quant) o falta de VRAM. Funcion pura: solo lookup y aritmetica, sin GPU ni runtime."
|
||||
tags: [ml, vram, gpu, budget, stable-diffusion, flux, sdxl, quantization, lora, estimation, pure]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
params:
|
||||
- name: gpu_vram_total_mb
|
||||
desc: "VRAM total de la GPU objetivo en MB. Obtener con gpu_info() o torch.cuda.get_device_properties()."
|
||||
- name: model_type
|
||||
desc: "Tipo de modelo. Valores soportados: sd15, sdxl, flux_dev, flux_schnell, sd3, qwen_image. Combinaciones fuera de la tabla retornan required_mb=-1."
|
||||
- name: quantization
|
||||
desc: "Esquema de cuantizacion. Valores: fp16, q8_0, q4_0 (y variantes q4_k_m, q5_k_m, q6_k). Afecta tanto el tamano base como la compatibilidad con LoRAs."
|
||||
- name: n_loras
|
||||
desc: "Numero de LoRAs a cargar simultaneamente en VRAM. Cada LoRA suma ~300 MB. Con quantization != fp16 se emite warning de incompatibilidad."
|
||||
- name: width
|
||||
desc: "Ancho en pixeles de la imagen a generar. Afecta el overhead de latentes (mayor resolucion = mas VRAM para activaciones)."
|
||||
- name: height
|
||||
desc: "Alto en pixeles de la imagen a generar."
|
||||
- name: batch_size
|
||||
desc: "Numero de imagenes generadas en paralelo. El overhead de latentes escala linealmente con batch_size."
|
||||
output: "dict con: required_mb (int, -1 si combo desconocido), fits (bool, True si cabe en gpu_vram_total_mb), headroom_mb (int, negativo si no cabe, 0 si combo desconocido), warning (str o None con aviso de incompatibilidad lora+quant o deficit de VRAM)."
|
||||
tested: true
|
||||
tests:
|
||||
- "sdxl fp16 cabe en 24gb con headroom positivo"
|
||||
- "flux fp16 no cabe en 8gb warning con deficit"
|
||||
- "lora con quantization incompatible emite warning"
|
||||
- "combo desconocido retorna required minus1 y warning"
|
||||
test_file_path: "python/functions/ml/tests/test_vram_budget.py"
|
||||
file_path: "python/functions/ml/vram_budget.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from ml.vram_budget import vram_budget
|
||||
|
||||
# SDXL fp16 en 24 GB — cabe
|
||||
r = vram_budget(24576, "sdxl", "fp16")
|
||||
# {"required_mb": 6960, "fits": True, "headroom_mb": 17616, "warning": None}
|
||||
|
||||
# Flux dev fp16 en 8 GB — no cabe
|
||||
r = vram_budget(8192, "flux_dev", "fp16")
|
||||
# {"required_mb": 23512, "fits": False, "headroom_mb": -15320, "warning": "needs +15320 MB ..."}
|
||||
|
||||
# Flux dev q4_0 en 8 GB con 1 LoRA — incompatible
|
||||
r = vram_budget(8192, "flux_dev", "q4_0", n_loras=1)
|
||||
# {"required_mb": 7300, "fits": True, "headroom_mb": 892,
|
||||
# "warning": "lora+quantization incompatible — usa fp16 para cargar LoRAs con flux_dev"}
|
||||
|
||||
# Combo desconocido
|
||||
r = vram_budget(24576, "mi_modelo", "q99_k")
|
||||
# {"required_mb": -1, "fits": False, "headroom_mb": 0,
|
||||
# "warning": "unknown model/quant combo: ('mi_modelo', 'q99_k')"}
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
- La tabla `_MODEL_VRAM_MB` es una estimacion inicial; el usuario debe calibrarla con mediciones reales (nvidia-smi durante inference).
|
||||
- El overhead de latentes se calcula como `w*h/64 MB` para SD/SDXL/SD3 y `w*h/32 MB` para modelos Flux (espacio latente con mas canales).
|
||||
- LoRA warning tiene prioridad sobre el warning de no-fits: si hay incompatibilidad lora+quant, ese warning se emite aunque el modelo no quepa.
|
||||
- Para obtener gpu_vram_total_mb en tiempo real usar `gpu_info_py_ml` (impure).
|
||||
- Funcion pura: misma entrada, misma salida. Sin I/O ni dependencias externas.
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Estimador de VRAM requerida para modelos de generacion de imagen."""
|
||||
|
||||
# Base weights por (model_type, quantization) en MB.
|
||||
# Incluye pesos del modelo + overhead tipico del contexto de inferencia.
|
||||
_MODEL_VRAM_MB: dict[tuple[str, str], int] = {
|
||||
("sd15", "fp16"): 2100,
|
||||
("sd15", "q8_0"): 1200,
|
||||
("sd15", "q4_0"): 700,
|
||||
("sdxl", "fp16"): 6800,
|
||||
("sdxl", "q8_0"): 3800,
|
||||
("sdxl", "q4_0"): 2200,
|
||||
("flux_dev", "fp16"): 23000,
|
||||
("flux_dev", "q8_0"): 13000,
|
||||
("flux_dev", "q4_0"): 7000,
|
||||
("flux_schnell", "fp16"): 23000,
|
||||
("flux_schnell", "q8_0"): 12500,
|
||||
("flux_schnell", "q4_0"): 6500,
|
||||
("sd3", "fp16"): 8500,
|
||||
("sd3", "q8_0"): 4800,
|
||||
("sd3", "q4_0"): 2800,
|
||||
("qwen_image", "fp16"): 8000,
|
||||
("qwen_image", "q8_0"): 4500,
|
||||
("qwen_image", "q4_0"): 2600,
|
||||
}
|
||||
|
||||
# MB por LoRA adicional (estimacion conservadora en fp16).
|
||||
_LORA_MB = 300
|
||||
|
||||
# Modelos que requieren overhead de latente mas alto (Flux usa bloques transformer mas grandes).
|
||||
_FLUX_MODELS = {"flux_dev", "flux_schnell"}
|
||||
|
||||
# Quantizaciones que son incompatibles con LoRA en la mayoria de runtimes.
|
||||
_QUANT_LORA_INCOMPATIBLE = {"q8_0", "q4_0", "q4_k_m", "q5_k_m", "q6_k"}
|
||||
|
||||
|
||||
def _latent_overhead_mb(model_type: str, width: int, height: int, batch_size: int) -> int:
|
||||
"""Estima el overhead de VRAM para activaciones y latentes en MB."""
|
||||
pixels = width * height
|
||||
if model_type in _FLUX_MODELS:
|
||||
# Flux usa un espacio latente 16x mas comprimido pero con mas canales.
|
||||
overhead = pixels // 32
|
||||
else:
|
||||
# SD 1.5 / SDXL / SD3: overhead aprox w*h/64 MB.
|
||||
overhead = pixels // 64
|
||||
return overhead * batch_size
|
||||
|
||||
|
||||
def vram_budget(
|
||||
gpu_vram_total_mb: int,
|
||||
model_type: str,
|
||||
quantization: str,
|
||||
n_loras: int = 0,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
batch_size: int = 1,
|
||||
) -> dict:
|
||||
"""Estima la VRAM requerida para ejecutar un modelo de generacion de imagen.
|
||||
|
||||
Usa heuristicas tabuladas por (model_type, quantization) mas overhead de
|
||||
latentes y LoRAs. No requiere GPU ni runtime — solo lookup y aritmetica.
|
||||
|
||||
Args:
|
||||
gpu_vram_total_mb: VRAM total de la GPU en MB.
|
||||
model_type: Tipo de modelo. Valores: sd15, sdxl, flux_dev, flux_schnell, sd3, qwen_image.
|
||||
quantization: Esquema de cuantizacion. Valores: fp16, q8_0, q4_0, etc.
|
||||
n_loras: Numero de LoRAs a cargar simultaneamente (default 0).
|
||||
width: Ancho de la imagen a generar en pixeles (default 1024).
|
||||
height: Alto de la imagen a generar en pixeles (default 1024).
|
||||
batch_size: Numero de imagenes en paralelo (default 1).
|
||||
|
||||
Returns:
|
||||
dict con:
|
||||
- required_mb (int): VRAM estimada necesaria en MB. -1 si combo desconocido.
|
||||
- fits (bool): True si required_mb <= gpu_vram_total_mb.
|
||||
- headroom_mb (int): MB sobrantes (negativo si no cabe). 0 si combo desconocido.
|
||||
- warning (str | None): Aviso sobre incompatibilidades o ajustes necesarios.
|
||||
None si no hay advertencias.
|
||||
"""
|
||||
key = (model_type, quantization)
|
||||
|
||||
if key not in _MODEL_VRAM_MB:
|
||||
return {
|
||||
"required_mb": -1,
|
||||
"fits": False,
|
||||
"headroom_mb": 0,
|
||||
"warning": f"unknown model/quant combo: ({model_type!r}, {quantization!r})",
|
||||
}
|
||||
|
||||
base_mb = _MODEL_VRAM_MB[key]
|
||||
latent_mb = _latent_overhead_mb(model_type, width, height, batch_size)
|
||||
lora_mb = n_loras * _LORA_MB
|
||||
|
||||
required_mb = base_mb + latent_mb + lora_mb
|
||||
fits = required_mb <= gpu_vram_total_mb
|
||||
headroom_mb = gpu_vram_total_mb - required_mb
|
||||
|
||||
warning: str | None = None
|
||||
|
||||
# LoRA + quantization incompatible en la mayoria de runtimes.
|
||||
if n_loras > 0 and quantization in _QUANT_LORA_INCOMPATIBLE:
|
||||
warning = f"lora+quantization incompatible — usa fp16 para cargar LoRAs con {model_type}"
|
||||
elif not fits:
|
||||
deficit = required_mb - gpu_vram_total_mb
|
||||
warning = f"needs +{deficit} MB (required {required_mb} MB, available {gpu_vram_total_mb} MB)"
|
||||
|
||||
return {
|
||||
"required_mb": required_mb,
|
||||
"fits": fits,
|
||||
"headroom_mb": headroom_mb,
|
||||
"warning": warning,
|
||||
}
|
||||
Reference in New Issue
Block a user