feat(eda): funciones de agregación/OLAP para AutomaticEDA (groupby/pivot push-down + selección LLM)

Cuatro funciones nuevas del grupo eda que nutren el capítulo AGREGACION:
- select_groupby_keys (pure): elige categóricas agrupables + numéricas medida desde el TableProfile.
- groupby_stats_duckdb (impure): GROUP BY push-down en DuckDB (count/mean/median/std/min/max por grupo).
- pivot_table_duckdb (impure): pivot A×B push-down, limitado a top filas/cols para no cortar.
- suggest_aggregations_llm (impure): el LLM elige las agregaciones interesantes con fallback determinista.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-30 15:33:55 +02:00
parent 415154d9a3
commit 96da9e3015
13 changed files with 2146 additions and 0 deletions
+6
View File
@@ -25,6 +25,7 @@ from .describe_numeric import describe_numeric
from .summarize_categorical import summarize_categorical from .summarize_categorical import summarize_categorical
from .infer_semantic_type import infer_semantic_type from .infer_semantic_type import infer_semantic_type
from .column_quality_score import column_quality_score from .column_quality_score import column_quality_score
from .select_groupby_keys import select_groupby_keys
from .render_eda_markdown import render_eda_markdown from .render_eda_markdown import render_eda_markdown
from .detect_distribution_type import detect_distribution_type from .detect_distribution_type import detect_distribution_type
from .spearman_corr import spearman_corr from .spearman_corr import spearman_corr
@@ -36,6 +37,8 @@ from .infer_fk_containment_duckdb import infer_fk_containment_duckdb
from .build_join_graph import build_join_graph from .build_join_graph import build_join_graph
from .association_matrix import association_matrix from .association_matrix import association_matrix
from .correlation_matrix_duckdb import correlation_matrix_duckdb from .correlation_matrix_duckdb import correlation_matrix_duckdb
from .pivot_table_duckdb import pivot_table_duckdb
from .groupby_stats_duckdb import groupby_stats_duckdb
from .pca_explained import pca_explained from .pca_explained import pca_explained
from .kmeans_segments import kmeans_segments from .kmeans_segments import kmeans_segments
from .isolation_forest_outliers import isolation_forest_outliers from .isolation_forest_outliers import isolation_forest_outliers
@@ -82,6 +85,8 @@ __all__ = [
"build_join_graph", "build_join_graph",
"association_matrix", "association_matrix",
"correlation_matrix_duckdb", "correlation_matrix_duckdb",
"pivot_table_duckdb",
"groupby_stats_duckdb",
"pca_explained", "pca_explained",
"kmeans_segments", "kmeans_segments",
"isolation_forest_outliers", "isolation_forest_outliers",
@@ -96,6 +101,7 @@ __all__ = [
"summarize_categorical", "summarize_categorical",
"infer_semantic_type", "infer_semantic_type",
"column_quality_score", "column_quality_score",
"select_groupby_keys",
"render_eda_markdown", "render_eda_markdown",
"detect_distribution_type", "detect_distribution_type",
"pull_gsc_search_analytics", "pull_gsc_search_analytics",
@@ -0,0 +1,87 @@
---
name: groupby_stats_duckdb
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: impure
signature: "def groupby_stats_duckdb(db_path: str, table: str, group_by: str, measures: list, aggs: list = None, top_n: int = 15) -> dict"
description: "Agregaciones GROUP BY con push-down SQL en DuckDB: para cada measure numerica calcula mean/median/std/min/max por grupo (split-apply-combine en el motor), trayendo solo una fila por grupo. Nucleo de un capitulo de agregacion/OLAP de un EDA. count = tamanio del grupo, independiente de measures."
tags: [eda, groupby, aggregation, olap, duckdb, datascience, push-down, split-apply-combine]
uses_functions: [duckdb_query_readonly_py_infra]
uses_types: []
returns: []
returns_optional: false
error_type: "error_go_core"
imports: []
params:
- name: db_path
desc: "Ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la base. Path inexistente -> {status:'error'} sin lanzar."
- name: table
desc: "Nombre de la tabla. Se interpola citado con dobles comillas (soporta nombres con espacios; las comillas internas se escapan)."
- name: group_by
desc: "Columna por la que agrupar. Se interpola citada. Sus valores distintos son las claves de los grupos."
- name: measures
desc: "Lista de columnas numericas a agregar. Lista vacia es valida: cada grupo trae solo su tamanio `n` y `stats` vacio."
- name: aggs
desc: "Lista de agregaciones. None (default) = ['count','mean','median','std','min','max']. Validas: count (tamanio del grupo, va a `n`), mean->avg, median, std->stddev_samp, min, max (estas cinco por measure). Agg desconocido -> error."
- name: top_n
desc: "Maximo de grupos a devolver, ordenados por tamanio de grupo descendente (default 15). Internamente se piden top_n+1 para detectar truncado."
output: "dict. En exito {status:'ok', group_by, measures:[...], aggs:[...], n_groups:int, truncated:bool, groups:[{key:<valor grupo>, n:int, stats:{<measure>:{mean,median,std,min,max}}}], note:str}. Las estadisticas son float o None (p.ej. std de un grupo de 1 fila -> NULL -> None). En error {status:'error', error:str} (no lanza)."
tested: true
tests: ["agrega por grupo con valores conocidos", "db inexistente devuelve error sin lanzar", "measures vacias agrega solo count", "columna con espacio agrupa bien"]
test_file_path: "python/functions/datascience/groupby_stats_duckdb_test.py"
file_path: "python/functions/datascience/groupby_stats_duckdb.py"
---
## Ejemplo
```python
import duckdb
from datascience import groupby_stats_duckdb
# Cargar el titanic en una tabla DuckDB de prueba.
db = "/tmp/titanic.duckdb"
con = duckdb.connect(db)
con.execute(
"CREATE TABLE titanic AS "
"SELECT * FROM read_csv_auto('https://raw.githubusercontent.com/"
"datasciencedojo/datasets/master/titanic.csv')"
)
con.close()
# Agrupar por sexo midiendo edad y tarifa.
res = groupby_stats_duckdb(db, "titanic", "Sex", ["Age", "Fare"])
print(res["status"]) # ok
print(res["n_groups"]) # 2 (male, female)
for g in res["groups"]:
print(g["key"], g["n"], round(g["stats"]["Fare"]["mean"], 2))
# female 314 44.48
# male 577 25.52
```
## Cuando usarla
Cuando en un EDA necesitas el clasico split-apply-combine: "para cada categoria de X,
¿cuanto vale en media/mediana/desviacion/min/max la metrica Y?". Es el nucleo de un
capitulo de agregacion/OLAP. Usala antes de pintar barras o boxplots por grupo, para
detectar segmentos con comportamiento distinto, o para resumir una tabla grande sin
traer las filas a RAM: todo el GROUP BY ocurre push-down en el motor de DuckDB y solo
viaja una fila por grupo. `top_n` te deja quedarte con los grupos mas poblados.
## Gotchas
- Funcion impura: lee un archivo DuckDB del disco (read_only, nunca lo modifica). La
tabla debe existir ya en el `.db` (no carga CSV; para eso crea la tabla antes).
- Identificadores (tabla, group_by, measures) se interpolan citados con dobles comillas
y escapando las internas: soporta nombres con espacios y evita inyeccion. No pases
expresiones SQL como group_by/measure — solo nombres de columna.
- `count` es el tamanio del grupo (`COUNT(*)`), independiente de las measures: se
refleja en el campo `n` de cada grupo, NO como clave dentro de `stats`. Las claves de
`stats[measure]` son las measure-aggs efectivas (mean/median/std/min/max menos count).
- `std` usa `stddev_samp` (muestral, n-1): un grupo con una sola fila da `NULL` -> `None`.
Las measures pueden contener NULLs; cada agregada los ignora segun la semantica de DuckDB.
- `truncated:True` indica que habia mas grupos que `top_n` (se devolvieron los `top_n`
mayores por tamanio). Sube `top_n` si necesitas todos los grupos.
- Si `measures` esta vacio, cada grupo trae solo `n` y `stats == {}` (valido, util para
un simple conteo por categoria).
@@ -0,0 +1,184 @@
"""groupby_stats_duckdb — agregaciones GROUP BY con push-down SQL en DuckDB.
Funcion impura: lee de disco a traves de DuckDB (via la primitiva read-only
`duckdb_query_readonly` del grupo `duckdb`). Pertenece al grupo de capacidad `eda`.
Ejecuta un `GROUP BY <group_by>` en el motor de DuckDB (split-apply-combine con
push-down) calculando, para cada columna numerica de `measures`, las agregaciones
pedidas (mean/median/std/min/max). Solo trae al cliente una fila por grupo, nunca
las filas crudas: apto para tablas grandes. Es el nucleo de un capitulo de
agregacion/OLAP de un EDA.
Estilo dict-no-throw del grupo duckdb: nunca lanza; captura cualquier error y
devuelve {status:'error', error:str}.
"""
from infra import duckdb_query_readonly
# Mapeo agg -> funcion agregada SQL de DuckDB. `count` se trata aparte: es
# COUNT(*) (tamanio del grupo), independiente de las measures.
_AGG_SQL = {
"mean": "avg",
"median": "median",
"std": "stddev_samp",
"min": "min",
"max": "max",
}
# Aggs por defecto cuando aggs=None. count primero (tamanio del grupo) + las
# cinco estadisticas por measure.
_DEFAULT_AGGS = ["count", "mean", "median", "std", "min", "max"]
def _quote_ident(ident: str) -> str:
"""Cita un identificador SQL con dobles comillas, escapando las internas.
Soporta nombres con espacios o caracteres especiales y evita inyeccion: dentro
de un identificador entrecomillado el unico caracter peligroso es la propia
comilla doble, que se duplica ("") segun el estandar SQL. DuckDB no admite
parametros posicionales para nombres de tabla/columna, asi que esta es la via
segura de interpolarlos.
"""
return '"' + str(ident).replace('"', '""') + '"'
def groupby_stats_duckdb(
db_path: str,
table: str,
group_by: str,
measures: list,
aggs: list = None,
top_n: int = 15,
) -> dict:
"""GROUP BY con agregaciones por measure, todo push-down en DuckDB.
Args:
db_path: ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la
base. Un path inexistente devuelve {status:'error', ...} sin lanzar.
table: nombre de la tabla. Se interpola citado con dobles comillas (soporta
nombres con espacios).
group_by: columna por la que agrupar. Se interpola citada.
measures: lista de columnas numericas a agregar. Lista vacia es valida:
cada grupo trae solo su tamanio `n` y `stats` vacio.
aggs: lista de agregaciones a calcular. None (default) =
["count", "mean", "median", "std", "min", "max"]. Valores validos:
count (tamanio del grupo, va a `n`), mean, median, std, min, max
(estas cinco se calculan por cada measure). Un agg desconocido devuelve
error.
top_n: numero maximo de grupos a devolver, ordenados por tamanio de grupo
descendente (default 15). Se pide top_n+1 internamente para detectar si
habia mas grupos y marcar `truncated`.
Returns:
dict. En exito:
{status:'ok',
group_by:str,
measures:[...],
aggs:[...], # las efectivas (incluye count si se pidio)
n_groups:int, # nº de grupos devueltos (<= top_n)
truncated:bool, # True si habia mas de top_n grupos
groups:[{key:<valor grupo>, n:int,
stats:{<measure>:{mean,median,std,min,max}}}, ...],
note:str}
Las estadisticas son float o None (p.ej. stddev_samp de un grupo de una
sola fila -> NULL -> None). En error (sin lanzar): {status:'error', error:str}.
"""
try:
# 1. Validar entradas.
if not isinstance(table, str) or table == "":
return {"status": "error", "error": "table must be a non-empty string"}
if not isinstance(group_by, str) or group_by == "":
return {"status": "error", "error": "group_by must be a non-empty string"}
if measures is None:
measures = []
if not isinstance(measures, list):
return {"status": "error", "error": "measures must be a list"}
for m in measures:
if not isinstance(m, str) or m == "":
return {
"status": "error",
"error": f"invalid measure identifier: {m!r}",
}
if aggs is None:
aggs = list(_DEFAULT_AGGS)
if not isinstance(aggs, list) or len(aggs) == 0:
return {
"status": "error",
"error": "aggs must be a non-empty list or None",
}
for a in aggs:
if a != "count" and a not in _AGG_SQL:
return {
"status": "error",
"error": f"unknown agg {a!r}; valid: count, "
+ ", ".join(_AGG_SQL),
}
if not isinstance(top_n, int) or isinstance(top_n, bool) or top_n < 1:
return {"status": "error", "error": "top_n must be a positive int"}
# 2. Aggs por measure = todas menos count (count es el tamanio del grupo,
# se mapea siempre a la columna `n`).
measure_aggs = [a for a in aggs if a != "count"]
# 3. Construir el SELECT. grp y n primero; luego un termino por measure x agg
# con alias posicional (m{idx}_{agg}) para no chocar con nombres de columna
# que lleven espacios o caracteres raros.
select_terms = [f"{_quote_ident(group_by)} AS grp", "COUNT(*) AS n"]
agg_index = [] # (measure_name, agg_name, alias)
for mi, m in enumerate(measures):
for a in measure_aggs:
alias = f"m{mi}_{a}"
fn = _AGG_SQL[a]
select_terms.append(f"{fn}({_quote_ident(m)}) AS {alias}")
agg_index.append((m, a, alias))
# Pedimos top_n+1 grupos para detectar truncado (habia mas que top_n).
sql = (
f"SELECT {', '.join(select_terms)} "
f"FROM {_quote_ident(table)} "
f"GROUP BY {_quote_ident(group_by)} "
f"ORDER BY n DESC "
f"LIMIT {top_n + 1}"
)
# 4. Ejecutar push-down. sandbox=True (default) basta: la tabla ya existe en
# el .db, no necesitamos read_csv/read_blob ni acceso al filesystem.
result = duckdb_query_readonly(db_path, sql, max_rows=top_n + 1)
if result.get("status") != "ok":
return {
"status": "error",
"error": "groupby query failed: "
+ str(result.get("error", "unknown")),
}
rows = result.get("rows", [])
truncated = len(rows) > top_n
if truncated:
rows = rows[:top_n]
# 5. Reconstruir la estructura por grupo.
groups = []
for row in rows:
stats = {m: {} for m in measures}
for (m, a, alias) in agg_index:
stats[m][a] = row.get(alias)
groups.append(
{"key": row.get("grp"), "n": row.get("n"), "stats": stats}
)
return {
"status": "ok",
"group_by": group_by,
"measures": list(measures),
"aggs": list(aggs),
"n_groups": len(groups),
"truncated": truncated,
"groups": groups,
"note": f"GROUP BY {group_by}: top {len(groups)} grupos por tamanio sobre "
f"{len(measures)} measure(s)",
}
except Exception as e: # noqa: BLE001
return {"status": "error", "error": str(e)}
@@ -0,0 +1,106 @@
"""Tests para groupby_stats_duckdb."""
import os
import sys
import duckdb
# Permitir importar funciones del registry (from infra import ..., from datascience import ...).
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions"))
from datascience.groupby_stats_duckdb import groupby_stats_duckdb
def _make_db(tmp_path, rows):
"""Crea una DuckDB con tabla t(g VARCHAR, x DOUBLE) e inserta `rows`."""
db = os.path.join(str(tmp_path), "t.duckdb")
con = duckdb.connect(db)
con.execute("CREATE TABLE t(g VARCHAR, x DOUBLE)")
con.executemany("INSERT INTO t VALUES (?, ?)", rows)
con.close()
return db
def test_agrega_por_grupo_con_valores_conocidos(tmp_path):
# Grupo a: [10, 20, 30] -> n=3, mean=20, min=10, max=30, median=20, std=10.
# Grupo b: [5, 15] -> n=2, mean=10, median=10.
# Grupo c: [100] -> n=1, mean=100, std=None (1 sola fila).
rows = [
("a", 10.0), ("a", 20.0), ("a", 30.0),
("b", 5.0), ("b", 15.0),
("c", 100.0),
]
db = _make_db(tmp_path, rows)
res = groupby_stats_duckdb(db, "t", "g", ["x"])
assert res["status"] == "ok", res
assert res["n_groups"] == 3
assert res["truncated"] is False
assert res["aggs"] == ["count", "mean", "median", "std", "min", "max"]
by_key = {g["key"]: g for g in res["groups"]}
assert set(by_key) == {"a", "b", "c"}
# Grupo a: comprobacion manual de mean/min/max/median/std.
sa = by_key["a"]["stats"]["x"]
assert by_key["a"]["n"] == 3
assert abs(sa["mean"] - 20.0) < 1e-9
assert abs(sa["min"] - 10.0) < 1e-9
assert abs(sa["max"] - 30.0) < 1e-9
assert abs(sa["median"] - 20.0) < 1e-9
assert "std" in sa and sa["std"] is not None
assert abs(sa["std"] - 10.0) < 1e-9 # stddev_samp([10,20,30]) = 10
# Grupo b: mean y median conocidas.
sb = by_key["b"]["stats"]["x"]
assert by_key["b"]["n"] == 2
assert abs(sb["mean"] - 10.0) < 1e-9
assert abs(sb["median"] - 10.0) < 1e-9
assert "median" in sb and "std" in sb
# Grupo c: una sola fila -> std None (stddev_samp NULL), mean/min/max definidos.
sc = by_key["c"]["stats"]["x"]
assert by_key["c"]["n"] == 1
assert abs(sc["mean"] - 100.0) < 1e-9
assert sc["std"] is None
def test_db_inexistente_devuelve_error_sin_lanzar(tmp_path):
db = os.path.join(str(tmp_path), "no_existe.duckdb")
res = groupby_stats_duckdb(db, "t", "g", ["x"])
assert res["status"] == "error", res
assert isinstance(res["error"], str) and res["error"]
def test_measures_vacias_agrega_solo_count(tmp_path):
rows = [("a", 1.0), ("a", 2.0), ("b", 3.0)]
db = _make_db(tmp_path, rows)
res = groupby_stats_duckdb(db, "t", "g", [])
assert res["status"] == "ok", res
by_key = {g["key"]: g for g in res["groups"]}
assert by_key["a"]["n"] == 2
assert by_key["b"]["n"] == 1
# Sin measures, stats por grupo es un dict vacio (valido).
assert by_key["a"]["stats"] == {}
assert by_key["b"]["stats"] == {}
def test_columna_con_espacio_agrupa_bien(tmp_path):
# Tabla con nombres de columna con espacios -> prueba el quoting con dobles
# comillas tanto en group_by como en la measure.
db = os.path.join(str(tmp_path), "space.duckdb")
con = duckdb.connect(db)
con.execute('CREATE TABLE t("my col" VARCHAR, "the val" DOUBLE)')
con.executemany(
'INSERT INTO t VALUES (?, ?)',
[("x", 1.0), ("x", 3.0), ("y", 10.0)],
)
con.close()
res = groupby_stats_duckdb(db, "t", "my col", ["the val"])
assert res["status"] == "ok", res
by_key = {g["key"]: g for g in res["groups"]}
assert by_key["x"]["n"] == 2
assert abs(by_key["x"]["stats"]["the val"]["mean"] - 2.0) < 1e-9
assert by_key["y"]["n"] == 1
assert abs(by_key["y"]["stats"]["the val"]["mean"] - 10.0) < 1e-9
@@ -0,0 +1,92 @@
---
name: pivot_table_duckdb
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: impure
signature: "def pivot_table_duckdb(db_path: str, table: str, index: str, columns: str, value: str, agg: str = 'mean', top_rows: int = 10, top_cols: int = 8) -> dict"
description: "Pivot table (index x columns -> agg(value)) calculada con push-down SQL en DuckDB (GROUP BY en el motor, sin traer filas a RAM) y recortada a las top_rows filas y top_cols columnas con mas observaciones para que quepa entera en un PDF movil / slide PPTX sin cortarse. Version push-down para tablas grandes de la funcion pura `pivot` (que pivota list[dict] en memoria)."
tags: [eda, pivot, duckdb, aggregate, datascience, push-down, report]
uses_functions: [duckdb_query_readonly_py_infra]
uses_types: []
returns: []
returns_optional: false
error_type: "error_go_core"
imports: []
params:
- name: db_path
desc: "Ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la base."
- name: table
desc: "Nombre de la tabla a pivotar. Se interpola citado con dobles comillas (DuckDB no admite parametros para identificadores)."
- name: index
desc: "Columna cuyos valores forman las filas de la pivot (eje vertical)."
- name: columns
desc: "Columna cuyos valores forman las columnas de la pivot (eje horizontal)."
- name: value
desc: "Columna numerica a agregar en cada celda. Ignorada cuando agg='count'."
- name: agg
desc: "Funcion de agregacion: mean, sum, count, min, max, median. mean->avg(), count->COUNT(*). Otro valor devuelve {status:'error'}."
- name: top_rows
desc: "Numero maximo de filas a conservar, elegidas por mayor numero de observaciones (suma de COUNT(*) por valor de index). Default 10."
- name: top_cols
desc: "Numero maximo de columnas a conservar, elegidas por mayor numero de observaciones (suma de COUNT(*) por valor de columns). Default 8."
output: "dict. En exito {status:'ok', index, columns, value, agg, row_labels:[...], col_labels:[...], matrix:[[...]], truncated_rows:bool, truncated_cols:bool, note:str}. matrix tiene len(row_labels) filas y cada fila len(col_labels) celdas (valor agregado o None si la combinacion no existe). truncated_* indica si hubo mas filas/columnas que el top. En error {status:'error', error:str} (no lanza)."
tested: true
tests: ["pivot mean labels y celda conocida", "pivot trunca a top rows y top cols", "pivot count no necesita value real", "pivot db inexistente devuelve error sin lanzar", "pivot agg invalido devuelve error"]
test_file_path: "python/functions/datascience/pivot_table_duckdb_test.py"
file_path: "python/functions/datascience/pivot_table_duckdb.py"
---
## Ejemplo
```python
import duckdb
from datascience import pivot_table_duckdb
# Tabla DuckDB de prueba estilo titanic: sex x pclass -> mean(fare).
db = "/tmp/pivot_demo.duckdb"
con = duckdb.connect(db)
con.execute(
"CREATE TABLE titanic AS SELECT * FROM (VALUES "
"('male',1,211.3),('female',1,151.5),('male',3,7.9),"
"('female',3,16.7),('male',1,52.0),('female',2,41.6)"
") t(sex, pclass, fare)"
)
con.close()
res = pivot_table_duckdb(db, "titanic", index="sex", columns="pclass", value="fare", agg="mean")
print(res["status"]) # ok
print(res["row_labels"]) # ['female', 'male'] (orden por nº de observaciones desc; empate -> etiqueta)
print(res["col_labels"]) # [1, 3, 2] (pclass=1 tiene 3 obs, pclass=3 -> 2, pclass=2 -> 1)
print(res["matrix"]) # [[151.5, 16.7, 41.6], [131.65, 7.9, None]] (male/pclass=2 no existe -> None)
```
## Cuando usarla
Cuando quieres una pivot table (`index` x `columns` -> `agg(value)`) de una tabla
DuckDB con MUCHAS filas y necesitas que el resultado quepa entero en un informe: un
PDF abierto en el movil o un slide PPTX, donde una matriz de 50x30 se cortaria. La
agregacion se hace push-down en el motor (no traes las filas a RAM) y el resultado se
limita a las `top_rows` x `top_cols` combinaciones con mas observaciones. Encaja en el
flujo `eda` para resumir el cruce de dos categoricas (sexo x clase, region x producto)
contra una metrica. Para pivotar un `list[dict]` ya cargado en memoria usa la funcion
pura `pivot_py_datascience`; esta es la version push-down sobre disco.
## Gotchas
- Funcion impura: lee un archivo DuckDB del disco (read_only, nunca lo modifica).
- Recorta a `top_rows` x `top_cols` por numero de observaciones (suma de `COUNT(*)`),
NO por magnitud del valor agregado. Si habia mas filas/columnas, `truncated_rows` /
`truncated_cols` quedan en True y esas combinaciones NO aparecen en la matriz.
- Las celdas sin datos (combinacion `index` x `columns` que no existe en la tabla) se
rellenan con `None`, no con 0: distinguir "cero medido" de "sin observaciones".
- `agg='count'` cuenta filas por celda con `COUNT(*)` e ignora `value` (puedes pasar
cualquier nombre de columna). Para el resto de aggs, `value` debe ser una columna
numerica real o la query fallara con `{status:'error'}`.
- `agg` solo admite mean, sum, count, min, max, median; cualquier otro valor devuelve
`{status:'error'}` sin tocar la base.
- Orden de `row_labels` / `col_labels`: por numero de observaciones descendente, con
desempate estable por etiqueta. No es orden alfabetico ni el de aparicion.
- La query se ejecuta con `sandbox=False` en `duckdb_query_readonly` (uso interno
confiable: el SQL lo construye esta funcion, no un cliente externo).
@@ -0,0 +1,176 @@
"""pivot_table_duckdb — pivot table (index x columns -> agg(value)) con push-down SQL.
Funcion impura: lee de disco a traves de DuckDB reusando la primitiva read-only del
grupo `duckdb` (`duckdb_query_readonly`). Pertenece al grupo de capacidad `eda`
(exploratory data analysis).
A diferencia de la funcion pura `pivot` (que pivota un `list[dict]` ya cargado en
memoria), esta version empuja la agregacion al motor de DuckDB (push-down): el
GROUP BY lo resuelve DuckDB y solo se traen los valores agregados, nunca las filas
crudas. Esto la hace apta para tablas grandes.
Ademas reduce el resultado a las `top_rows` filas y `top_cols` columnas con mas
observaciones, de modo que la pivot quepa entera en un PDF movil / slide PPTX sin
cortarse. Marca `truncated_rows`/`truncated_cols` cuando hubo recorte.
Estilo dict-no-throw del grupo duckdb: nunca lanza; captura cualquier error y
devuelve {status:'error', error:str}.
"""
from collections import defaultdict
from infra import duckdb_query_readonly
# Funciones de agregacion permitidas y su nombre en SQL DuckDB.
# mean -> avg; el resto mapea directo. count se trata aparte (COUNT(*), sin value).
_AGG_SQL = {
"mean": "avg",
"sum": "sum",
"count": "count",
"min": "min",
"max": "max",
"median": "median",
}
def _quote_ident(ident: str) -> str:
"""Cita un identificador SQL con dobles comillas, escapando `"` -> `""`.
DuckDB no admite parametros posicionales para nombres de tabla/columna, asi que
hay que interpolarlos. El quoting con `"` y el doblado de comillas internas evita
que un nombre rompa la sentencia (mismo patron que correlation_matrix_duckdb).
"""
return '"' + str(ident).replace('"', '""') + '"'
def pivot_table_duckdb(
db_path: str,
table: str,
index: str,
columns: str,
value: str,
agg: str = "mean",
top_rows: int = 10,
top_cols: int = 8,
) -> dict:
"""Pivot table push-down en DuckDB, recortada a top_rows x top_cols.
Construye una pivot (filas = valores de `index`, columnas = valores de `columns`,
celda = `agg(value)`) agregando en el motor de DuckDB, y la reduce a las filas y
columnas con mas observaciones para que quepa en un PDF / slide.
Args:
db_path: ruta al archivo DuckDB. Debe existir (read_only NO crea la base).
table: nombre de la tabla a pivotar.
index: columna cuyos valores forman las filas de la pivot.
columns: columna cuyos valores forman las columnas de la pivot.
value: columna numerica a agregar. Ignorada cuando agg="count".
agg: funcion de agregacion. Una de: "mean", "sum", "count", "min", "max",
"median". mean se traduce a avg(); count a COUNT(*).
top_rows: numero maximo de filas a conservar, elegidas por mayor numero de
observaciones (suma de COUNT(*) por valor de index). Default 10.
top_cols: numero maximo de columnas a conservar, elegidas por mayor numero de
observaciones (suma de COUNT(*) por valor de columns). Default 8.
Returns:
dict. En exito:
{status:'ok',
index, columns, value, agg,
row_labels:[...], # valores de index, en orden de freq desc
col_labels:[...], # valores de columns, en orden de freq desc
matrix:[[...], ...], # len == len(row_labels); cada fila
# len == len(col_labels); celda = agg o None
truncated_rows:bool, truncated_cols:bool,
note:str}
En error (sin lanzar): {status:'error', error:str}.
"""
try:
if not isinstance(agg, str) or agg not in _AGG_SQL:
return {
"status": "error",
"error": "invalid agg "
+ repr(agg)
+ "; allowed: "
+ ", ".join(sorted(_AGG_SQL)),
}
# Paso 1 (push-down): agregar (index, columns) -> agg(value) + COUNT(*).
if agg == "count":
agg_expr = "COUNT(*)"
else:
agg_expr = f"{_AGG_SQL[agg]}({_quote_ident(value)})"
sql = (
f"SELECT {_quote_ident(index)} AS r, "
f"{_quote_ident(columns)} AS c, "
f"{agg_expr} AS v, "
f"COUNT(*) AS n "
f"FROM {_quote_ident(table)} "
f"GROUP BY {_quote_ident(index)}, {_quote_ident(columns)}"
)
# max_rows alto: queremos todos los grupos (index x columns) para elegir el
# top con criterio global. sandbox=False igual que correlation_matrix_duckdb,
# porque db_path es una ruta interna de confianza.
result = duckdb_query_readonly(
db_path, sql, max_rows=1_000_000, sandbox=False
)
if result.get("status") != "ok":
return {
"status": "error",
"error": "pivot query failed: "
+ str(result.get("error", "unknown")),
}
# Paso 2 (en python): contar observaciones por fila y por columna, y guardar
# el valor agregado de cada celda (r, c).
row_obs: dict = defaultdict(int)
col_obs: dict = defaultdict(int)
cell: dict = {}
for row in result.get("rows", []):
r = row.get("r")
c = row.get("c")
n = row.get("n") or 0
row_obs[r] += n
col_obs[c] += n
cell[(r, c)] = row.get("v")
def _top(obs: dict, limit: int):
# Orden: mas observaciones primero; desempate estable por etiqueta.
ranked = sorted(obs.items(), key=lambda kv: (-kv[1], str(kv[0])))
selected = [label for label, _ in ranked[:limit]]
return selected, len(ranked) > limit
row_labels, truncated_rows = _top(row_obs, top_rows)
col_labels, truncated_cols = _top(col_obs, top_cols)
# Paso 3: materializar la matriz; None donde la combinacion no existe.
matrix = [
[cell.get((r, c)) for c in col_labels] for r in row_labels
]
note = (
f"pivot {agg}({value}) reducida a {len(row_labels)}x{len(col_labels)} "
"(top por observaciones) para caber en PDF/slide"
)
if agg == "count":
note = (
f"pivot count(*) reducida a {len(row_labels)}x{len(col_labels)} "
"(top por observaciones) para caber en PDF/slide"
)
return {
"status": "ok",
"index": index,
"columns": columns,
"value": value,
"agg": agg,
"row_labels": row_labels,
"col_labels": col_labels,
"matrix": matrix,
"truncated_rows": truncated_rows,
"truncated_cols": truncated_cols,
"note": note,
}
except Exception as e: # noqa: BLE001
return {"status": "error", "error": str(e)}
@@ -0,0 +1,115 @@
"""Tests para pivot_table_duckdb."""
import os
import sys
import duckdb
# Permitir importar funciones del registry (from infra import ..., from datascience import ...).
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions"))
from datascience.pivot_table_duckdb import pivot_table_duckdb
def _make_db(tmp_name: str) -> str:
"""Crea una DuckDB con dos categoricas (a, b) y un valor numerico conocido.
Filas:
a='x', b='y', val=10
a='x', b='y', val=20 -> mean(x,y) = 15, count(x,y) = 2
a='x', b='z', val=5 -> mean(x,z) = 5
a='w', b='y', val=100 -> mean(w,y) = 100
Observaciones por a: x=3, w=1. Por b: y=3, z=1.
La combinacion (w, z) no existe -> celda None.
"""
db = os.path.join("/tmp", tmp_name)
if os.path.exists(db):
os.remove(db)
con = duckdb.connect(db)
con.execute("CREATE TABLE t (a VARCHAR, b VARCHAR, val DOUBLE)")
con.execute(
"INSERT INTO t VALUES "
"('x','y',10),('x','y',20),('x','z',5),('w','y',100)"
)
con.close()
return db
def test_pivot_mean_labels_y_celda_conocida():
db = _make_db("pivot_test_mean.duckdb")
res = pivot_table_duckdb(db, "t", index="a", columns="b", value="val", agg="mean")
assert res["status"] == "ok", res
# Filas ordenadas por observaciones desc: x (3) antes que w (1).
assert res["row_labels"] == ["x", "w"], res["row_labels"]
# Columnas ordenadas por observaciones desc: y (3) antes que z (1).
assert res["col_labels"] == ["y", "z"], res["col_labels"]
# matrix[0][0] = mean(a='x', b='y') = (10 + 20) / 2 = 15.
assert abs(res["matrix"][0][0] - 15.0) < 1e-9, res["matrix"]
# matrix[0][1] = mean(a='x', b='z') = 5.
assert abs(res["matrix"][0][1] - 5.0) < 1e-9, res["matrix"]
# matrix[1][0] = mean(a='w', b='y') = 100.
assert abs(res["matrix"][1][0] - 100.0) < 1e-9, res["matrix"]
# (w, z) no existe -> None.
assert res["matrix"][1][1] is None, res["matrix"]
# Sin truncado con los defaults (top_rows=10, top_cols=8).
assert res["truncated_rows"] is False
assert res["truncated_cols"] is False
# La matriz es rectangular consistente con las etiquetas.
assert len(res["matrix"]) == len(res["row_labels"])
for fila in res["matrix"]:
assert len(fila) == len(res["col_labels"])
def test_pivot_trunca_a_top_rows_y_top_cols():
db = _make_db("pivot_test_trunc.duckdb")
res = pivot_table_duckdb(
db, "t", index="a", columns="b", value="val", agg="mean",
top_rows=1, top_cols=1,
)
assert res["status"] == "ok", res
# Solo la fila/columna mas frecuente sobrevive.
assert res["row_labels"] == ["x"], res["row_labels"]
assert res["col_labels"] == ["y"], res["col_labels"]
assert res["matrix"] == [[15.0]], res["matrix"]
# Habia mas de 1 fila y mas de 1 columna -> truncado en ambos ejes.
assert res["truncated_rows"] is True
assert res["truncated_cols"] is True
def test_pivot_count_no_necesita_value_real():
db = _make_db("pivot_test_count.duckdb")
# value apunta a una columna real pero count(*) la ignora; tambien valdria un
# nombre cualquiera. Verificamos que count funciona igualmente.
res = pivot_table_duckdb(
db, "t", index="a", columns="b", value="val", agg="count"
)
assert res["status"] == "ok", res
assert res["row_labels"] == ["x", "w"]
assert res["col_labels"] == ["y", "z"]
# count(a='x', b='y') = 2 observaciones.
assert res["matrix"][0][0] == 2, res["matrix"]
# count(a='x', b='z') = 1.
assert res["matrix"][0][1] == 1, res["matrix"]
# count(a='w', b='y') = 1.
assert res["matrix"][1][0] == 1, res["matrix"]
# (w, z) no existe -> None.
assert res["matrix"][1][1] is None, res["matrix"]
def test_pivot_db_inexistente_devuelve_error_sin_lanzar():
res = pivot_table_duckdb(
"/nonexistent/path/does_not_exist.duckdb",
"t", index="a", columns="b", value="val", agg="mean",
)
assert res["status"] == "error", res
assert isinstance(res["error"], str)
def test_pivot_agg_invalido_devuelve_error():
db = _make_db("pivot_test_badagg.duckdb")
res = pivot_table_duckdb(
db, "t", index="a", columns="b", value="val", agg="stddev"
)
assert res["status"] == "error", res
assert "invalid agg" in res["error"]
@@ -0,0 +1,158 @@
---
id: select_groupby_keys_py_datascience
name: select_groupby_keys
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: pure
signature: "def select_groupby_keys(profile: dict, max_keys: int = 3, max_card: int = 20, max_measures: int = 4) -> dict"
description: "Elige deterministicamente las columnas categoricas mas interesantes para GROUP BY, las numericas medida y pares pivote a partir de un TableProfile del grupo eda. Respaldo cuantitativo para el capitulo de agregacion/OLAP de un EDA. Funcion pura, no muta el input, nunca lanza."
tags: [eda, aggregation, groupby, olap, profiling, datascience]
uses_functions: []
uses_types: []
returns: []
returns_optional: false
error_type: ""
imports: []
example: |
from datascience import select_groupby_keys
profile = {
"n_rows": 891,
"key_candidates": ["passenger_id"],
"columns": [
{"name": "sex", "inferred_type": "categorical", "distinct_count": 2,
"unique_pct": 0.002, "null_pct": 0.0, "flags": [],
"categorical": {"imbalance": 1.8}, "numeric": None},
{"name": "pclass", "inferred_type": "categorical", "distinct_count": 3,
"unique_pct": 0.003, "null_pct": 0.0, "flags": [],
"categorical": {"imbalance": 2.5}, "numeric": None},
{"name": "fare", "inferred_type": "numeric", "distinct_count": 200,
"unique_pct": 0.2, "null_pct": 0.0, "flags": [],
"numeric": {"std": 49.7, "cv": 1.54}, "categorical": None},
],
}
select_groupby_keys(profile)
# {"group_keys": [{"col": "sex", ...}, {"col": "pclass", ...}],
# "measures": ["fare"],
# "pivots": [{"index": "sex", "columns": "pclass", "value": "fare"}],
# "note": "2 clave(s) de grupo: sex, pclass; 1 medida(s): fare; 1 pivot(s)."}
tested: true
tests:
- "test_titanic_picks_good_cats_excludes_id_and_constant"
- "test_titanic_measures_exclude_id_constant_and_keep_numerics"
- "test_titanic_generates_one_pivot"
- "test_empty_profile_returns_all_empty_and_does_not_crash"
- "test_none_profile_does_not_crash"
- "test_only_numerics_yields_empty_group_keys_and_no_pivots"
- "test_high_cardinality_and_max_card_are_excluded"
- "test_max_keys_limits_group_keys"
- "test_three_keys_cap_pivots_to_two"
- "test_does_not_mutate_input"
test_file_path: "python/functions/datascience/select_groupby_keys_test.py"
file_path: "python/functions/datascience/select_groupby_keys.py"
params:
- name: profile
desc: >
TableProfile dict del grupo eda (p.ej. salida de summarize_table_duckdb).
Se lee de forma defensiva (.get / or [] / isinstance). Claves usadas:
columns (list[ColumnProfile]), key_candidates (list de nombres de columna
o dicts {name}), n_rows. Cada ColumnProfile usa: name, inferred_type
("numeric"|"categorical"|"datetime"|"text"|"boolean"), distinct_count,
unique_pct (0..1), null_pct (0..1), flags (list[str], reconoce
"possible_id"/"high_cardinality"/"constant"), numeric ({std, cv, ...}|None)
y categorical ({imbalance, mode_pct, ...}|None).
- name: max_keys
desc: "Numero maximo de claves de grupo (group_keys) a devolver. Default 3."
- name: max_card
desc: >
Cardinalidad maxima (distinct_count) que una columna categorica puede
tener para seguir siendo candidata a clave de grupo. Default 20.
- name: max_measures
desc: "Numero maximo de columnas medida (nombres) a devolver. Default 4."
output: >
dict con group_keys (list de {col, cardinality, score} ordenada por score
desc), measures (list[str] de nombres de columnas numericas ordenadas por
dispersion), pivots (list de {index, columns, value}, hasta 2 pares
categorica x categorica con la primera measure como valor) y note (str,
resumen corto en espanol de lo elegido). Ante profile vacio/None devuelve
todas las listas vacias y una note descriptiva; nunca lanza.
---
## Ejemplo
```python
from datascience import select_groupby_keys
# TableProfile estilo titanic: 2 categoricas buenas, 1 numerica medida,
# 1 id secuencial (descartado) y un key_candidate declarado.
profile = {
"n_rows": 891,
"key_candidates": ["passenger_id"],
"columns": [
{"name": "sex", "inferred_type": "categorical", "distinct_count": 2,
"unique_pct": 0.002, "null_pct": 0.0, "flags": [],
"categorical": {"imbalance": 1.8}, "numeric": None},
{"name": "pclass", "inferred_type": "categorical", "distinct_count": 3,
"unique_pct": 0.003, "null_pct": 0.0, "flags": [],
"categorical": {"imbalance": 2.5}, "numeric": None},
{"name": "fare", "inferred_type": "numeric", "distinct_count": 200,
"unique_pct": 0.2, "null_pct": 0.0, "flags": [],
"numeric": {"std": 49.7, "cv": 1.54}, "categorical": None},
{"name": "passenger_id", "inferred_type": "numeric", "distinct_count": 891,
"unique_pct": 1.0, "null_pct": 0.0, "flags": ["possible_id"],
"numeric": {"std": 257.4, "cv": 0.58}, "categorical": None},
],
}
select_groupby_keys(profile)
# {
# "group_keys": [
# {"col": "sex", "cardinality": 2, "score": 0.5556},
# {"col": "pclass", "cardinality": 3, "score": 0.4},
# ],
# "measures": ["fare"], # passenger_id excluido (id secuencial)
# "pivots": [{"index": "sex", "columns": "pclass", "value": "fare"}],
# "note": "2 clave(s) de grupo: sex, pclass; 1 medida(s): fare; 1 pivot(s).",
# }
```
## Cuando usarla
Cuando hayas perfilado una tabla con el grupo `eda` (p.ej.
`summarize_table_duckdb`) y necesites decidir, sin mirar los datos, por qué
columnas merece la pena agrupar (GROUP BY) y qué métricas numéricas agregar:
el respaldo cuantitativo del capítulo de agregación/OLAP de un AutomaticEDA, o
para proponer pivotes en un dashboard. Es la capa de selección sobre el
TableProfile crudo: lee el perfil, ordena candidatos de forma determinista y
no toca los datos.
## Notas
Función pura, sin I/O ni dependencias externas (solo stdlib), no muta
`profile`. Lectura defensiva total (`.get`, `or []`, `isinstance`): un `{}` o
`None` produce `{"group_keys": [], "measures": [], "pivots": [], "note": ...}`
y nunca lanza.
Criterios de selección (deterministas):
- **group_keys** — candidatas con `inferred_type` en `("categorical","boolean")`.
Se descartan las que estén en `key_candidates`, con flag
`possible_id`/`high_cardinality`/`constant`, con `distinct_count` fuera de
`[2, max_card]`, o all-null (`null_pct >= 0.999`). `score = card_score *
balance_score`: `card_score` mantiene un plateau para cardinalidad moderada
(2..12) y decae hacia `max_card`; `balance_score = 1/imbalance` usando
`categorical.imbalance` si está, aproximando con `mode_pct` si no, o un valor
neutro (0.5) en último caso. Devuelve hasta `max_keys`, ordenadas por score
desc (empates por orden de columna).
- **measures** — candidatas con `inferred_type` en
`("numeric","integer","float")`. Se descartan id-like (flag `possible_id` y
`unique_pct >= 0.99`) y constantes (`numeric.std` == 0 o None). Se rankean por
dispersión informativa: `abs(cv)` si está, si no `abs(std)`. Devuelve hasta
`max_measures` **nombres** (strings).
- **pivots** — hasta 2 pares `(group_keys[i].col, group_keys[j].col)` con i<j y
la primera measure como valor. Vacío si hay menos de 2 group_keys.
Caveat de ranking de measures: mezclar `cv` (adimensional) con `std` (en
unidades de la columna) cuando una columna carece de `cv` puede dar órdenes
poco comparables entre columnas; se prefiere `cv` siempre que esté disponible.
@@ -0,0 +1,310 @@
"""Pure EDA helper: pick GROUP BY keys and measures from a TableProfile.
Given a ``TableProfile`` of the ``eda`` group (the dict produced by, e.g.,
``summarize_table_duckdb``), this function deterministically selects the most
interesting categorical columns to group by (GROUP BY), the numeric measure
columns to aggregate, and a couple of categorical x categorical pivot pairs.
It is the quantitative backbone for the aggregation / OLAP chapter of an
AutomaticEDA: a pure, deterministic ranking over the profile, with no I/O, no
mutation of the input and no external dependencies (stdlib only). It never
raises — a missing or malformed profile yields an empty, well-formed result.
"""
def select_groupby_keys(
profile: dict,
max_keys: int = 3,
max_card: int = 20,
max_measures: int = 4,
) -> dict:
"""Select GROUP BY keys, measures and pivot pairs from a TableProfile.
Reads everything defensively (``.get(...)``, ``or []``, ``isinstance``) and
never raises. With an empty/None profile it returns every list empty.
Selection rules (deterministic):
- **group_keys** (categorical columns to group by): candidates have
``inferred_type`` in ``("categorical", "boolean")``. Discarded if they are
in ``profile['key_candidates']``, carry a ``possible_id`` /
``high_cardinality`` / ``constant`` flag, have ``distinct_count`` outside
``[2, max_card]``, or are all-null (``null_pct >= 0.999``). Each survivor
gets ``score = card_score * balance_score`` where ``card_score`` keeps a
plateau for moderate cardinality (2..12) and decays towards ``max_card``,
and ``balance_score = 1 / imbalance`` (``categorical.imbalance`` if
present, else approximated from ``mode_pct``, else a neutral default).
The top ``max_keys`` by score (desc, ties by column order) are returned.
- **measures** (numeric columns to aggregate): candidates have
``inferred_type`` in ``("numeric", "integer", "float")``. Discarded if
id-like (``possible_id`` flag *and* ``unique_pct >= 0.99``) or constant
(``numeric.std`` is ``0`` or ``None``). Ranked by informative dispersion:
``abs(cv)`` when available, else ``abs(std)``. The top ``max_measures``
**names** are returned.
- **pivots**: up to 2 ``(group_keys[i].col, group_keys[j].col)`` pairs with
``i < j``, using the first measure as the aggregated value. Empty when
fewer than 2 group keys were selected.
Args:
profile: TableProfile dict of the ``eda`` group. Relevant keys:
``columns`` (list[ColumnProfile]), ``key_candidates`` (list of
column names or ``{name}`` dicts), ``n_rows``. Each ColumnProfile
uses: ``name``, ``inferred_type``, ``distinct_count``,
``unique_pct`` (0..1), ``null_pct`` (0..1), ``flags`` (list[str]),
``numeric`` ({std, cv, ...}|None), ``categorical``
({imbalance, mode_pct, ...}|None).
max_keys: Maximum number of group-by keys to return. Default 3.
max_card: Maximum cardinality (``distinct_count``) a categorical column
may have to still qualify as a group key. Default 20.
max_measures: Maximum number of measure names to return. Default 4.
Returns:
dict with:
group_keys (list[{col, cardinality, score}], ordered by score desc),
measures (list[str], numeric column names ordered by dispersion),
pivots (list[{index, columns, value}], up to 2 pairs),
note (str, short summary of what was chosen).
"""
if not isinstance(profile, dict):
profile = {}
try:
max_keys = int(max_keys)
except (TypeError, ValueError):
max_keys = 3
try:
max_card = int(max_card)
except (TypeError, ValueError):
max_card = 20
try:
max_measures = int(max_measures)
except (TypeError, ValueError):
max_measures = 4
max_keys = max(max_keys, 0)
max_card = max(max_card, 2)
max_measures = max(max_measures, 0)
columns = profile.get("columns") or []
if not isinstance(columns, (list, tuple)):
columns = []
key_names = _key_candidate_names(profile.get("key_candidates"))
group_keys = _select_group_keys(columns, key_names, max_keys, max_card)
measures = _select_measures(columns, max_measures)
pivots = _select_pivots(group_keys, measures)
return {
"group_keys": group_keys,
"measures": measures,
"pivots": pivots,
"note": _build_note(group_keys, measures, pivots),
}
# ---------------------------------------------------------------------------
# group_keys
# ---------------------------------------------------------------------------
_GROUP_TYPES = ("categorical", "boolean")
_DISQUALIFYING_FLAGS = frozenset({"possible_id", "high_cardinality", "constant"})
_CARD_PLATEAU_HI = 12 # cardinalities 2..12 are all "moderate" (best).
def _select_group_keys(columns, key_names, max_keys, max_card) -> list:
"""Rank categorical/boolean columns suitable for GROUP BY."""
scored = []
for idx, col in enumerate(columns):
if not isinstance(col, dict):
continue
if (col.get("inferred_type") or "") not in _GROUP_TYPES:
continue
name = col.get("name")
if name is None:
continue
if name in key_names:
continue
flags = _as_set(col.get("flags"))
if flags & _DISQUALIFYING_FLAGS:
continue
if _num(col.get("null_pct"), 0.0) >= 0.999:
continue
card = _num(col.get("distinct_count"), 0.0)
if card < 2 or card > max_card:
continue
card_i = int(card)
score = _card_score(card_i, max_card) * _balance_score(col.get("categorical"))
scored.append((round(score, 6), idx, name, card_i))
# Deterministic: higher score first, ties broken by original column order.
scored.sort(key=lambda t: (-t[0], t[1]))
out = []
for score, _idx, name, card_i in scored[:max_keys]:
out.append({"col": name, "cardinality": card_i, "score": score})
return out
def _card_score(card: int, max_card: int) -> float:
"""Prefer moderate cardinality; plateau at 2..12, decay towards max_card."""
if card <= 1:
return 0.0
if card <= _CARD_PLATEAU_HI:
return 1.0
denom = max(max_card - _CARD_PLATEAU_HI, 1)
over = card - _CARD_PLATEAU_HI
return max(0.1, 1.0 - over / denom)
def _balance_score(categorical) -> float:
"""1.0 for a perfectly balanced category, decaying as imbalance grows.
Uses ``categorical.imbalance`` (max_count/min_count, >= 1) when available;
otherwise approximates from ``mode_pct`` (top-class dominance); otherwise a
neutral default so the column is still selectable.
"""
if isinstance(categorical, dict):
imbalance = categorical.get("imbalance")
if isinstance(imbalance, (int, float)) and imbalance >= 1.0:
return 1.0 / float(imbalance)
mode_pct = categorical.get("mode_pct")
if isinstance(mode_pct, (int, float)):
return _clamp(1.0 - float(mode_pct), 0.0, 1.0)
return 0.5
# ---------------------------------------------------------------------------
# measures
# ---------------------------------------------------------------------------
_NUMERIC_TYPES = ("numeric", "integer", "float")
def _select_measures(columns, max_measures) -> list:
"""Rank numeric columns by informative dispersion (cv, else std)."""
scored = []
for idx, col in enumerate(columns):
if not isinstance(col, dict):
continue
if (col.get("inferred_type") or "") not in _NUMERIC_TYPES:
continue
name = col.get("name")
if name is None:
continue
flags = _as_set(col.get("flags"))
unique_pct = _num(col.get("unique_pct"), 0.0)
if "possible_id" in flags and unique_pct >= 0.99:
continue # sequential id, not a measure.
numeric = col.get("numeric")
std = numeric.get("std") if isinstance(numeric, dict) else None
if not isinstance(std, (int, float)) or std == 0:
continue # constant or unknown spread -> not informative.
cv = numeric.get("cv") if isinstance(numeric, dict) else None
if isinstance(cv, (int, float)):
dispersion = abs(float(cv))
else:
dispersion = abs(float(std))
scored.append((dispersion, idx, name))
# Higher dispersion first, ties broken by original column order.
scored.sort(key=lambda t: (-t[0], t[1]))
return [name for _disp, _idx, name in scored[:max_measures]]
# ---------------------------------------------------------------------------
# pivots
# ---------------------------------------------------------------------------
def _select_pivots(group_keys, measures) -> list:
"""Up to 2 (cat_a, cat_b) pairs from the chosen group keys."""
if not isinstance(group_keys, list) or len(group_keys) < 2:
return []
value = measures[0] if measures else None
pairs = []
n = len(group_keys)
for i in range(n):
for j in range(i + 1, n):
pairs.append({
"index": group_keys[i].get("col"),
"columns": group_keys[j].get("col"),
"value": value,
})
if len(pairs) >= 2:
return pairs
return pairs
# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------
def _build_note(group_keys, measures, pivots) -> str:
"""One-line Spanish summary of the selection."""
parts = []
if group_keys:
cols = ", ".join(str(g.get("col")) for g in group_keys)
parts.append(f"{len(group_keys)} clave(s) de grupo: {cols}")
else:
parts.append("sin categóricas agrupables")
if measures:
parts.append(f"{len(measures)} medida(s): " + ", ".join(str(m) for m in measures))
else:
parts.append("sin medidas numéricas")
if pivots:
parts.append(f"{len(pivots)} pivot(s)")
return "; ".join(parts) + "."
def _key_candidate_names(key_candidates) -> set:
"""Normalize ``key_candidates`` (strings or ``{name}`` dicts) to a name set."""
names = set()
if not isinstance(key_candidates, (list, tuple)):
return names
for entry in key_candidates:
if isinstance(entry, str):
names.add(entry)
elif isinstance(entry, dict):
nm = entry.get("name") or entry.get("col")
if nm is not None:
names.add(nm)
return names
def _as_set(flags) -> set:
"""Coerce a flags value into a set, tolerating None / non-iterables."""
if isinstance(flags, (list, tuple, set)):
return set(flags)
return set()
def _num(value, default: float) -> float:
"""Best-effort float conversion with a fallback default."""
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def _clamp(x: float, lo: float, hi: float) -> float:
"""Recorta x al rango [lo, hi]."""
if x < lo:
return lo
if x > hi:
return hi
return x
@@ -0,0 +1,213 @@
"""Tests para select_groupby_keys (grupo eda, dominio datascience)."""
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from select_groupby_keys import select_groupby_keys
def _cat_col(name, card, *, imbalance=2.0, flags=None, null_pct=0.0):
"""ColumnProfile categorico minimo con bloque categorical."""
return {
"name": name,
"inferred_type": "categorical",
"distinct_count": card,
"unique_pct": card / 1000.0,
"null_pct": null_pct,
"flags": flags or [],
"numeric": None,
"categorical": {"imbalance": imbalance, "mode_pct": 0.5, "n_distinct": card},
}
def _num_col(name, *, std, cv, flags=None, unique_pct=0.1):
"""ColumnProfile numerico minimo con bloque numeric."""
return {
"name": name,
"inferred_type": "numeric",
"distinct_count": 200,
"unique_pct": unique_pct,
"null_pct": 0.0,
"flags": flags or [],
"numeric": {"std": std, "cv": cv},
"categorical": None,
}
def _titanic_like_profile() -> dict:
"""Perfil estilo titanic: 2 categoricas buenas, 2 numericas, 1 id, 1 constante."""
return {
"n_rows": 891,
"key_candidates": ["passenger_id"],
"columns": [
_cat_col("sex", 2, imbalance=1.8),
_cat_col("pclass", 3, imbalance=2.5),
_num_col("age", std=14.5, cv=0.49),
_num_col("fare", std=49.7, cv=1.54),
# id secuencial: flag possible_id + unique_pct alto.
{
"name": "passenger_id",
"inferred_type": "numeric",
"distinct_count": 891,
"unique_pct": 1.0,
"null_pct": 0.0,
"flags": ["possible_id"],
"numeric": {"std": 257.4, "cv": 0.58},
"categorical": None,
},
# columna constante: flag constant + std 0.
{
"name": "embarked_const",
"inferred_type": "categorical",
"distinct_count": 1,
"unique_pct": 0.001,
"null_pct": 0.0,
"flags": ["constant"],
"numeric": None,
"categorical": {"imbalance": 1.0},
},
],
}
def test_titanic_picks_good_cats_excludes_id_and_constant():
out = select_groupby_keys(_titanic_like_profile())
# Elige las dos categoricas buenas.
chosen_cols = {g["col"] for g in out["group_keys"]}
assert chosen_cols == {"sex", "pclass"}
# Excluye la constante y el key_candidate.
assert "embarked_const" not in chosen_cols
assert "passenger_id" not in chosen_cols
# Cada group key trae col, cardinality y score.
for g in out["group_keys"]:
assert set(g.keys()) == {"col", "cardinality", "score"}
assert isinstance(g["score"], float)
by_col = {g["col"]: g for g in out["group_keys"]}
assert by_col["sex"]["cardinality"] == 2
assert by_col["pclass"]["cardinality"] == 3
# Ordenadas por score descendente.
scores = [g["score"] for g in out["group_keys"]]
assert scores == sorted(scores, reverse=True)
def test_titanic_measures_exclude_id_constant_and_keep_numerics():
out = select_groupby_keys(_titanic_like_profile())
# Solo nombres (strings) de numericas informativas, sin el id secuencial.
assert all(isinstance(m, str) for m in out["measures"])
assert "passenger_id" not in out["measures"]
assert set(out["measures"]) == {"age", "fare"}
# fare tiene mayor cv (1.54 > 0.49) -> primero.
assert out["measures"][0] == "fare"
def test_titanic_generates_one_pivot():
out = select_groupby_keys(_titanic_like_profile())
# Con 2 group keys -> exactamente 1 pivot.
assert len(out["pivots"]) == 1
pivot = out["pivots"][0]
assert set(pivot.keys()) == {"index", "columns", "value"}
assert {pivot["index"], pivot["columns"]} == {"sex", "pclass"}
# El valor es la primera measure (fare).
assert pivot["value"] == "fare"
def test_empty_profile_returns_all_empty_and_does_not_crash():
out = select_groupby_keys({})
assert out["group_keys"] == []
assert out["measures"] == []
assert out["pivots"] == []
assert isinstance(out["note"], str)
def test_none_profile_does_not_crash():
out = select_groupby_keys(None)
assert out == {
"group_keys": [],
"measures": [],
"pivots": [],
"note": out["note"],
}
assert isinstance(out["note"], str)
def test_only_numerics_yields_empty_group_keys_and_no_pivots():
profile = {
"n_rows": 500,
"key_candidates": [],
"columns": [
_num_col("price", std=12.0, cv=0.6),
_num_col("weight", std=3.0, cv=0.2),
],
}
out = select_groupby_keys(profile)
assert out["group_keys"] == []
assert out["pivots"] == []
# Las numericas si se eligen como measures.
assert set(out["measures"]) == {"price", "weight"}
assert out["measures"][0] == "price" # mayor cv.
def test_high_cardinality_and_max_card_are_excluded():
profile = {
"n_rows": 1000,
"key_candidates": [],
"columns": [
_cat_col("city", 50, flags=["high_cardinality"]), # flag -> fuera.
_cat_col("zone", 35), # card 35 > max_card 20 -> fuera.
_cat_col("region", 5), # valida.
],
}
out = select_groupby_keys(profile, max_card=20)
assert {g["col"] for g in out["group_keys"]} == {"region"}
def test_max_keys_limits_group_keys():
profile = {
"n_rows": 1000,
"key_candidates": [],
"columns": [
_cat_col("a", 4, imbalance=1.0),
_cat_col("b", 5, imbalance=1.2),
_cat_col("c", 6, imbalance=1.5),
_cat_col("d", 7, imbalance=2.0),
],
}
out = select_groupby_keys(profile, max_keys=2)
assert len(out["group_keys"]) == 2
# Hasta 2 pivots con >=2 keys (aqui exactamente 1 par posible entre 2 keys).
assert len(out["pivots"]) == 1
def test_three_keys_cap_pivots_to_two():
profile = {
"n_rows": 1000,
"key_candidates": [],
"columns": [
_cat_col("a", 4, imbalance=1.0),
_cat_col("b", 5, imbalance=1.1),
_cat_col("c", 6, imbalance=1.2),
_num_col("m", std=10.0, cv=0.5),
],
}
out = select_groupby_keys(profile, max_keys=3)
assert len(out["group_keys"]) == 3
# 3 keys -> 3 pares posibles, capado a 2.
assert len(out["pivots"]) == 2
for p in out["pivots"]:
assert p["value"] == "m"
def test_does_not_mutate_input():
profile = _titanic_like_profile()
before = repr(profile)
select_groupby_keys(profile)
assert repr(profile) == before
@@ -0,0 +1,96 @@
---
name: suggest_aggregations_llm
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: impure
signature: "def suggest_aggregations_llm(profile: dict, candidates: dict, max_aggs: int = 4, model: str = \"claude-haiku-4-5-20251001\") -> dict"
description: "MUST-11.1 del capitulo AGREGACION del AutomaticEDA (grupo eda). Dado el TableProfile de una tabla y los candidatos cuantitativos de select_groupby_keys ({group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}), con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una razon corta cada uno, evitando la explosion combinatoria (no todo contra todo). Privacidad/coste: NO envia filas crudas, solo el resumen AGREGADO de los candidatos (tabla, columnas categoricas con cardinalidad/score, medidas, pivots). Reusa ask_llm del grupo claude-direct (API directa con token OAuth de Claude). Impura, dict-no-throw: NUNCA lanza y SIEMPRE devuelve algo usable; si el LLM falla, el JSON no parsea o no hay seleccion valida, cae a un fallback determinista construido desde los candidatos (source='fallback'). Toda columna que el LLM invente se descarta."
tags: [eda, claude-direct, llm, aggregation, groupby, pivot, datascience, automatic-eda]
params:
- name: profile
desc: "TableProfile del grupo eda. Solo se usa profile['table'] para nombrar la tabla en el prompt; puede ir vacio o sin esa clave (se usa '(tabla sin nombre)')."
- name: candidates
desc: "Salida de select_groupby_keys: {group_keys:[{col, cardinality, score}], measures:[str], pivots:[{index, columns, value}]}. group_keys = columnas categoricas candidatas para GROUP BY; measures = columnas numericas a agregar (sum/avg); pivots = cruces index x columns -> value sugeridos. Cualquier columna que el LLM elija debe existir aqui o se descarta. None o no-dict se trata como vacio."
- name: max_aggs
desc: "Tope de agregaciones a devolver. Default 4. Valores <1 o no-int se normalizan a 4. Limita tanto la seleccion del LLM como el fallback determinista, para evitar la explosion combinatoria."
- name: model
desc: "id del modelo Anthropic a usar en la unica llamada. Default 'claude-haiku-4-5-20251001' (haiku, coste bajo, ~2-3s). Para razones mas finas, pasar p.ej. 'claude-opus-4-8'."
output: "dict dict-no-throw: {status:'ok', source:'llm'|'fallback', aggregations:[{group_by:str, measures:[str], why:str}], pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}. source=='llm' si el LLM produjo al menos una agregacion valida (columnas existentes en candidates); en cualquier otro caso (LLM caido, JSON invalido, seleccion vacia, sin candidatos) source=='fallback' y aggregations/pivots se derivan de candidates con why='selección cuantitativa (sin LLM)'. NUNCA lanza."
uses_functions: [ask_llm_py_core, select_groupby_keys_py_datascience]
uses_types: []
returns: []
returns_optional: false
error_type: "error_go_core"
imports: []
tested: true
tests: ["test_extract_json_object", "test_extract_json_wrapped_in_fences_and_junk", "test_extract_json_non_json_returns_none", "test_validate_aggregations_drops_invalid_columns", "test_llm_path_uses_selection", "test_llm_path_respects_max_aggs", "test_llm_invented_column_is_discarded", "test_fallback_on_empty_llm_response", "test_fallback_on_unparseable_response", "test_fallback_respects_max_aggs", "test_fallback_when_llm_raises", "test_no_candidates_returns_empty_fallback", "test_non_dict_candidates_does_not_raise"]
test_file_path: "python/functions/datascience/suggest_aggregations_llm_test.py"
file_path: "python/functions/datascience/suggest_aggregations_llm.py"
---
## Ejemplo
```python
import sys, os
sys.path.insert(0, os.path.join("python", "functions"))
from datascience.suggest_aggregations_llm import suggest_aggregations_llm
profile = {"table": "ventas"}
# candidates = salida de select_groupby_keys (aqui literal de ejemplo).
candidates = {
"group_keys": [
{"col": "categoria", "cardinality": 8, "score": 0.91},
{"col": "region", "cardinality": 5, "score": 0.74},
{"col": "canal", "cardinality": 3, "score": 0.60},
],
"measures": ["importe", "unidades"],
"pivots": [
{"index": "categoria", "columns": "region", "value": "importe"},
],
}
out = suggest_aggregations_llm(profile, candidates, max_aggs=4) # haiku por defecto
print("fuente:", out["source"]) # "llm" o "fallback" si no hay red
for agg in out["aggregations"]:
print(f"GROUP BY {agg['group_by']} -> {agg['measures']} ({agg['why']})")
for piv in out["pivots"]:
print(f"pivot {piv['index']} x {piv['columns']} = {piv['value']} ({piv['why']})")
```
## Cuando usarla
Justo despues de `select_groupby_keys` en el capitulo AGREGACION del AutomaticEDA:
cuando ya tienes los candidatos cuantitativos (columnas categoricas con cardinalidad,
medidas numericas y pivots posibles) y quieres que un LLM se quede con las K
agregaciones y pivots MAS INFORMATIVOS en vez de generar "todo contra todo". Usala para
priorizar el plan de analisis de grupos antes de materializar las tablas con
`aggregate_by_group` / pivots, manteniendo el coste y el ruido bajos. Si no hay red o
credenciales, sigue funcionando con un fallback determinista, asi que es seguro
ponerla en un pipeline.
## Gotchas
- **Impura: hace 1 llamada de red al LLM.** No es determinista ni gratis. Latencia
tipica ~2-3s con haiku. Una sola llamada cubre toda la seleccion.
- **Requiere token OAuth de Claude** en `~/.claude/.credentials.json` (via `ask_llm` /
grupo `claude-direct`). Sin token / sin red NO lanza: cae al **fallback
determinista** (`source="fallback"`) construido desde `candidates`
(group_keys x measures hasta `max_aggs`, pivots tal cual) con
`why="selección cuantitativa (sin LLM)"`. Comprueba `out["source"]` para saber si la
seleccion vino del LLM o del fallback.
- **NO envia filas crudas al LLM**, solo el resumen AGREGADO de los candidatos. Esto
exige que `candidates` venga ya calculado por `select_groupby_keys` (cardinalidades,
scores, medidas, pivots).
- **Valida columnas inventadas**: si el LLM propone un `group_by`/`measure`/`index`/
`columns` que no esta en `candidates`, esa entrada se descarta (las medidas se
recortan a las validas). Si tras validar no queda ninguna agregacion, cae al
fallback completo.
- **`max_aggs` acota la explosion combinatoria** tanto en el camino LLM como en el
fallback. Subirlo demasiado reintroduce el ruido que esta funcion evita.
- **Modelo `haiku` por defecto** para coste bajo; sube a `claude-opus-4-8` si necesitas
razones (`why`) mas finas (mas caro y lento).
@@ -0,0 +1,405 @@
"""suggest_aggregations_llm — el LLM elige las agregaciones mas informativas (grupo `eda`).
MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Dado el `TableProfile` de una
tabla y los CANDIDATOS cuantitativos que produce `select_groupby_keys`
(`{group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}`),
con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x
medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una
razon corta cada uno. El objetivo es evitar la explosion combinatoria: en vez de
"todo contra todo", el LLM se queda con lo que mas informa.
Privacidad y coste: NO se envian filas crudas al LLM. El prompt solo lleva el resumen
AGREGADO de los candidatos (nombre de la tabla, columnas categoricas con su
cardinalidad/score, medidas y pivots posibles). Una sola llamada barata.
Reusa `ask_llm` del registry (grupo claude-direct, API directa con el token OAuth de
Claude en ~/.claude/.credentials.json, arranque 0). Impura: una llamada de red.
Estilo dict-no-throw con FALLBACK DETERMINISTA: la funcion NUNCA lanza y SIEMPRE
devuelve algo usable. Si `ask_llm` falla (devuelve ""), el JSON no parsea, o el LLM no
produce ninguna seleccion valida, se construye la respuesta directamente desde los
candidatos (group_keys x measures hasta max_aggs, pivots tal cual) con
`source="fallback"`. Ademas, toda columna que el LLM invente (no presente en los
candidatos) se descarta.
"""
import json
from core.ask_llm import ask_llm
_SYSTEM = (
"Eres un analista de datos conciso. Te dan los CANDIDATOS AGREGADOS de una tabla "
"(columnas categoricas para GROUP BY con su cardinalidad, medidas numericas y "
"pivots posibles) y eliges las agregaciones y pivots MAS INFORMATIVOS para "
"entender los grupos, evitando la explosion combinatoria (no todo contra todo). "
"No recibes filas crudas. Responde en espanol. Responde SIEMPRE y SOLO con un "
"unico objeto JSON valido, sin texto alrededor ni fences de markdown, con la forma "
'{"aggregations": [{"group_by": "<col categorica>", "measures": ["<medida>", ...], '
'"why": "<razon corta>"}], "pivots": [{"index": "<col>", "columns": "<col>", '
'"value": "<medida o null>", "why": "<razon corta>"}]}. Usa SOLO nombres de columna '
"que aparezcan en los candidatos; no inventes nombres."
)
def _fmt_num(value) -> str:
"""Formatea un numero de forma compacta para el prompt (None -> '?')."""
if value is None:
return "?"
if isinstance(value, bool):
return str(value)
if isinstance(value, float):
if value == int(value):
return str(int(value))
return f"{value:.4g}"
return str(value)
def _candidate_view(candidates: dict):
"""Extrae las vistas utiles de los candidatos. Funcion interna PURA.
Devuelve la tupla (group_cols, measures, measure_set, pivots, group_keys):
- group_cols: set de nombres de columna categorica validas (de group_keys[].col).
- measures: lista de medidas numericas (str) preservando orden.
- measure_set: set de las medidas para validar pertenencia rapido.
- pivots: lista de pivots candidatos (dicts) tal cual vienen.
- group_keys: lista de dicts {col, cardinality, score} ya filtrada a entradas validas.
Tolera estructuras incompletas o de tipo incorrecto sin lanzar.
"""
candidates = candidates if isinstance(candidates, dict) else {}
gk_raw = candidates.get("group_keys")
group_keys = []
if isinstance(gk_raw, list):
for gk in gk_raw:
if isinstance(gk, dict) and isinstance(gk.get("col"), str):
group_keys.append(gk)
group_cols = {gk["col"] for gk in group_keys}
m_raw = candidates.get("measures")
measures = [m for m in m_raw if isinstance(m, str)] if isinstance(m_raw, list) else []
measure_set = set(measures)
p_raw = candidates.get("pivots")
pivots = p_raw if isinstance(p_raw, list) else []
return group_cols, measures, measure_set, pivots, group_keys
def _sorted_group_cols(group_keys: list) -> list:
"""Nombres de columna categorica ordenados por score descendente. PURA."""
def _score(gk):
s = gk.get("score")
if isinstance(s, (int, float)) and not isinstance(s, bool):
return s
return 0.0
return [gk["col"] for gk in sorted(group_keys, key=_score, reverse=True)]
def _build_prompt(profile: dict, candidates: dict, max_aggs: int) -> str:
"""Construye el prompt compacto SOLO con agregados. Funcion interna PURA.
No toca red ni disco: testeable sin credenciales. Incluye el nombre de la tabla,
las columnas categoricas candidatas (con cardinalidad y score), las medidas
numericas y los pivots candidatos. Nunca filas crudas.
Args:
profile: TableProfile (se usa solo profile['table'] para nombrar la tabla).
candidates: salida de select_groupby_keys.
max_aggs: tope de agregaciones a pedir.
Returns:
El texto del prompt.
"""
profile = profile if isinstance(profile, dict) else {}
candidates = candidates if isinstance(candidates, dict) else {}
table = profile.get("table")
table = str(table) if table is not None else "(tabla sin nombre)"
lines = [
f"Tabla: {table}",
(
"Tarea: elegir las agregaciones (GROUP BY categorica x medidas numericas) y "
"los pivots MAS INFORMATIVOS para un analisis de grupos. Evita la explosion "
"combinatoria: NO combines todo contra todo, prioriza lo que mas informa."
),
f"Devuelve a lo sumo {max_aggs} agregaciones.",
"",
"Columnas categoricas candidatas para GROUP BY (col: cardinalidad, score):",
]
group_keys = candidates.get("group_keys") or []
for gk in group_keys:
if not isinstance(gk, dict) or not isinstance(gk.get("col"), str):
continue
lines.append(
f" - {gk['col']}: cardinalidad={_fmt_num(gk.get('cardinality'))}, "
f"score={_fmt_num(gk.get('score'))}"
)
measures = candidates.get("measures") or []
lines.append("")
lines.append("Medidas numericas disponibles (para sum/avg por grupo):")
lines.append(" " + ", ".join(str(m) for m in measures if isinstance(m, str)))
pivots = candidates.get("pivots") or []
if pivots:
lines.append("")
lines.append("Pivots candidatos (index x columns -> value):")
for p in pivots:
if not isinstance(p, dict):
continue
lines.append(
f" - index={p.get('index')}, columns={p.get('columns')}, "
f"value={p.get('value')}"
)
lines.append("")
lines.append(
"Usa SOLO columnas de las listas anteriores; no inventes nombres. Responde "
"SOLO con el JSON descrito en las instrucciones del sistema."
)
return "\n".join(lines)
def _extract_json(text: str):
"""Extrae el primer bloque JSON (objeto o array) de la respuesta. PURA.
Localiza el bloque que empieza antes (el primer '{' o el primer '[') y, para ese
delimitador, hace json.loads del rango hasta su ultimo cierre. Tolera texto basura
alrededor y fences ```json. NUNCA lanza: ante cualquier fallo devuelve None.
Args:
text: respuesta cruda del LLM.
Returns:
El objeto/lista deserializado, o None si no se pudo parsear.
"""
if not text or not isinstance(text, str):
return None
opens = []
i_obj = text.find("{")
if i_obj != -1:
opens.append((i_obj, "{", "}"))
i_arr = text.find("[")
if i_arr != -1:
opens.append((i_arr, "[", "]"))
opens.sort()
for _, open_c, close_c in opens:
start = text.find(open_c)
end = text.rfind(close_c)
if start != -1 and end != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except (ValueError, TypeError):
continue
return None
def _validate_aggregations(raw_aggs, group_cols: set, measure_set: set, max_aggs: int) -> list:
"""Filtra las agregaciones del LLM a las que usan SOLO columnas candidatas. PURA.
Descarta cualquier agregacion cuyo group_by no este en group_cols o que no tenga
al menos una medida valida. Recorta las medidas a las presentes en measure_set.
Limita el resultado a max_aggs entradas.
"""
out = []
if not isinstance(raw_aggs, list):
return out
for item in raw_aggs:
if not isinstance(item, dict):
continue
gb = item.get("group_by")
if not isinstance(gb, str) or gb not in group_cols:
continue # columna inventada -> se descarta
raw_measures = item.get("measures")
if isinstance(raw_measures, str):
raw_measures = [raw_measures]
if not isinstance(raw_measures, list):
continue
measures = [m for m in raw_measures if isinstance(m, str) and m in measure_set]
if not measures:
continue # sin medidas validas -> agregacion inutil
why = item.get("why")
why = str(why) if why is not None else ""
out.append({"group_by": gb, "measures": measures, "why": why})
if len(out) >= max_aggs:
break
return out
def _validate_pivots(raw_pivots, group_cols: set, measure_set: set) -> list:
"""Filtra los pivots del LLM a los que usan SOLO columnas candidatas. PURA.
Descarta el pivot si index o columns no son columnas categoricas validas. Si el
value no es una medida valida, lo deja en None (un pivot de conteo sigue siendo util).
"""
out = []
if not isinstance(raw_pivots, list):
return out
for item in raw_pivots:
if not isinstance(item, dict):
continue
idx = item.get("index")
cols = item.get("columns")
if not (isinstance(idx, str) and idx in group_cols):
continue
if not (isinstance(cols, str) and cols in group_cols):
continue
val = item.get("value")
if not (isinstance(val, str) and val in measure_set):
val = None
why = item.get("why")
why = str(why) if why is not None else ""
out.append({"index": idx, "columns": cols, "value": val, "why": why})
return out
def _fallback_aggregations(group_cols_sorted: list, measures: list, max_aggs: int) -> list:
"""Agregaciones deterministas: cada columna categorica x todas las medidas. PURA."""
out = []
for col in group_cols_sorted:
out.append(
{
"group_by": col,
"measures": list(measures),
"why": "selección cuantitativa (sin LLM)",
}
)
if len(out) >= max_aggs:
break
return out
def _fallback_pivots(cand_pivots: list) -> list:
"""Normaliza los pivots candidatos a la forma de salida (tal cual + why). PURA."""
out = []
if not isinstance(cand_pivots, list):
return out
for p in cand_pivots:
if not isinstance(p, dict):
continue
idx = p.get("index")
cols = p.get("columns")
if not (isinstance(idx, str) and isinstance(cols, str)):
continue
val = p.get("value")
if not isinstance(val, str):
val = None
out.append(
{
"index": idx,
"columns": cols,
"value": val,
"why": "selección cuantitativa (sin LLM)",
}
)
return out
def suggest_aggregations_llm(
profile: dict,
candidates: dict,
max_aggs: int = 4,
model: str = "claude-haiku-4-5-20251001",
) -> dict:
"""Elige las agregaciones y pivots mas informativos con UNA llamada al LLM.
MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Toma el perfil de la tabla y
los candidatos cuantitativos (salida de select_groupby_keys) y deja que el LLM
seleccione/ordene las K agregaciones (GROUP BY categorica x medidas) y los pivots
mas utiles, con una razon corta cada uno, evitando la explosion combinatoria.
Privacidad/coste: solo viaja al LLM el resumen AGREGADO de los candidatos, nunca
filas crudas. Una sola llamada barata.
dict-no-throw con fallback determinista: NUNCA lanza. Si el LLM falla, el JSON no
parsea, o no produce seleccion valida -> construye la respuesta desde los candidatos
(group_keys x measures hasta max_aggs, pivots tal cual) con source="fallback". Las
columnas que el LLM invente (no presentes en los candidatos) se descartan.
Args:
profile: TableProfile del grupo eda. Solo se usa profile['table'] para nombrar
la tabla en el prompt; puede ir vacio.
candidates: salida de select_groupby_keys, con la forma
{group_keys:[{col,cardinality,score}], measures:[str],
pivots:[{index,columns,value}]}.
max_aggs: tope de agregaciones a devolver. Default 4. Valores <1 o no-int se
normalizan a 4.
model: id del modelo Anthropic. Default 'claude-haiku-4-5-20251001' (haiku,
coste bajo, ~2-3s).
Returns:
dict {status:"ok", source:"llm"|"fallback",
aggregations:[{group_by:str, measures:[str], why:str}],
pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}.
source=="llm" si el LLM produjo al menos una agregacion valida; en cualquier
otro caso "fallback". NUNCA lanza.
"""
if not isinstance(candidates, dict):
candidates = {}
if isinstance(max_aggs, bool) or not isinstance(max_aggs, int) or max_aggs < 1:
max_aggs = 4
group_cols, measures, measure_set, cand_pivots, group_keys = _candidate_view(candidates)
group_cols_sorted = _sorted_group_cols(group_keys)
# Sin material suficiente para agregar: no merece la pena llamar al LLM.
if not group_cols or not measures:
return {
"status": "ok",
"source": "fallback",
"aggregations": [],
"pivots": _fallback_pivots(cand_pivots),
"note": "sin candidatos suficientes para agregar",
}
prompt = _build_prompt(profile, candidates, max_aggs)
try:
text = ask_llm(prompt, model=model, system=_SYSTEM, echo=False)
except Exception: # noqa: BLE001 — degradacion: cualquier fallo de red/LLM.
text = ""
parsed = _extract_json(text)
if parsed is not None:
if isinstance(parsed, dict):
raw_aggs = parsed.get("aggregations")
raw_pivots = parsed.get("pivots")
elif isinstance(parsed, list):
raw_aggs = parsed
raw_pivots = None
else:
raw_aggs = None
raw_pivots = None
aggs = _validate_aggregations(raw_aggs, group_cols, measure_set, max_aggs)
if aggs:
pivots = _validate_pivots(raw_pivots, group_cols, measure_set)
if not pivots:
pivots = _fallback_pivots(cand_pivots)
return {
"status": "ok",
"source": "llm",
"aggregations": aggs,
"pivots": pivots,
"note": f"{len(aggs)} agregaciones y {len(pivots)} pivots seleccionados por el LLM",
}
# Fallback determinista.
note = (
"LLM no disponible; selección cuantitativa determinista"
if not text
else "LLM sin selección válida; selección cuantitativa determinista"
)
return {
"status": "ok",
"source": "fallback",
"aggregations": _fallback_aggregations(group_cols_sorted, measures, max_aggs),
"pivots": _fallback_pivots(cand_pivots),
"note": note,
}
@@ -0,0 +1,198 @@
"""Tests para suggest_aggregations_llm.
NO acceden a red ni a credenciales: las funciones internas (_build_prompt,
_extract_json, _validate_*, _fallback_*) son puras y testeables aisladas; la unica
via que llamaria al LLM (suggest_aggregations_llm) se prueba reemplazando el simbolo
`ask_llm` del modulo bajo prueba con una funcion simulada. Los candidatos van
literales en el test: NO se importa select_groupby_keys.
Cubre golden (LLM ok con columnas validas), edge (max_aggs respetado, sin candidatos)
y error (LLM caido -> fallback, JSON invalido -> fallback, columna inventada -> se
descarta). Todos sin tocar la red.
"""
import json
import datascience.suggest_aggregations_llm as M
from datascience.suggest_aggregations_llm import (
_extract_json,
_validate_aggregations,
suggest_aggregations_llm,
)
# Candidatos de ejemplo con la forma que produce select_groupby_keys (literales).
_CANDIDATES = {
"group_keys": [
{"col": "categoria", "cardinality": 8, "score": 0.91},
{"col": "region", "cardinality": 5, "score": 0.74},
{"col": "canal", "cardinality": 3, "score": 0.60},
],
"measures": ["importe", "unidades"],
"pivots": [
{"index": "categoria", "columns": "region", "value": "importe"},
],
}
_PROFILE = {"table": "ventas"}
def _fake_returner(text):
"""Devuelve un ask_llm simulado que ignora args y retorna `text`."""
def _fake(prompt, model="x", system="", echo=True, **kwargs):
return text
return _fake
# --- _extract_json (parser puro, sin red) ---
def test_extract_json_object():
obj = {"aggregations": [{"group_by": "categoria", "measures": ["importe"], "why": "x"}]}
assert _extract_json(json.dumps(obj)) == obj
def test_extract_json_wrapped_in_fences_and_junk():
obj = {"aggregations": [], "pivots": []}
text = "Claro, aqui tienes:\n```json\n" + json.dumps(obj) + "\n```\nFin."
assert _extract_json(text) == obj
def test_extract_json_non_json_returns_none():
assert _extract_json("no hay json aqui") is None
assert _extract_json("") is None
assert _extract_json(None) is None
# --- _validate_aggregations (puro) ---
def test_validate_aggregations_drops_invalid_columns():
group_cols = {"categoria", "region"}
measure_set = {"importe", "unidades"}
raw = [
{"group_by": "categoria", "measures": ["importe", "inventada"], "why": "ok"},
{"group_by": "no_existe", "measures": ["importe"], "why": "mala"},
{"group_by": "region", "measures": ["solo_inventada"], "why": "sin medidas"},
]
out = _validate_aggregations(raw, group_cols, measure_set, max_aggs=4)
# Solo sobrevive la primera, con las medidas recortadas a las validas.
assert out == [{"group_by": "categoria", "measures": ["importe"], "why": "ok"}]
# --- suggest_aggregations_llm: camino LLM (golden) ---
def test_llm_path_uses_selection(monkeypatch):
llm_obj = {
"aggregations": [
{"group_by": "categoria", "measures": ["importe"], "why": "ventas por familia"},
{"group_by": "region", "measures": ["importe", "unidades"], "why": "reparto geografico"},
],
"pivots": [
{"index": "categoria", "columns": "region", "value": "importe", "why": "cruce clave"},
],
}
monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj)))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES)
assert out["status"] == "ok"
assert out["source"] == "llm"
assert out["aggregations"] == llm_obj["aggregations"]
assert out["pivots"][0]["index"] == "categoria"
assert out["pivots"][0]["why"] == "cruce clave"
def test_llm_path_respects_max_aggs(monkeypatch):
llm_obj = {
"aggregations": [
{"group_by": "categoria", "measures": ["importe"], "why": "a"},
{"group_by": "region", "measures": ["importe"], "why": "b"},
{"group_by": "canal", "measures": ["unidades"], "why": "c"},
],
"pivots": [],
}
monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj)))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=2)
assert out["source"] == "llm"
assert len(out["aggregations"]) == 2
def test_llm_invented_column_is_discarded(monkeypatch):
# El LLM mezcla una agregacion valida con otra de columna inexistente.
llm_obj = {
"aggregations": [
{"group_by": "categoria", "measures": ["importe"], "why": "valida"},
{"group_by": "columna_fantasma", "measures": ["importe"], "why": "inventada"},
],
"pivots": [
{"index": "fantasma", "columns": "region", "value": "importe", "why": "mala"},
],
}
monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj)))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES)
assert out["source"] == "llm"
# La agregacion inventada se descarta; queda solo la valida.
assert [a["group_by"] for a in out["aggregations"]] == ["categoria"]
# El pivot con index fantasma se descarta -> cae a los pivots de candidates.
assert all(p["index"] in {"categoria", "region", "canal"} for p in out["pivots"])
# --- suggest_aggregations_llm: fallback determinista (error paths) ---
def test_fallback_on_empty_llm_response(monkeypatch):
monkeypatch.setattr(M, "ask_llm", _fake_returner(""))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=4)
assert out["status"] == "ok"
assert out["source"] == "fallback"
# Las agregaciones se derivan de candidates (una por group_key, con todas las medidas).
assert out["aggregations"][0]["group_by"] in {"categoria", "region", "canal"}
assert out["aggregations"][0]["measures"] == ["importe", "unidades"]
assert out["aggregations"][0]["why"] == "selección cuantitativa (sin LLM)"
# Pivots tal cual de candidates.
assert out["pivots"][0]["index"] == "categoria"
def test_fallback_on_unparseable_response(monkeypatch):
monkeypatch.setattr(M, "ask_llm", _fake_returner("esto no es JSON {roto"))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES)
assert out["source"] == "fallback"
assert len(out["aggregations"]) >= 1
def test_fallback_respects_max_aggs(monkeypatch):
monkeypatch.setattr(M, "ask_llm", _fake_returner(""))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=2)
assert out["source"] == "fallback"
assert len(out["aggregations"]) == 2
def test_fallback_when_llm_raises(monkeypatch):
def _boom(*args, **kwargs):
raise RuntimeError("sin red")
monkeypatch.setattr(M, "ask_llm", _boom)
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES)
assert out["source"] == "fallback"
assert out["aggregations"] # no vacio, no lanza
def test_no_candidates_returns_empty_fallback():
# Sin red porque ni siquiera se llama al LLM (no hay material).
out = suggest_aggregations_llm(_PROFILE, {"group_keys": [], "measures": [], "pivots": []})
assert out["status"] == "ok"
assert out["source"] == "fallback"
assert out["aggregations"] == []
def test_non_dict_candidates_does_not_raise():
out = suggest_aggregations_llm(_PROFILE, None)
assert out["status"] == "ok"
assert out["aggregations"] == []