96da9e3015
Cuatro funciones nuevas del grupo eda que nutren el capítulo AGREGACION: - select_groupby_keys (pure): elige categóricas agrupables + numéricas medida desde el TableProfile. - groupby_stats_duckdb (impure): GROUP BY push-down en DuckDB (count/mean/median/std/min/max por grupo). - pivot_table_duckdb (impure): pivot A×B push-down, limitado a top filas/cols para no cortar. - suggest_aggregations_llm (impure): el LLM elige las agregaciones interesantes con fallback determinista. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
116 lines
4.3 KiB
Python
116 lines
4.3 KiB
Python
"""Tests para pivot_table_duckdb."""
|
|
|
|
import os
|
|
import sys
|
|
|
|
import duckdb
|
|
|
|
# Permitir importar funciones del registry (from infra import ..., from datascience import ...).
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions"))
|
|
|
|
from datascience.pivot_table_duckdb import pivot_table_duckdb
|
|
|
|
|
|
def _make_db(tmp_name: str) -> str:
|
|
"""Crea una DuckDB con dos categoricas (a, b) y un valor numerico conocido.
|
|
|
|
Filas:
|
|
a='x', b='y', val=10
|
|
a='x', b='y', val=20 -> mean(x,y) = 15, count(x,y) = 2
|
|
a='x', b='z', val=5 -> mean(x,z) = 5
|
|
a='w', b='y', val=100 -> mean(w,y) = 100
|
|
Observaciones por a: x=3, w=1. Por b: y=3, z=1.
|
|
La combinacion (w, z) no existe -> celda None.
|
|
"""
|
|
db = os.path.join("/tmp", tmp_name)
|
|
if os.path.exists(db):
|
|
os.remove(db)
|
|
con = duckdb.connect(db)
|
|
con.execute("CREATE TABLE t (a VARCHAR, b VARCHAR, val DOUBLE)")
|
|
con.execute(
|
|
"INSERT INTO t VALUES "
|
|
"('x','y',10),('x','y',20),('x','z',5),('w','y',100)"
|
|
)
|
|
con.close()
|
|
return db
|
|
|
|
|
|
def test_pivot_mean_labels_y_celda_conocida():
|
|
db = _make_db("pivot_test_mean.duckdb")
|
|
res = pivot_table_duckdb(db, "t", index="a", columns="b", value="val", agg="mean")
|
|
assert res["status"] == "ok", res
|
|
# Filas ordenadas por observaciones desc: x (3) antes que w (1).
|
|
assert res["row_labels"] == ["x", "w"], res["row_labels"]
|
|
# Columnas ordenadas por observaciones desc: y (3) antes que z (1).
|
|
assert res["col_labels"] == ["y", "z"], res["col_labels"]
|
|
# matrix[0][0] = mean(a='x', b='y') = (10 + 20) / 2 = 15.
|
|
assert abs(res["matrix"][0][0] - 15.0) < 1e-9, res["matrix"]
|
|
# matrix[0][1] = mean(a='x', b='z') = 5.
|
|
assert abs(res["matrix"][0][1] - 5.0) < 1e-9, res["matrix"]
|
|
# matrix[1][0] = mean(a='w', b='y') = 100.
|
|
assert abs(res["matrix"][1][0] - 100.0) < 1e-9, res["matrix"]
|
|
# (w, z) no existe -> None.
|
|
assert res["matrix"][1][1] is None, res["matrix"]
|
|
# Sin truncado con los defaults (top_rows=10, top_cols=8).
|
|
assert res["truncated_rows"] is False
|
|
assert res["truncated_cols"] is False
|
|
# La matriz es rectangular consistente con las etiquetas.
|
|
assert len(res["matrix"]) == len(res["row_labels"])
|
|
for fila in res["matrix"]:
|
|
assert len(fila) == len(res["col_labels"])
|
|
|
|
|
|
def test_pivot_trunca_a_top_rows_y_top_cols():
|
|
db = _make_db("pivot_test_trunc.duckdb")
|
|
res = pivot_table_duckdb(
|
|
db, "t", index="a", columns="b", value="val", agg="mean",
|
|
top_rows=1, top_cols=1,
|
|
)
|
|
assert res["status"] == "ok", res
|
|
# Solo la fila/columna mas frecuente sobrevive.
|
|
assert res["row_labels"] == ["x"], res["row_labels"]
|
|
assert res["col_labels"] == ["y"], res["col_labels"]
|
|
assert res["matrix"] == [[15.0]], res["matrix"]
|
|
# Habia mas de 1 fila y mas de 1 columna -> truncado en ambos ejes.
|
|
assert res["truncated_rows"] is True
|
|
assert res["truncated_cols"] is True
|
|
|
|
|
|
def test_pivot_count_no_necesita_value_real():
|
|
db = _make_db("pivot_test_count.duckdb")
|
|
# value apunta a una columna real pero count(*) la ignora; tambien valdria un
|
|
# nombre cualquiera. Verificamos que count funciona igualmente.
|
|
res = pivot_table_duckdb(
|
|
db, "t", index="a", columns="b", value="val", agg="count"
|
|
)
|
|
assert res["status"] == "ok", res
|
|
assert res["row_labels"] == ["x", "w"]
|
|
assert res["col_labels"] == ["y", "z"]
|
|
# count(a='x', b='y') = 2 observaciones.
|
|
assert res["matrix"][0][0] == 2, res["matrix"]
|
|
# count(a='x', b='z') = 1.
|
|
assert res["matrix"][0][1] == 1, res["matrix"]
|
|
# count(a='w', b='y') = 1.
|
|
assert res["matrix"][1][0] == 1, res["matrix"]
|
|
# (w, z) no existe -> None.
|
|
assert res["matrix"][1][1] is None, res["matrix"]
|
|
|
|
|
|
def test_pivot_db_inexistente_devuelve_error_sin_lanzar():
|
|
res = pivot_table_duckdb(
|
|
"/nonexistent/path/does_not_exist.duckdb",
|
|
"t", index="a", columns="b", value="val", agg="mean",
|
|
)
|
|
assert res["status"] == "error", res
|
|
assert isinstance(res["error"], str)
|
|
|
|
|
|
def test_pivot_agg_invalido_devuelve_error():
|
|
db = _make_db("pivot_test_badagg.duckdb")
|
|
res = pivot_table_duckdb(
|
|
db, "t", index="a", columns="b", value="val", agg="stddev"
|
|
)
|
|
assert res["status"] == "error", res
|
|
assert "invalid agg" in res["error"]
|