From 96da9e3015c3eee89a9ea5fd2dd613d28009e5b2 Mon Sep 17 00:00:00 2001 From: Egutierrez Date: Tue, 30 Jun 2026 15:33:55 +0200 Subject: [PATCH 1/2] =?UTF-8?q?feat(eda):=20funciones=20de=20agregaci?= =?UTF-8?q?=C3=B3n/OLAP=20para=20AutomaticEDA=20(groupby/pivot=20push-down?= =?UTF-8?q?=20+=20selecci=C3=B3n=20LLM)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cuatro funciones nuevas del grupo eda que nutren el capítulo AGREGACION: - select_groupby_keys (pure): elige categóricas agrupables + numéricas medida desde el TableProfile. - groupby_stats_duckdb (impure): GROUP BY push-down en DuckDB (count/mean/median/std/min/max por grupo). - pivot_table_duckdb (impure): pivot A×B push-down, limitado a top filas/cols para no cortar. - suggest_aggregations_llm (impure): el LLM elige las agregaciones interesantes con fallback determinista. Co-Authored-By: Claude Opus 4.8 (1M context) --- python/functions/datascience/__init__.py | 6 + .../datascience/groupby_stats_duckdb.md | 87 ++++ .../datascience/groupby_stats_duckdb.py | 184 ++++++++ .../datascience/groupby_stats_duckdb_test.py | 106 +++++ .../datascience/pivot_table_duckdb.md | 92 ++++ .../datascience/pivot_table_duckdb.py | 176 ++++++++ .../datascience/pivot_table_duckdb_test.py | 115 +++++ .../datascience/select_groupby_keys.md | 158 +++++++ .../datascience/select_groupby_keys.py | 310 ++++++++++++++ .../datascience/select_groupby_keys_test.py | 213 +++++++++ .../datascience/suggest_aggregations_llm.md | 96 +++++ .../datascience/suggest_aggregations_llm.py | 405 ++++++++++++++++++ .../suggest_aggregations_llm_test.py | 198 +++++++++ 13 files changed, 2146 insertions(+) create mode 100644 python/functions/datascience/groupby_stats_duckdb.md create mode 100644 python/functions/datascience/groupby_stats_duckdb.py create mode 100644 python/functions/datascience/groupby_stats_duckdb_test.py create mode 100644 python/functions/datascience/pivot_table_duckdb.md create mode 100644 python/functions/datascience/pivot_table_duckdb.py create mode 100644 python/functions/datascience/pivot_table_duckdb_test.py create mode 100644 python/functions/datascience/select_groupby_keys.md create mode 100644 python/functions/datascience/select_groupby_keys.py create mode 100644 python/functions/datascience/select_groupby_keys_test.py create mode 100644 python/functions/datascience/suggest_aggregations_llm.md create mode 100644 python/functions/datascience/suggest_aggregations_llm.py create mode 100644 python/functions/datascience/suggest_aggregations_llm_test.py diff --git a/python/functions/datascience/__init__.py b/python/functions/datascience/__init__.py index 9fc8c206..f2746d11 100644 --- a/python/functions/datascience/__init__.py +++ b/python/functions/datascience/__init__.py @@ -25,6 +25,7 @@ from .describe_numeric import describe_numeric from .summarize_categorical import summarize_categorical from .infer_semantic_type import infer_semantic_type from .column_quality_score import column_quality_score +from .select_groupby_keys import select_groupby_keys from .render_eda_markdown import render_eda_markdown from .detect_distribution_type import detect_distribution_type from .spearman_corr import spearman_corr @@ -36,6 +37,8 @@ from .infer_fk_containment_duckdb import infer_fk_containment_duckdb from .build_join_graph import build_join_graph from .association_matrix import association_matrix from .correlation_matrix_duckdb import correlation_matrix_duckdb +from .pivot_table_duckdb import pivot_table_duckdb +from .groupby_stats_duckdb import groupby_stats_duckdb from .pca_explained import pca_explained from .kmeans_segments import kmeans_segments from .isolation_forest_outliers import isolation_forest_outliers @@ -82,6 +85,8 @@ __all__ = [ "build_join_graph", "association_matrix", "correlation_matrix_duckdb", + "pivot_table_duckdb", + "groupby_stats_duckdb", "pca_explained", "kmeans_segments", "isolation_forest_outliers", @@ -96,6 +101,7 @@ __all__ = [ "summarize_categorical", "infer_semantic_type", "column_quality_score", + "select_groupby_keys", "render_eda_markdown", "detect_distribution_type", "pull_gsc_search_analytics", diff --git a/python/functions/datascience/groupby_stats_duckdb.md b/python/functions/datascience/groupby_stats_duckdb.md new file mode 100644 index 00000000..faf22222 --- /dev/null +++ b/python/functions/datascience/groupby_stats_duckdb.md @@ -0,0 +1,87 @@ +--- +name: groupby_stats_duckdb +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def groupby_stats_duckdb(db_path: str, table: str, group_by: str, measures: list, aggs: list = None, top_n: int = 15) -> dict" +description: "Agregaciones GROUP BY con push-down SQL en DuckDB: para cada measure numerica calcula mean/median/std/min/max por grupo (split-apply-combine en el motor), trayendo solo una fila por grupo. Nucleo de un capitulo de agregacion/OLAP de un EDA. count = tamanio del grupo, independiente de measures." +tags: [eda, groupby, aggregation, olap, duckdb, datascience, push-down, split-apply-combine] +uses_functions: [duckdb_query_readonly_py_infra] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: db_path + desc: "Ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la base. Path inexistente -> {status:'error'} sin lanzar." + - name: table + desc: "Nombre de la tabla. Se interpola citado con dobles comillas (soporta nombres con espacios; las comillas internas se escapan)." + - name: group_by + desc: "Columna por la que agrupar. Se interpola citada. Sus valores distintos son las claves de los grupos." + - name: measures + desc: "Lista de columnas numericas a agregar. Lista vacia es valida: cada grupo trae solo su tamanio `n` y `stats` vacio." + - name: aggs + desc: "Lista de agregaciones. None (default) = ['count','mean','median','std','min','max']. Validas: count (tamanio del grupo, va a `n`), mean->avg, median, std->stddev_samp, min, max (estas cinco por measure). Agg desconocido -> error." + - name: top_n + desc: "Maximo de grupos a devolver, ordenados por tamanio de grupo descendente (default 15). Internamente se piden top_n+1 para detectar truncado." +output: "dict. En exito {status:'ok', group_by, measures:[...], aggs:[...], n_groups:int, truncated:bool, groups:[{key:, n:int, stats:{:{mean,median,std,min,max}}}], note:str}. Las estadisticas son float o None (p.ej. std de un grupo de 1 fila -> NULL -> None). En error {status:'error', error:str} (no lanza)." +tested: true +tests: ["agrega por grupo con valores conocidos", "db inexistente devuelve error sin lanzar", "measures vacias agrega solo count", "columna con espacio agrupa bien"] +test_file_path: "python/functions/datascience/groupby_stats_duckdb_test.py" +file_path: "python/functions/datascience/groupby_stats_duckdb.py" +--- + +## Ejemplo + +```python +import duckdb +from datascience import groupby_stats_duckdb + +# Cargar el titanic en una tabla DuckDB de prueba. +db = "/tmp/titanic.duckdb" +con = duckdb.connect(db) +con.execute( + "CREATE TABLE titanic AS " + "SELECT * FROM read_csv_auto('https://raw.githubusercontent.com/" + "datasciencedojo/datasets/master/titanic.csv')" +) +con.close() + +# Agrupar por sexo midiendo edad y tarifa. +res = groupby_stats_duckdb(db, "titanic", "Sex", ["Age", "Fare"]) +print(res["status"]) # ok +print(res["n_groups"]) # 2 (male, female) +for g in res["groups"]: + print(g["key"], g["n"], round(g["stats"]["Fare"]["mean"], 2)) +# female 314 44.48 +# male 577 25.52 +``` + +## Cuando usarla + +Cuando en un EDA necesitas el clasico split-apply-combine: "para cada categoria de X, +¿cuanto vale en media/mediana/desviacion/min/max la metrica Y?". Es el nucleo de un +capitulo de agregacion/OLAP. Usala antes de pintar barras o boxplots por grupo, para +detectar segmentos con comportamiento distinto, o para resumir una tabla grande sin +traer las filas a RAM: todo el GROUP BY ocurre push-down en el motor de DuckDB y solo +viaja una fila por grupo. `top_n` te deja quedarte con los grupos mas poblados. + +## Gotchas + +- Funcion impura: lee un archivo DuckDB del disco (read_only, nunca lo modifica). La + tabla debe existir ya en el `.db` (no carga CSV; para eso crea la tabla antes). +- Identificadores (tabla, group_by, measures) se interpolan citados con dobles comillas + y escapando las internas: soporta nombres con espacios y evita inyeccion. No pases + expresiones SQL como group_by/measure — solo nombres de columna. +- `count` es el tamanio del grupo (`COUNT(*)`), independiente de las measures: se + refleja en el campo `n` de cada grupo, NO como clave dentro de `stats`. Las claves de + `stats[measure]` son las measure-aggs efectivas (mean/median/std/min/max menos count). +- `std` usa `stddev_samp` (muestral, n-1): un grupo con una sola fila da `NULL` -> `None`. + Las measures pueden contener NULLs; cada agregada los ignora segun la semantica de DuckDB. +- `truncated:True` indica que habia mas grupos que `top_n` (se devolvieron los `top_n` + mayores por tamanio). Sube `top_n` si necesitas todos los grupos. +- Si `measures` esta vacio, cada grupo trae solo `n` y `stats == {}` (valido, util para + un simple conteo por categoria). diff --git a/python/functions/datascience/groupby_stats_duckdb.py b/python/functions/datascience/groupby_stats_duckdb.py new file mode 100644 index 00000000..9a2bfd1e --- /dev/null +++ b/python/functions/datascience/groupby_stats_duckdb.py @@ -0,0 +1,184 @@ +"""groupby_stats_duckdb — agregaciones GROUP BY con push-down SQL en DuckDB. + +Funcion impura: lee de disco a traves de DuckDB (via la primitiva read-only +`duckdb_query_readonly` del grupo `duckdb`). Pertenece al grupo de capacidad `eda`. + +Ejecuta un `GROUP BY ` en el motor de DuckDB (split-apply-combine con +push-down) calculando, para cada columna numerica de `measures`, las agregaciones +pedidas (mean/median/std/min/max). Solo trae al cliente una fila por grupo, nunca +las filas crudas: apto para tablas grandes. Es el nucleo de un capitulo de +agregacion/OLAP de un EDA. + +Estilo dict-no-throw del grupo duckdb: nunca lanza; captura cualquier error y +devuelve {status:'error', error:str}. +""" + +from infra import duckdb_query_readonly + +# Mapeo agg -> funcion agregada SQL de DuckDB. `count` se trata aparte: es +# COUNT(*) (tamanio del grupo), independiente de las measures. +_AGG_SQL = { + "mean": "avg", + "median": "median", + "std": "stddev_samp", + "min": "min", + "max": "max", +} + +# Aggs por defecto cuando aggs=None. count primero (tamanio del grupo) + las +# cinco estadisticas por measure. +_DEFAULT_AGGS = ["count", "mean", "median", "std", "min", "max"] + + +def _quote_ident(ident: str) -> str: + """Cita un identificador SQL con dobles comillas, escapando las internas. + + Soporta nombres con espacios o caracteres especiales y evita inyeccion: dentro + de un identificador entrecomillado el unico caracter peligroso es la propia + comilla doble, que se duplica ("") segun el estandar SQL. DuckDB no admite + parametros posicionales para nombres de tabla/columna, asi que esta es la via + segura de interpolarlos. + """ + return '"' + str(ident).replace('"', '""') + '"' + + +def groupby_stats_duckdb( + db_path: str, + table: str, + group_by: str, + measures: list, + aggs: list = None, + top_n: int = 15, +) -> dict: + """GROUP BY con agregaciones por measure, todo push-down en DuckDB. + + Args: + db_path: ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la + base. Un path inexistente devuelve {status:'error', ...} sin lanzar. + table: nombre de la tabla. Se interpola citado con dobles comillas (soporta + nombres con espacios). + group_by: columna por la que agrupar. Se interpola citada. + measures: lista de columnas numericas a agregar. Lista vacia es valida: + cada grupo trae solo su tamanio `n` y `stats` vacio. + aggs: lista de agregaciones a calcular. None (default) = + ["count", "mean", "median", "std", "min", "max"]. Valores validos: + count (tamanio del grupo, va a `n`), mean, median, std, min, max + (estas cinco se calculan por cada measure). Un agg desconocido devuelve + error. + top_n: numero maximo de grupos a devolver, ordenados por tamanio de grupo + descendente (default 15). Se pide top_n+1 internamente para detectar si + habia mas grupos y marcar `truncated`. + + Returns: + dict. En exito: + {status:'ok', + group_by:str, + measures:[...], + aggs:[...], # las efectivas (incluye count si se pidio) + n_groups:int, # nº de grupos devueltos (<= top_n) + truncated:bool, # True si habia mas de top_n grupos + groups:[{key:, n:int, + stats:{:{mean,median,std,min,max}}}, ...], + note:str} + Las estadisticas son float o None (p.ej. stddev_samp de un grupo de una + sola fila -> NULL -> None). En error (sin lanzar): {status:'error', error:str}. + """ + try: + # 1. Validar entradas. + if not isinstance(table, str) or table == "": + return {"status": "error", "error": "table must be a non-empty string"} + if not isinstance(group_by, str) or group_by == "": + return {"status": "error", "error": "group_by must be a non-empty string"} + + if measures is None: + measures = [] + if not isinstance(measures, list): + return {"status": "error", "error": "measures must be a list"} + for m in measures: + if not isinstance(m, str) or m == "": + return { + "status": "error", + "error": f"invalid measure identifier: {m!r}", + } + + if aggs is None: + aggs = list(_DEFAULT_AGGS) + if not isinstance(aggs, list) or len(aggs) == 0: + return { + "status": "error", + "error": "aggs must be a non-empty list or None", + } + for a in aggs: + if a != "count" and a not in _AGG_SQL: + return { + "status": "error", + "error": f"unknown agg {a!r}; valid: count, " + + ", ".join(_AGG_SQL), + } + + if not isinstance(top_n, int) or isinstance(top_n, bool) or top_n < 1: + return {"status": "error", "error": "top_n must be a positive int"} + + # 2. Aggs por measure = todas menos count (count es el tamanio del grupo, + # se mapea siempre a la columna `n`). + measure_aggs = [a for a in aggs if a != "count"] + + # 3. Construir el SELECT. grp y n primero; luego un termino por measure x agg + # con alias posicional (m{idx}_{agg}) para no chocar con nombres de columna + # que lleven espacios o caracteres raros. + select_terms = [f"{_quote_ident(group_by)} AS grp", "COUNT(*) AS n"] + agg_index = [] # (measure_name, agg_name, alias) + for mi, m in enumerate(measures): + for a in measure_aggs: + alias = f"m{mi}_{a}" + fn = _AGG_SQL[a] + select_terms.append(f"{fn}({_quote_ident(m)}) AS {alias}") + agg_index.append((m, a, alias)) + + # Pedimos top_n+1 grupos para detectar truncado (habia mas que top_n). + sql = ( + f"SELECT {', '.join(select_terms)} " + f"FROM {_quote_ident(table)} " + f"GROUP BY {_quote_ident(group_by)} " + f"ORDER BY n DESC " + f"LIMIT {top_n + 1}" + ) + + # 4. Ejecutar push-down. sandbox=True (default) basta: la tabla ya existe en + # el .db, no necesitamos read_csv/read_blob ni acceso al filesystem. + result = duckdb_query_readonly(db_path, sql, max_rows=top_n + 1) + if result.get("status") != "ok": + return { + "status": "error", + "error": "groupby query failed: " + + str(result.get("error", "unknown")), + } + + rows = result.get("rows", []) + truncated = len(rows) > top_n + if truncated: + rows = rows[:top_n] + + # 5. Reconstruir la estructura por grupo. + groups = [] + for row in rows: + stats = {m: {} for m in measures} + for (m, a, alias) in agg_index: + stats[m][a] = row.get(alias) + groups.append( + {"key": row.get("grp"), "n": row.get("n"), "stats": stats} + ) + + return { + "status": "ok", + "group_by": group_by, + "measures": list(measures), + "aggs": list(aggs), + "n_groups": len(groups), + "truncated": truncated, + "groups": groups, + "note": f"GROUP BY {group_by}: top {len(groups)} grupos por tamanio sobre " + f"{len(measures)} measure(s)", + } + except Exception as e: # noqa: BLE001 + return {"status": "error", "error": str(e)} diff --git a/python/functions/datascience/groupby_stats_duckdb_test.py b/python/functions/datascience/groupby_stats_duckdb_test.py new file mode 100644 index 00000000..e0857b3d --- /dev/null +++ b/python/functions/datascience/groupby_stats_duckdb_test.py @@ -0,0 +1,106 @@ +"""Tests para groupby_stats_duckdb.""" + +import os +import sys + +import duckdb + +# Permitir importar funciones del registry (from infra import ..., from datascience import ...). +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions")) + +from datascience.groupby_stats_duckdb import groupby_stats_duckdb + + +def _make_db(tmp_path, rows): + """Crea una DuckDB con tabla t(g VARCHAR, x DOUBLE) e inserta `rows`.""" + db = os.path.join(str(tmp_path), "t.duckdb") + con = duckdb.connect(db) + con.execute("CREATE TABLE t(g VARCHAR, x DOUBLE)") + con.executemany("INSERT INTO t VALUES (?, ?)", rows) + con.close() + return db + + +def test_agrega_por_grupo_con_valores_conocidos(tmp_path): + # Grupo a: [10, 20, 30] -> n=3, mean=20, min=10, max=30, median=20, std=10. + # Grupo b: [5, 15] -> n=2, mean=10, median=10. + # Grupo c: [100] -> n=1, mean=100, std=None (1 sola fila). + rows = [ + ("a", 10.0), ("a", 20.0), ("a", 30.0), + ("b", 5.0), ("b", 15.0), + ("c", 100.0), + ] + db = _make_db(tmp_path, rows) + res = groupby_stats_duckdb(db, "t", "g", ["x"]) + assert res["status"] == "ok", res + assert res["n_groups"] == 3 + assert res["truncated"] is False + assert res["aggs"] == ["count", "mean", "median", "std", "min", "max"] + + by_key = {g["key"]: g for g in res["groups"]} + assert set(by_key) == {"a", "b", "c"} + + # Grupo a: comprobacion manual de mean/min/max/median/std. + sa = by_key["a"]["stats"]["x"] + assert by_key["a"]["n"] == 3 + assert abs(sa["mean"] - 20.0) < 1e-9 + assert abs(sa["min"] - 10.0) < 1e-9 + assert abs(sa["max"] - 30.0) < 1e-9 + assert abs(sa["median"] - 20.0) < 1e-9 + assert "std" in sa and sa["std"] is not None + assert abs(sa["std"] - 10.0) < 1e-9 # stddev_samp([10,20,30]) = 10 + + # Grupo b: mean y median conocidas. + sb = by_key["b"]["stats"]["x"] + assert by_key["b"]["n"] == 2 + assert abs(sb["mean"] - 10.0) < 1e-9 + assert abs(sb["median"] - 10.0) < 1e-9 + assert "median" in sb and "std" in sb + + # Grupo c: una sola fila -> std None (stddev_samp NULL), mean/min/max definidos. + sc = by_key["c"]["stats"]["x"] + assert by_key["c"]["n"] == 1 + assert abs(sc["mean"] - 100.0) < 1e-9 + assert sc["std"] is None + + +def test_db_inexistente_devuelve_error_sin_lanzar(tmp_path): + db = os.path.join(str(tmp_path), "no_existe.duckdb") + res = groupby_stats_duckdb(db, "t", "g", ["x"]) + assert res["status"] == "error", res + assert isinstance(res["error"], str) and res["error"] + + +def test_measures_vacias_agrega_solo_count(tmp_path): + rows = [("a", 1.0), ("a", 2.0), ("b", 3.0)] + db = _make_db(tmp_path, rows) + res = groupby_stats_duckdb(db, "t", "g", []) + assert res["status"] == "ok", res + by_key = {g["key"]: g for g in res["groups"]} + assert by_key["a"]["n"] == 2 + assert by_key["b"]["n"] == 1 + # Sin measures, stats por grupo es un dict vacio (valido). + assert by_key["a"]["stats"] == {} + assert by_key["b"]["stats"] == {} + + +def test_columna_con_espacio_agrupa_bien(tmp_path): + # Tabla con nombres de columna con espacios -> prueba el quoting con dobles + # comillas tanto en group_by como en la measure. + db = os.path.join(str(tmp_path), "space.duckdb") + con = duckdb.connect(db) + con.execute('CREATE TABLE t("my col" VARCHAR, "the val" DOUBLE)') + con.executemany( + 'INSERT INTO t VALUES (?, ?)', + [("x", 1.0), ("x", 3.0), ("y", 10.0)], + ) + con.close() + + res = groupby_stats_duckdb(db, "t", "my col", ["the val"]) + assert res["status"] == "ok", res + by_key = {g["key"]: g for g in res["groups"]} + assert by_key["x"]["n"] == 2 + assert abs(by_key["x"]["stats"]["the val"]["mean"] - 2.0) < 1e-9 + assert by_key["y"]["n"] == 1 + assert abs(by_key["y"]["stats"]["the val"]["mean"] - 10.0) < 1e-9 diff --git a/python/functions/datascience/pivot_table_duckdb.md b/python/functions/datascience/pivot_table_duckdb.md new file mode 100644 index 00000000..a14d55eb --- /dev/null +++ b/python/functions/datascience/pivot_table_duckdb.md @@ -0,0 +1,92 @@ +--- +name: pivot_table_duckdb +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def pivot_table_duckdb(db_path: str, table: str, index: str, columns: str, value: str, agg: str = 'mean', top_rows: int = 10, top_cols: int = 8) -> dict" +description: "Pivot table (index x columns -> agg(value)) calculada con push-down SQL en DuckDB (GROUP BY en el motor, sin traer filas a RAM) y recortada a las top_rows filas y top_cols columnas con mas observaciones para que quepa entera en un PDF movil / slide PPTX sin cortarse. Version push-down para tablas grandes de la funcion pura `pivot` (que pivota list[dict] en memoria)." +tags: [eda, pivot, duckdb, aggregate, datascience, push-down, report] +uses_functions: [duckdb_query_readonly_py_infra] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: db_path + desc: "Ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la base." + - name: table + desc: "Nombre de la tabla a pivotar. Se interpola citado con dobles comillas (DuckDB no admite parametros para identificadores)." + - name: index + desc: "Columna cuyos valores forman las filas de la pivot (eje vertical)." + - name: columns + desc: "Columna cuyos valores forman las columnas de la pivot (eje horizontal)." + - name: value + desc: "Columna numerica a agregar en cada celda. Ignorada cuando agg='count'." + - name: agg + desc: "Funcion de agregacion: mean, sum, count, min, max, median. mean->avg(), count->COUNT(*). Otro valor devuelve {status:'error'}." + - name: top_rows + desc: "Numero maximo de filas a conservar, elegidas por mayor numero de observaciones (suma de COUNT(*) por valor de index). Default 10." + - name: top_cols + desc: "Numero maximo de columnas a conservar, elegidas por mayor numero de observaciones (suma de COUNT(*) por valor de columns). Default 8." +output: "dict. En exito {status:'ok', index, columns, value, agg, row_labels:[...], col_labels:[...], matrix:[[...]], truncated_rows:bool, truncated_cols:bool, note:str}. matrix tiene len(row_labels) filas y cada fila len(col_labels) celdas (valor agregado o None si la combinacion no existe). truncated_* indica si hubo mas filas/columnas que el top. En error {status:'error', error:str} (no lanza)." +tested: true +tests: ["pivot mean labels y celda conocida", "pivot trunca a top rows y top cols", "pivot count no necesita value real", "pivot db inexistente devuelve error sin lanzar", "pivot agg invalido devuelve error"] +test_file_path: "python/functions/datascience/pivot_table_duckdb_test.py" +file_path: "python/functions/datascience/pivot_table_duckdb.py" +--- + +## Ejemplo + +```python +import duckdb +from datascience import pivot_table_duckdb + +# Tabla DuckDB de prueba estilo titanic: sex x pclass -> mean(fare). +db = "/tmp/pivot_demo.duckdb" +con = duckdb.connect(db) +con.execute( + "CREATE TABLE titanic AS SELECT * FROM (VALUES " + "('male',1,211.3),('female',1,151.5),('male',3,7.9)," + "('female',3,16.7),('male',1,52.0),('female',2,41.6)" + ") t(sex, pclass, fare)" +) +con.close() + +res = pivot_table_duckdb(db, "titanic", index="sex", columns="pclass", value="fare", agg="mean") +print(res["status"]) # ok +print(res["row_labels"]) # ['female', 'male'] (orden por nº de observaciones desc; empate -> etiqueta) +print(res["col_labels"]) # [1, 3, 2] (pclass=1 tiene 3 obs, pclass=3 -> 2, pclass=2 -> 1) +print(res["matrix"]) # [[151.5, 16.7, 41.6], [131.65, 7.9, None]] (male/pclass=2 no existe -> None) +``` + +## Cuando usarla + +Cuando quieres una pivot table (`index` x `columns` -> `agg(value)`) de una tabla +DuckDB con MUCHAS filas y necesitas que el resultado quepa entero en un informe: un +PDF abierto en el movil o un slide PPTX, donde una matriz de 50x30 se cortaria. La +agregacion se hace push-down en el motor (no traes las filas a RAM) y el resultado se +limita a las `top_rows` x `top_cols` combinaciones con mas observaciones. Encaja en el +flujo `eda` para resumir el cruce de dos categoricas (sexo x clase, region x producto) +contra una metrica. Para pivotar un `list[dict]` ya cargado en memoria usa la funcion +pura `pivot_py_datascience`; esta es la version push-down sobre disco. + +## Gotchas + +- Funcion impura: lee un archivo DuckDB del disco (read_only, nunca lo modifica). +- Recorta a `top_rows` x `top_cols` por numero de observaciones (suma de `COUNT(*)`), + NO por magnitud del valor agregado. Si habia mas filas/columnas, `truncated_rows` / + `truncated_cols` quedan en True y esas combinaciones NO aparecen en la matriz. +- Las celdas sin datos (combinacion `index` x `columns` que no existe en la tabla) se + rellenan con `None`, no con 0: distinguir "cero medido" de "sin observaciones". +- `agg='count'` cuenta filas por celda con `COUNT(*)` e ignora `value` (puedes pasar + cualquier nombre de columna). Para el resto de aggs, `value` debe ser una columna + numerica real o la query fallara con `{status:'error'}`. +- `agg` solo admite mean, sum, count, min, max, median; cualquier otro valor devuelve + `{status:'error'}` sin tocar la base. +- Orden de `row_labels` / `col_labels`: por numero de observaciones descendente, con + desempate estable por etiqueta. No es orden alfabetico ni el de aparicion. +- La query se ejecuta con `sandbox=False` en `duckdb_query_readonly` (uso interno + confiable: el SQL lo construye esta funcion, no un cliente externo). diff --git a/python/functions/datascience/pivot_table_duckdb.py b/python/functions/datascience/pivot_table_duckdb.py new file mode 100644 index 00000000..cee20464 --- /dev/null +++ b/python/functions/datascience/pivot_table_duckdb.py @@ -0,0 +1,176 @@ +"""pivot_table_duckdb — pivot table (index x columns -> agg(value)) con push-down SQL. + +Funcion impura: lee de disco a traves de DuckDB reusando la primitiva read-only del +grupo `duckdb` (`duckdb_query_readonly`). Pertenece al grupo de capacidad `eda` +(exploratory data analysis). + +A diferencia de la funcion pura `pivot` (que pivota un `list[dict]` ya cargado en +memoria), esta version empuja la agregacion al motor de DuckDB (push-down): el +GROUP BY lo resuelve DuckDB y solo se traen los valores agregados, nunca las filas +crudas. Esto la hace apta para tablas grandes. + +Ademas reduce el resultado a las `top_rows` filas y `top_cols` columnas con mas +observaciones, de modo que la pivot quepa entera en un PDF movil / slide PPTX sin +cortarse. Marca `truncated_rows`/`truncated_cols` cuando hubo recorte. + +Estilo dict-no-throw del grupo duckdb: nunca lanza; captura cualquier error y +devuelve {status:'error', error:str}. +""" + +from collections import defaultdict + +from infra import duckdb_query_readonly + +# Funciones de agregacion permitidas y su nombre en SQL DuckDB. +# mean -> avg; el resto mapea directo. count se trata aparte (COUNT(*), sin value). +_AGG_SQL = { + "mean": "avg", + "sum": "sum", + "count": "count", + "min": "min", + "max": "max", + "median": "median", +} + + +def _quote_ident(ident: str) -> str: + """Cita un identificador SQL con dobles comillas, escapando `"` -> `""`. + + DuckDB no admite parametros posicionales para nombres de tabla/columna, asi que + hay que interpolarlos. El quoting con `"` y el doblado de comillas internas evita + que un nombre rompa la sentencia (mismo patron que correlation_matrix_duckdb). + """ + return '"' + str(ident).replace('"', '""') + '"' + + +def pivot_table_duckdb( + db_path: str, + table: str, + index: str, + columns: str, + value: str, + agg: str = "mean", + top_rows: int = 10, + top_cols: int = 8, +) -> dict: + """Pivot table push-down en DuckDB, recortada a top_rows x top_cols. + + Construye una pivot (filas = valores de `index`, columnas = valores de `columns`, + celda = `agg(value)`) agregando en el motor de DuckDB, y la reduce a las filas y + columnas con mas observaciones para que quepa en un PDF / slide. + + Args: + db_path: ruta al archivo DuckDB. Debe existir (read_only NO crea la base). + table: nombre de la tabla a pivotar. + index: columna cuyos valores forman las filas de la pivot. + columns: columna cuyos valores forman las columnas de la pivot. + value: columna numerica a agregar. Ignorada cuando agg="count". + agg: funcion de agregacion. Una de: "mean", "sum", "count", "min", "max", + "median". mean se traduce a avg(); count a COUNT(*). + top_rows: numero maximo de filas a conservar, elegidas por mayor numero de + observaciones (suma de COUNT(*) por valor de index). Default 10. + top_cols: numero maximo de columnas a conservar, elegidas por mayor numero de + observaciones (suma de COUNT(*) por valor de columns). Default 8. + + Returns: + dict. En exito: + {status:'ok', + index, columns, value, agg, + row_labels:[...], # valores de index, en orden de freq desc + col_labels:[...], # valores de columns, en orden de freq desc + matrix:[[...], ...], # len == len(row_labels); cada fila + # len == len(col_labels); celda = agg o None + truncated_rows:bool, truncated_cols:bool, + note:str} + En error (sin lanzar): {status:'error', error:str}. + """ + try: + if not isinstance(agg, str) or agg not in _AGG_SQL: + return { + "status": "error", + "error": "invalid agg " + + repr(agg) + + "; allowed: " + + ", ".join(sorted(_AGG_SQL)), + } + + # Paso 1 (push-down): agregar (index, columns) -> agg(value) + COUNT(*). + if agg == "count": + agg_expr = "COUNT(*)" + else: + agg_expr = f"{_AGG_SQL[agg]}({_quote_ident(value)})" + + sql = ( + f"SELECT {_quote_ident(index)} AS r, " + f"{_quote_ident(columns)} AS c, " + f"{agg_expr} AS v, " + f"COUNT(*) AS n " + f"FROM {_quote_ident(table)} " + f"GROUP BY {_quote_ident(index)}, {_quote_ident(columns)}" + ) + + # max_rows alto: queremos todos los grupos (index x columns) para elegir el + # top con criterio global. sandbox=False igual que correlation_matrix_duckdb, + # porque db_path es una ruta interna de confianza. + result = duckdb_query_readonly( + db_path, sql, max_rows=1_000_000, sandbox=False + ) + if result.get("status") != "ok": + return { + "status": "error", + "error": "pivot query failed: " + + str(result.get("error", "unknown")), + } + + # Paso 2 (en python): contar observaciones por fila y por columna, y guardar + # el valor agregado de cada celda (r, c). + row_obs: dict = defaultdict(int) + col_obs: dict = defaultdict(int) + cell: dict = {} + for row in result.get("rows", []): + r = row.get("r") + c = row.get("c") + n = row.get("n") or 0 + row_obs[r] += n + col_obs[c] += n + cell[(r, c)] = row.get("v") + + def _top(obs: dict, limit: int): + # Orden: mas observaciones primero; desempate estable por etiqueta. + ranked = sorted(obs.items(), key=lambda kv: (-kv[1], str(kv[0]))) + selected = [label for label, _ in ranked[:limit]] + return selected, len(ranked) > limit + + row_labels, truncated_rows = _top(row_obs, top_rows) + col_labels, truncated_cols = _top(col_obs, top_cols) + + # Paso 3: materializar la matriz; None donde la combinacion no existe. + matrix = [ + [cell.get((r, c)) for c in col_labels] for r in row_labels + ] + + note = ( + f"pivot {agg}({value}) reducida a {len(row_labels)}x{len(col_labels)} " + "(top por observaciones) para caber en PDF/slide" + ) + if agg == "count": + note = ( + f"pivot count(*) reducida a {len(row_labels)}x{len(col_labels)} " + "(top por observaciones) para caber en PDF/slide" + ) + + return { + "status": "ok", + "index": index, + "columns": columns, + "value": value, + "agg": agg, + "row_labels": row_labels, + "col_labels": col_labels, + "matrix": matrix, + "truncated_rows": truncated_rows, + "truncated_cols": truncated_cols, + "note": note, + } + except Exception as e: # noqa: BLE001 + return {"status": "error", "error": str(e)} diff --git a/python/functions/datascience/pivot_table_duckdb_test.py b/python/functions/datascience/pivot_table_duckdb_test.py new file mode 100644 index 00000000..53f85e2e --- /dev/null +++ b/python/functions/datascience/pivot_table_duckdb_test.py @@ -0,0 +1,115 @@ +"""Tests para pivot_table_duckdb.""" + +import os +import sys + +import duckdb + +# Permitir importar funciones del registry (from infra import ..., from datascience import ...). +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions")) + +from datascience.pivot_table_duckdb import pivot_table_duckdb + + +def _make_db(tmp_name: str) -> str: + """Crea una DuckDB con dos categoricas (a, b) y un valor numerico conocido. + + Filas: + a='x', b='y', val=10 + a='x', b='y', val=20 -> mean(x,y) = 15, count(x,y) = 2 + a='x', b='z', val=5 -> mean(x,z) = 5 + a='w', b='y', val=100 -> mean(w,y) = 100 + Observaciones por a: x=3, w=1. Por b: y=3, z=1. + La combinacion (w, z) no existe -> celda None. + """ + db = os.path.join("/tmp", tmp_name) + if os.path.exists(db): + os.remove(db) + con = duckdb.connect(db) + con.execute("CREATE TABLE t (a VARCHAR, b VARCHAR, val DOUBLE)") + con.execute( + "INSERT INTO t VALUES " + "('x','y',10),('x','y',20),('x','z',5),('w','y',100)" + ) + con.close() + return db + + +def test_pivot_mean_labels_y_celda_conocida(): + db = _make_db("pivot_test_mean.duckdb") + res = pivot_table_duckdb(db, "t", index="a", columns="b", value="val", agg="mean") + assert res["status"] == "ok", res + # Filas ordenadas por observaciones desc: x (3) antes que w (1). + assert res["row_labels"] == ["x", "w"], res["row_labels"] + # Columnas ordenadas por observaciones desc: y (3) antes que z (1). + assert res["col_labels"] == ["y", "z"], res["col_labels"] + # matrix[0][0] = mean(a='x', b='y') = (10 + 20) / 2 = 15. + assert abs(res["matrix"][0][0] - 15.0) < 1e-9, res["matrix"] + # matrix[0][1] = mean(a='x', b='z') = 5. + assert abs(res["matrix"][0][1] - 5.0) < 1e-9, res["matrix"] + # matrix[1][0] = mean(a='w', b='y') = 100. + assert abs(res["matrix"][1][0] - 100.0) < 1e-9, res["matrix"] + # (w, z) no existe -> None. + assert res["matrix"][1][1] is None, res["matrix"] + # Sin truncado con los defaults (top_rows=10, top_cols=8). + assert res["truncated_rows"] is False + assert res["truncated_cols"] is False + # La matriz es rectangular consistente con las etiquetas. + assert len(res["matrix"]) == len(res["row_labels"]) + for fila in res["matrix"]: + assert len(fila) == len(res["col_labels"]) + + +def test_pivot_trunca_a_top_rows_y_top_cols(): + db = _make_db("pivot_test_trunc.duckdb") + res = pivot_table_duckdb( + db, "t", index="a", columns="b", value="val", agg="mean", + top_rows=1, top_cols=1, + ) + assert res["status"] == "ok", res + # Solo la fila/columna mas frecuente sobrevive. + assert res["row_labels"] == ["x"], res["row_labels"] + assert res["col_labels"] == ["y"], res["col_labels"] + assert res["matrix"] == [[15.0]], res["matrix"] + # Habia mas de 1 fila y mas de 1 columna -> truncado en ambos ejes. + assert res["truncated_rows"] is True + assert res["truncated_cols"] is True + + +def test_pivot_count_no_necesita_value_real(): + db = _make_db("pivot_test_count.duckdb") + # value apunta a una columna real pero count(*) la ignora; tambien valdria un + # nombre cualquiera. Verificamos que count funciona igualmente. + res = pivot_table_duckdb( + db, "t", index="a", columns="b", value="val", agg="count" + ) + assert res["status"] == "ok", res + assert res["row_labels"] == ["x", "w"] + assert res["col_labels"] == ["y", "z"] + # count(a='x', b='y') = 2 observaciones. + assert res["matrix"][0][0] == 2, res["matrix"] + # count(a='x', b='z') = 1. + assert res["matrix"][0][1] == 1, res["matrix"] + # count(a='w', b='y') = 1. + assert res["matrix"][1][0] == 1, res["matrix"] + # (w, z) no existe -> None. + assert res["matrix"][1][1] is None, res["matrix"] + + +def test_pivot_db_inexistente_devuelve_error_sin_lanzar(): + res = pivot_table_duckdb( + "/nonexistent/path/does_not_exist.duckdb", + "t", index="a", columns="b", value="val", agg="mean", + ) + assert res["status"] == "error", res + assert isinstance(res["error"], str) + + +def test_pivot_agg_invalido_devuelve_error(): + db = _make_db("pivot_test_badagg.duckdb") + res = pivot_table_duckdb( + db, "t", index="a", columns="b", value="val", agg="stddev" + ) + assert res["status"] == "error", res + assert "invalid agg" in res["error"] diff --git a/python/functions/datascience/select_groupby_keys.md b/python/functions/datascience/select_groupby_keys.md new file mode 100644 index 00000000..d97332eb --- /dev/null +++ b/python/functions/datascience/select_groupby_keys.md @@ -0,0 +1,158 @@ +--- +id: select_groupby_keys_py_datascience +name: select_groupby_keys +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: pure +signature: "def select_groupby_keys(profile: dict, max_keys: int = 3, max_card: int = 20, max_measures: int = 4) -> dict" +description: "Elige deterministicamente las columnas categoricas mas interesantes para GROUP BY, las numericas medida y pares pivote a partir de un TableProfile del grupo eda. Respaldo cuantitativo para el capitulo de agregacion/OLAP de un EDA. Funcion pura, no muta el input, nunca lanza." +tags: [eda, aggregation, groupby, olap, profiling, datascience] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "" +imports: [] +example: | + from datascience import select_groupby_keys + profile = { + "n_rows": 891, + "key_candidates": ["passenger_id"], + "columns": [ + {"name": "sex", "inferred_type": "categorical", "distinct_count": 2, + "unique_pct": 0.002, "null_pct": 0.0, "flags": [], + "categorical": {"imbalance": 1.8}, "numeric": None}, + {"name": "pclass", "inferred_type": "categorical", "distinct_count": 3, + "unique_pct": 0.003, "null_pct": 0.0, "flags": [], + "categorical": {"imbalance": 2.5}, "numeric": None}, + {"name": "fare", "inferred_type": "numeric", "distinct_count": 200, + "unique_pct": 0.2, "null_pct": 0.0, "flags": [], + "numeric": {"std": 49.7, "cv": 1.54}, "categorical": None}, + ], + } + select_groupby_keys(profile) + # {"group_keys": [{"col": "sex", ...}, {"col": "pclass", ...}], + # "measures": ["fare"], + # "pivots": [{"index": "sex", "columns": "pclass", "value": "fare"}], + # "note": "2 clave(s) de grupo: sex, pclass; 1 medida(s): fare; 1 pivot(s)."} +tested: true +tests: + - "test_titanic_picks_good_cats_excludes_id_and_constant" + - "test_titanic_measures_exclude_id_constant_and_keep_numerics" + - "test_titanic_generates_one_pivot" + - "test_empty_profile_returns_all_empty_and_does_not_crash" + - "test_none_profile_does_not_crash" + - "test_only_numerics_yields_empty_group_keys_and_no_pivots" + - "test_high_cardinality_and_max_card_are_excluded" + - "test_max_keys_limits_group_keys" + - "test_three_keys_cap_pivots_to_two" + - "test_does_not_mutate_input" +test_file_path: "python/functions/datascience/select_groupby_keys_test.py" +file_path: "python/functions/datascience/select_groupby_keys.py" +params: + - name: profile + desc: > + TableProfile dict del grupo eda (p.ej. salida de summarize_table_duckdb). + Se lee de forma defensiva (.get / or [] / isinstance). Claves usadas: + columns (list[ColumnProfile]), key_candidates (list de nombres de columna + o dicts {name}), n_rows. Cada ColumnProfile usa: name, inferred_type + ("numeric"|"categorical"|"datetime"|"text"|"boolean"), distinct_count, + unique_pct (0..1), null_pct (0..1), flags (list[str], reconoce + "possible_id"/"high_cardinality"/"constant"), numeric ({std, cv, ...}|None) + y categorical ({imbalance, mode_pct, ...}|None). + - name: max_keys + desc: "Numero maximo de claves de grupo (group_keys) a devolver. Default 3." + - name: max_card + desc: > + Cardinalidad maxima (distinct_count) que una columna categorica puede + tener para seguir siendo candidata a clave de grupo. Default 20. + - name: max_measures + desc: "Numero maximo de columnas medida (nombres) a devolver. Default 4." +output: > + dict con group_keys (list de {col, cardinality, score} ordenada por score + desc), measures (list[str] de nombres de columnas numericas ordenadas por + dispersion), pivots (list de {index, columns, value}, hasta 2 pares + categorica x categorica con la primera measure como valor) y note (str, + resumen corto en espanol de lo elegido). Ante profile vacio/None devuelve + todas las listas vacias y una note descriptiva; nunca lanza. +--- + +## Ejemplo + +```python +from datascience import select_groupby_keys + +# TableProfile estilo titanic: 2 categoricas buenas, 1 numerica medida, +# 1 id secuencial (descartado) y un key_candidate declarado. +profile = { + "n_rows": 891, + "key_candidates": ["passenger_id"], + "columns": [ + {"name": "sex", "inferred_type": "categorical", "distinct_count": 2, + "unique_pct": 0.002, "null_pct": 0.0, "flags": [], + "categorical": {"imbalance": 1.8}, "numeric": None}, + {"name": "pclass", "inferred_type": "categorical", "distinct_count": 3, + "unique_pct": 0.003, "null_pct": 0.0, "flags": [], + "categorical": {"imbalance": 2.5}, "numeric": None}, + {"name": "fare", "inferred_type": "numeric", "distinct_count": 200, + "unique_pct": 0.2, "null_pct": 0.0, "flags": [], + "numeric": {"std": 49.7, "cv": 1.54}, "categorical": None}, + {"name": "passenger_id", "inferred_type": "numeric", "distinct_count": 891, + "unique_pct": 1.0, "null_pct": 0.0, "flags": ["possible_id"], + "numeric": {"std": 257.4, "cv": 0.58}, "categorical": None}, + ], +} + +select_groupby_keys(profile) +# { +# "group_keys": [ +# {"col": "sex", "cardinality": 2, "score": 0.5556}, +# {"col": "pclass", "cardinality": 3, "score": 0.4}, +# ], +# "measures": ["fare"], # passenger_id excluido (id secuencial) +# "pivots": [{"index": "sex", "columns": "pclass", "value": "fare"}], +# "note": "2 clave(s) de grupo: sex, pclass; 1 medida(s): fare; 1 pivot(s).", +# } +``` + +## Cuando usarla + +Cuando hayas perfilado una tabla con el grupo `eda` (p.ej. +`summarize_table_duckdb`) y necesites decidir, sin mirar los datos, por qué +columnas merece la pena agrupar (GROUP BY) y qué métricas numéricas agregar: +el respaldo cuantitativo del capítulo de agregación/OLAP de un AutomaticEDA, o +para proponer pivotes en un dashboard. Es la capa de selección sobre el +TableProfile crudo: lee el perfil, ordena candidatos de forma determinista y +no toca los datos. + +## Notas + +Función pura, sin I/O ni dependencias externas (solo stdlib), no muta +`profile`. Lectura defensiva total (`.get`, `or []`, `isinstance`): un `{}` o +`None` produce `{"group_keys": [], "measures": [], "pivots": [], "note": ...}` +y nunca lanza. + +Criterios de selección (deterministas): + +- **group_keys** — candidatas con `inferred_type` en `("categorical","boolean")`. + Se descartan las que estén en `key_candidates`, con flag + `possible_id`/`high_cardinality`/`constant`, con `distinct_count` fuera de + `[2, max_card]`, o all-null (`null_pct >= 0.999`). `score = card_score * + balance_score`: `card_score` mantiene un plateau para cardinalidad moderada + (2..12) y decae hacia `max_card`; `balance_score = 1/imbalance` usando + `categorical.imbalance` si está, aproximando con `mode_pct` si no, o un valor + neutro (0.5) en último caso. Devuelve hasta `max_keys`, ordenadas por score + desc (empates por orden de columna). +- **measures** — candidatas con `inferred_type` en + `("numeric","integer","float")`. Se descartan id-like (flag `possible_id` y + `unique_pct >= 0.99`) y constantes (`numeric.std` == 0 o None). Se rankean por + dispersión informativa: `abs(cv)` si está, si no `abs(std)`. Devuelve hasta + `max_measures` **nombres** (strings). +- **pivots** — hasta 2 pares `(group_keys[i].col, group_keys[j].col)` con i dict: + """Select GROUP BY keys, measures and pivot pairs from a TableProfile. + + Reads everything defensively (``.get(...)``, ``or []``, ``isinstance``) and + never raises. With an empty/None profile it returns every list empty. + + Selection rules (deterministic): + + - **group_keys** (categorical columns to group by): candidates have + ``inferred_type`` in ``("categorical", "boolean")``. Discarded if they are + in ``profile['key_candidates']``, carry a ``possible_id`` / + ``high_cardinality`` / ``constant`` flag, have ``distinct_count`` outside + ``[2, max_card]``, or are all-null (``null_pct >= 0.999``). Each survivor + gets ``score = card_score * balance_score`` where ``card_score`` keeps a + plateau for moderate cardinality (2..12) and decays towards ``max_card``, + and ``balance_score = 1 / imbalance`` (``categorical.imbalance`` if + present, else approximated from ``mode_pct``, else a neutral default). + The top ``max_keys`` by score (desc, ties by column order) are returned. + + - **measures** (numeric columns to aggregate): candidates have + ``inferred_type`` in ``("numeric", "integer", "float")``. Discarded if + id-like (``possible_id`` flag *and* ``unique_pct >= 0.99``) or constant + (``numeric.std`` is ``0`` or ``None``). Ranked by informative dispersion: + ``abs(cv)`` when available, else ``abs(std)``. The top ``max_measures`` + **names** are returned. + + - **pivots**: up to 2 ``(group_keys[i].col, group_keys[j].col)`` pairs with + ``i < j``, using the first measure as the aggregated value. Empty when + fewer than 2 group keys were selected. + + Args: + profile: TableProfile dict of the ``eda`` group. Relevant keys: + ``columns`` (list[ColumnProfile]), ``key_candidates`` (list of + column names or ``{name}`` dicts), ``n_rows``. Each ColumnProfile + uses: ``name``, ``inferred_type``, ``distinct_count``, + ``unique_pct`` (0..1), ``null_pct`` (0..1), ``flags`` (list[str]), + ``numeric`` ({std, cv, ...}|None), ``categorical`` + ({imbalance, mode_pct, ...}|None). + max_keys: Maximum number of group-by keys to return. Default 3. + max_card: Maximum cardinality (``distinct_count``) a categorical column + may have to still qualify as a group key. Default 20. + max_measures: Maximum number of measure names to return. Default 4. + + Returns: + dict with: + group_keys (list[{col, cardinality, score}], ordered by score desc), + measures (list[str], numeric column names ordered by dispersion), + pivots (list[{index, columns, value}], up to 2 pairs), + note (str, short summary of what was chosen). + """ + if not isinstance(profile, dict): + profile = {} + + try: + max_keys = int(max_keys) + except (TypeError, ValueError): + max_keys = 3 + try: + max_card = int(max_card) + except (TypeError, ValueError): + max_card = 20 + try: + max_measures = int(max_measures) + except (TypeError, ValueError): + max_measures = 4 + max_keys = max(max_keys, 0) + max_card = max(max_card, 2) + max_measures = max(max_measures, 0) + + columns = profile.get("columns") or [] + if not isinstance(columns, (list, tuple)): + columns = [] + + key_names = _key_candidate_names(profile.get("key_candidates")) + + group_keys = _select_group_keys(columns, key_names, max_keys, max_card) + measures = _select_measures(columns, max_measures) + pivots = _select_pivots(group_keys, measures) + + return { + "group_keys": group_keys, + "measures": measures, + "pivots": pivots, + "note": _build_note(group_keys, measures, pivots), + } + + +# --------------------------------------------------------------------------- +# group_keys +# --------------------------------------------------------------------------- + +_GROUP_TYPES = ("categorical", "boolean") +_DISQUALIFYING_FLAGS = frozenset({"possible_id", "high_cardinality", "constant"}) +_CARD_PLATEAU_HI = 12 # cardinalities 2..12 are all "moderate" (best). + + +def _select_group_keys(columns, key_names, max_keys, max_card) -> list: + """Rank categorical/boolean columns suitable for GROUP BY.""" + scored = [] + for idx, col in enumerate(columns): + if not isinstance(col, dict): + continue + if (col.get("inferred_type") or "") not in _GROUP_TYPES: + continue + + name = col.get("name") + if name is None: + continue + if name in key_names: + continue + + flags = _as_set(col.get("flags")) + if flags & _DISQUALIFYING_FLAGS: + continue + + if _num(col.get("null_pct"), 0.0) >= 0.999: + continue + + card = _num(col.get("distinct_count"), 0.0) + if card < 2 or card > max_card: + continue + card_i = int(card) + + score = _card_score(card_i, max_card) * _balance_score(col.get("categorical")) + scored.append((round(score, 6), idx, name, card_i)) + + # Deterministic: higher score first, ties broken by original column order. + scored.sort(key=lambda t: (-t[0], t[1])) + + out = [] + for score, _idx, name, card_i in scored[:max_keys]: + out.append({"col": name, "cardinality": card_i, "score": score}) + return out + + +def _card_score(card: int, max_card: int) -> float: + """Prefer moderate cardinality; plateau at 2..12, decay towards max_card.""" + if card <= 1: + return 0.0 + if card <= _CARD_PLATEAU_HI: + return 1.0 + denom = max(max_card - _CARD_PLATEAU_HI, 1) + over = card - _CARD_PLATEAU_HI + return max(0.1, 1.0 - over / denom) + + +def _balance_score(categorical) -> float: + """1.0 for a perfectly balanced category, decaying as imbalance grows. + + Uses ``categorical.imbalance`` (max_count/min_count, >= 1) when available; + otherwise approximates from ``mode_pct`` (top-class dominance); otherwise a + neutral default so the column is still selectable. + """ + if isinstance(categorical, dict): + imbalance = categorical.get("imbalance") + if isinstance(imbalance, (int, float)) and imbalance >= 1.0: + return 1.0 / float(imbalance) + mode_pct = categorical.get("mode_pct") + if isinstance(mode_pct, (int, float)): + return _clamp(1.0 - float(mode_pct), 0.0, 1.0) + return 0.5 + + +# --------------------------------------------------------------------------- +# measures +# --------------------------------------------------------------------------- + +_NUMERIC_TYPES = ("numeric", "integer", "float") + + +def _select_measures(columns, max_measures) -> list: + """Rank numeric columns by informative dispersion (cv, else std).""" + scored = [] + for idx, col in enumerate(columns): + if not isinstance(col, dict): + continue + if (col.get("inferred_type") or "") not in _NUMERIC_TYPES: + continue + + name = col.get("name") + if name is None: + continue + + flags = _as_set(col.get("flags")) + unique_pct = _num(col.get("unique_pct"), 0.0) + if "possible_id" in flags and unique_pct >= 0.99: + continue # sequential id, not a measure. + + numeric = col.get("numeric") + std = numeric.get("std") if isinstance(numeric, dict) else None + if not isinstance(std, (int, float)) or std == 0: + continue # constant or unknown spread -> not informative. + + cv = numeric.get("cv") if isinstance(numeric, dict) else None + if isinstance(cv, (int, float)): + dispersion = abs(float(cv)) + else: + dispersion = abs(float(std)) + + scored.append((dispersion, idx, name)) + + # Higher dispersion first, ties broken by original column order. + scored.sort(key=lambda t: (-t[0], t[1])) + return [name for _disp, _idx, name in scored[:max_measures]] + + +# --------------------------------------------------------------------------- +# pivots +# --------------------------------------------------------------------------- + + +def _select_pivots(group_keys, measures) -> list: + """Up to 2 (cat_a, cat_b) pairs from the chosen group keys.""" + if not isinstance(group_keys, list) or len(group_keys) < 2: + return [] + value = measures[0] if measures else None + pairs = [] + n = len(group_keys) + for i in range(n): + for j in range(i + 1, n): + pairs.append({ + "index": group_keys[i].get("col"), + "columns": group_keys[j].get("col"), + "value": value, + }) + if len(pairs) >= 2: + return pairs + return pairs + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _build_note(group_keys, measures, pivots) -> str: + """One-line Spanish summary of the selection.""" + parts = [] + if group_keys: + cols = ", ".join(str(g.get("col")) for g in group_keys) + parts.append(f"{len(group_keys)} clave(s) de grupo: {cols}") + else: + parts.append("sin categóricas agrupables") + if measures: + parts.append(f"{len(measures)} medida(s): " + ", ".join(str(m) for m in measures)) + else: + parts.append("sin medidas numéricas") + if pivots: + parts.append(f"{len(pivots)} pivot(s)") + return "; ".join(parts) + "." + + +def _key_candidate_names(key_candidates) -> set: + """Normalize ``key_candidates`` (strings or ``{name}`` dicts) to a name set.""" + names = set() + if not isinstance(key_candidates, (list, tuple)): + return names + for entry in key_candidates: + if isinstance(entry, str): + names.add(entry) + elif isinstance(entry, dict): + nm = entry.get("name") or entry.get("col") + if nm is not None: + names.add(nm) + return names + + +def _as_set(flags) -> set: + """Coerce a flags value into a set, tolerating None / non-iterables.""" + if isinstance(flags, (list, tuple, set)): + return set(flags) + return set() + + +def _num(value, default: float) -> float: + """Best-effort float conversion with a fallback default.""" + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _clamp(x: float, lo: float, hi: float) -> float: + """Recorta x al rango [lo, hi].""" + if x < lo: + return lo + if x > hi: + return hi + return x diff --git a/python/functions/datascience/select_groupby_keys_test.py b/python/functions/datascience/select_groupby_keys_test.py new file mode 100644 index 00000000..0aaf167d --- /dev/null +++ b/python/functions/datascience/select_groupby_keys_test.py @@ -0,0 +1,213 @@ +"""Tests para select_groupby_keys (grupo eda, dominio datascience).""" + +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +from select_groupby_keys import select_groupby_keys + + +def _cat_col(name, card, *, imbalance=2.0, flags=None, null_pct=0.0): + """ColumnProfile categorico minimo con bloque categorical.""" + return { + "name": name, + "inferred_type": "categorical", + "distinct_count": card, + "unique_pct": card / 1000.0, + "null_pct": null_pct, + "flags": flags or [], + "numeric": None, + "categorical": {"imbalance": imbalance, "mode_pct": 0.5, "n_distinct": card}, + } + + +def _num_col(name, *, std, cv, flags=None, unique_pct=0.1): + """ColumnProfile numerico minimo con bloque numeric.""" + return { + "name": name, + "inferred_type": "numeric", + "distinct_count": 200, + "unique_pct": unique_pct, + "null_pct": 0.0, + "flags": flags or [], + "numeric": {"std": std, "cv": cv}, + "categorical": None, + } + + +def _titanic_like_profile() -> dict: + """Perfil estilo titanic: 2 categoricas buenas, 2 numericas, 1 id, 1 constante.""" + return { + "n_rows": 891, + "key_candidates": ["passenger_id"], + "columns": [ + _cat_col("sex", 2, imbalance=1.8), + _cat_col("pclass", 3, imbalance=2.5), + _num_col("age", std=14.5, cv=0.49), + _num_col("fare", std=49.7, cv=1.54), + # id secuencial: flag possible_id + unique_pct alto. + { + "name": "passenger_id", + "inferred_type": "numeric", + "distinct_count": 891, + "unique_pct": 1.0, + "null_pct": 0.0, + "flags": ["possible_id"], + "numeric": {"std": 257.4, "cv": 0.58}, + "categorical": None, + }, + # columna constante: flag constant + std 0. + { + "name": "embarked_const", + "inferred_type": "categorical", + "distinct_count": 1, + "unique_pct": 0.001, + "null_pct": 0.0, + "flags": ["constant"], + "numeric": None, + "categorical": {"imbalance": 1.0}, + }, + ], + } + + +def test_titanic_picks_good_cats_excludes_id_and_constant(): + out = select_groupby_keys(_titanic_like_profile()) + + # Elige las dos categoricas buenas. + chosen_cols = {g["col"] for g in out["group_keys"]} + assert chosen_cols == {"sex", "pclass"} + + # Excluye la constante y el key_candidate. + assert "embarked_const" not in chosen_cols + assert "passenger_id" not in chosen_cols + + # Cada group key trae col, cardinality y score. + for g in out["group_keys"]: + assert set(g.keys()) == {"col", "cardinality", "score"} + assert isinstance(g["score"], float) + by_col = {g["col"]: g for g in out["group_keys"]} + assert by_col["sex"]["cardinality"] == 2 + assert by_col["pclass"]["cardinality"] == 3 + + # Ordenadas por score descendente. + scores = [g["score"] for g in out["group_keys"]] + assert scores == sorted(scores, reverse=True) + + +def test_titanic_measures_exclude_id_constant_and_keep_numerics(): + out = select_groupby_keys(_titanic_like_profile()) + + # Solo nombres (strings) de numericas informativas, sin el id secuencial. + assert all(isinstance(m, str) for m in out["measures"]) + assert "passenger_id" not in out["measures"] + assert set(out["measures"]) == {"age", "fare"} + + # fare tiene mayor cv (1.54 > 0.49) -> primero. + assert out["measures"][0] == "fare" + + +def test_titanic_generates_one_pivot(): + out = select_groupby_keys(_titanic_like_profile()) + + # Con 2 group keys -> exactamente 1 pivot. + assert len(out["pivots"]) == 1 + pivot = out["pivots"][0] + assert set(pivot.keys()) == {"index", "columns", "value"} + assert {pivot["index"], pivot["columns"]} == {"sex", "pclass"} + # El valor es la primera measure (fare). + assert pivot["value"] == "fare" + + +def test_empty_profile_returns_all_empty_and_does_not_crash(): + out = select_groupby_keys({}) + assert out["group_keys"] == [] + assert out["measures"] == [] + assert out["pivots"] == [] + assert isinstance(out["note"], str) + + +def test_none_profile_does_not_crash(): + out = select_groupby_keys(None) + assert out == { + "group_keys": [], + "measures": [], + "pivots": [], + "note": out["note"], + } + assert isinstance(out["note"], str) + + +def test_only_numerics_yields_empty_group_keys_and_no_pivots(): + profile = { + "n_rows": 500, + "key_candidates": [], + "columns": [ + _num_col("price", std=12.0, cv=0.6), + _num_col("weight", std=3.0, cv=0.2), + ], + } + out = select_groupby_keys(profile) + assert out["group_keys"] == [] + assert out["pivots"] == [] + # Las numericas si se eligen como measures. + assert set(out["measures"]) == {"price", "weight"} + assert out["measures"][0] == "price" # mayor cv. + + +def test_high_cardinality_and_max_card_are_excluded(): + profile = { + "n_rows": 1000, + "key_candidates": [], + "columns": [ + _cat_col("city", 50, flags=["high_cardinality"]), # flag -> fuera. + _cat_col("zone", 35), # card 35 > max_card 20 -> fuera. + _cat_col("region", 5), # valida. + ], + } + out = select_groupby_keys(profile, max_card=20) + assert {g["col"] for g in out["group_keys"]} == {"region"} + + +def test_max_keys_limits_group_keys(): + profile = { + "n_rows": 1000, + "key_candidates": [], + "columns": [ + _cat_col("a", 4, imbalance=1.0), + _cat_col("b", 5, imbalance=1.2), + _cat_col("c", 6, imbalance=1.5), + _cat_col("d", 7, imbalance=2.0), + ], + } + out = select_groupby_keys(profile, max_keys=2) + assert len(out["group_keys"]) == 2 + # Hasta 2 pivots con >=2 keys (aqui exactamente 1 par posible entre 2 keys). + assert len(out["pivots"]) == 1 + + +def test_three_keys_cap_pivots_to_two(): + profile = { + "n_rows": 1000, + "key_candidates": [], + "columns": [ + _cat_col("a", 4, imbalance=1.0), + _cat_col("b", 5, imbalance=1.1), + _cat_col("c", 6, imbalance=1.2), + _num_col("m", std=10.0, cv=0.5), + ], + } + out = select_groupby_keys(profile, max_keys=3) + assert len(out["group_keys"]) == 3 + # 3 keys -> 3 pares posibles, capado a 2. + assert len(out["pivots"]) == 2 + for p in out["pivots"]: + assert p["value"] == "m" + + +def test_does_not_mutate_input(): + profile = _titanic_like_profile() + before = repr(profile) + select_groupby_keys(profile) + assert repr(profile) == before diff --git a/python/functions/datascience/suggest_aggregations_llm.md b/python/functions/datascience/suggest_aggregations_llm.md new file mode 100644 index 00000000..2b4a79fd --- /dev/null +++ b/python/functions/datascience/suggest_aggregations_llm.md @@ -0,0 +1,96 @@ +--- +name: suggest_aggregations_llm +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def suggest_aggregations_llm(profile: dict, candidates: dict, max_aggs: int = 4, model: str = \"claude-haiku-4-5-20251001\") -> dict" +description: "MUST-11.1 del capitulo AGREGACION del AutomaticEDA (grupo eda). Dado el TableProfile de una tabla y los candidatos cuantitativos de select_groupby_keys ({group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}), con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una razon corta cada uno, evitando la explosion combinatoria (no todo contra todo). Privacidad/coste: NO envia filas crudas, solo el resumen AGREGADO de los candidatos (tabla, columnas categoricas con cardinalidad/score, medidas, pivots). Reusa ask_llm del grupo claude-direct (API directa con token OAuth de Claude). Impura, dict-no-throw: NUNCA lanza y SIEMPRE devuelve algo usable; si el LLM falla, el JSON no parsea o no hay seleccion valida, cae a un fallback determinista construido desde los candidatos (source='fallback'). Toda columna que el LLM invente se descarta." +tags: [eda, claude-direct, llm, aggregation, groupby, pivot, datascience, automatic-eda] +params: + - name: profile + desc: "TableProfile del grupo eda. Solo se usa profile['table'] para nombrar la tabla en el prompt; puede ir vacio o sin esa clave (se usa '(tabla sin nombre)')." + - name: candidates + desc: "Salida de select_groupby_keys: {group_keys:[{col, cardinality, score}], measures:[str], pivots:[{index, columns, value}]}. group_keys = columnas categoricas candidatas para GROUP BY; measures = columnas numericas a agregar (sum/avg); pivots = cruces index x columns -> value sugeridos. Cualquier columna que el LLM elija debe existir aqui o se descarta. None o no-dict se trata como vacio." + - name: max_aggs + desc: "Tope de agregaciones a devolver. Default 4. Valores <1 o no-int se normalizan a 4. Limita tanto la seleccion del LLM como el fallback determinista, para evitar la explosion combinatoria." + - name: model + desc: "id del modelo Anthropic a usar en la unica llamada. Default 'claude-haiku-4-5-20251001' (haiku, coste bajo, ~2-3s). Para razones mas finas, pasar p.ej. 'claude-opus-4-8'." +output: "dict dict-no-throw: {status:'ok', source:'llm'|'fallback', aggregations:[{group_by:str, measures:[str], why:str}], pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}. source=='llm' si el LLM produjo al menos una agregacion valida (columnas existentes en candidates); en cualquier otro caso (LLM caido, JSON invalido, seleccion vacia, sin candidatos) source=='fallback' y aggregations/pivots se derivan de candidates con why='selección cuantitativa (sin LLM)'. NUNCA lanza." +uses_functions: [ask_llm_py_core, select_groupby_keys_py_datascience] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +tested: true +tests: ["test_extract_json_object", "test_extract_json_wrapped_in_fences_and_junk", "test_extract_json_non_json_returns_none", "test_validate_aggregations_drops_invalid_columns", "test_llm_path_uses_selection", "test_llm_path_respects_max_aggs", "test_llm_invented_column_is_discarded", "test_fallback_on_empty_llm_response", "test_fallback_on_unparseable_response", "test_fallback_respects_max_aggs", "test_fallback_when_llm_raises", "test_no_candidates_returns_empty_fallback", "test_non_dict_candidates_does_not_raise"] +test_file_path: "python/functions/datascience/suggest_aggregations_llm_test.py" +file_path: "python/functions/datascience/suggest_aggregations_llm.py" +--- + +## Ejemplo + +```python +import sys, os +sys.path.insert(0, os.path.join("python", "functions")) + +from datascience.suggest_aggregations_llm import suggest_aggregations_llm + +profile = {"table": "ventas"} + +# candidates = salida de select_groupby_keys (aqui literal de ejemplo). +candidates = { + "group_keys": [ + {"col": "categoria", "cardinality": 8, "score": 0.91}, + {"col": "region", "cardinality": 5, "score": 0.74}, + {"col": "canal", "cardinality": 3, "score": 0.60}, + ], + "measures": ["importe", "unidades"], + "pivots": [ + {"index": "categoria", "columns": "region", "value": "importe"}, + ], +} + +out = suggest_aggregations_llm(profile, candidates, max_aggs=4) # haiku por defecto + +print("fuente:", out["source"]) # "llm" o "fallback" si no hay red +for agg in out["aggregations"]: + print(f"GROUP BY {agg['group_by']} -> {agg['measures']} ({agg['why']})") +for piv in out["pivots"]: + print(f"pivot {piv['index']} x {piv['columns']} = {piv['value']} ({piv['why']})") +``` + +## Cuando usarla + +Justo despues de `select_groupby_keys` en el capitulo AGREGACION del AutomaticEDA: +cuando ya tienes los candidatos cuantitativos (columnas categoricas con cardinalidad, +medidas numericas y pivots posibles) y quieres que un LLM se quede con las K +agregaciones y pivots MAS INFORMATIVOS en vez de generar "todo contra todo". Usala para +priorizar el plan de analisis de grupos antes de materializar las tablas con +`aggregate_by_group` / pivots, manteniendo el coste y el ruido bajos. Si no hay red o +credenciales, sigue funcionando con un fallback determinista, asi que es seguro +ponerla en un pipeline. + +## Gotchas + +- **Impura: hace 1 llamada de red al LLM.** No es determinista ni gratis. Latencia + tipica ~2-3s con haiku. Una sola llamada cubre toda la seleccion. +- **Requiere token OAuth de Claude** en `~/.claude/.credentials.json` (via `ask_llm` / + grupo `claude-direct`). Sin token / sin red NO lanza: cae al **fallback + determinista** (`source="fallback"`) construido desde `candidates` + (group_keys x measures hasta `max_aggs`, pivots tal cual) con + `why="selección cuantitativa (sin LLM)"`. Comprueba `out["source"]` para saber si la + seleccion vino del LLM o del fallback. +- **NO envia filas crudas al LLM**, solo el resumen AGREGADO de los candidatos. Esto + exige que `candidates` venga ya calculado por `select_groupby_keys` (cardinalidades, + scores, medidas, pivots). +- **Valida columnas inventadas**: si el LLM propone un `group_by`/`measure`/`index`/ + `columns` que no esta en `candidates`, esa entrada se descarta (las medidas se + recortan a las validas). Si tras validar no queda ninguna agregacion, cae al + fallback completo. +- **`max_aggs` acota la explosion combinatoria** tanto en el camino LLM como en el + fallback. Subirlo demasiado reintroduce el ruido que esta funcion evita. +- **Modelo `haiku` por defecto** para coste bajo; sube a `claude-opus-4-8` si necesitas + razones (`why`) mas finas (mas caro y lento). diff --git a/python/functions/datascience/suggest_aggregations_llm.py b/python/functions/datascience/suggest_aggregations_llm.py new file mode 100644 index 00000000..b7fc4ac5 --- /dev/null +++ b/python/functions/datascience/suggest_aggregations_llm.py @@ -0,0 +1,405 @@ +"""suggest_aggregations_llm — el LLM elige las agregaciones mas informativas (grupo `eda`). + +MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Dado el `TableProfile` de una +tabla y los CANDIDATOS cuantitativos que produce `select_groupby_keys` +(`{group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}`), +con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x +medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una +razon corta cada uno. El objetivo es evitar la explosion combinatoria: en vez de +"todo contra todo", el LLM se queda con lo que mas informa. + +Privacidad y coste: NO se envian filas crudas al LLM. El prompt solo lleva el resumen +AGREGADO de los candidatos (nombre de la tabla, columnas categoricas con su +cardinalidad/score, medidas y pivots posibles). Una sola llamada barata. + +Reusa `ask_llm` del registry (grupo claude-direct, API directa con el token OAuth de +Claude en ~/.claude/.credentials.json, arranque 0). Impura: una llamada de red. + +Estilo dict-no-throw con FALLBACK DETERMINISTA: la funcion NUNCA lanza y SIEMPRE +devuelve algo usable. Si `ask_llm` falla (devuelve ""), el JSON no parsea, o el LLM no +produce ninguna seleccion valida, se construye la respuesta directamente desde los +candidatos (group_keys x measures hasta max_aggs, pivots tal cual) con +`source="fallback"`. Ademas, toda columna que el LLM invente (no presente en los +candidatos) se descarta. +""" + +import json + +from core.ask_llm import ask_llm + +_SYSTEM = ( + "Eres un analista de datos conciso. Te dan los CANDIDATOS AGREGADOS de una tabla " + "(columnas categoricas para GROUP BY con su cardinalidad, medidas numericas y " + "pivots posibles) y eliges las agregaciones y pivots MAS INFORMATIVOS para " + "entender los grupos, evitando la explosion combinatoria (no todo contra todo). " + "No recibes filas crudas. Responde en espanol. Responde SIEMPRE y SOLO con un " + "unico objeto JSON valido, sin texto alrededor ni fences de markdown, con la forma " + '{"aggregations": [{"group_by": "", "measures": ["", ...], ' + '"why": ""}], "pivots": [{"index": "", "columns": "", ' + '"value": "", "why": ""}]}. Usa SOLO nombres de columna ' + "que aparezcan en los candidatos; no inventes nombres." +) + + +def _fmt_num(value) -> str: + """Formatea un numero de forma compacta para el prompt (None -> '?').""" + if value is None: + return "?" + if isinstance(value, bool): + return str(value) + if isinstance(value, float): + if value == int(value): + return str(int(value)) + return f"{value:.4g}" + return str(value) + + +def _candidate_view(candidates: dict): + """Extrae las vistas utiles de los candidatos. Funcion interna PURA. + + Devuelve la tupla (group_cols, measures, measure_set, pivots, group_keys): + - group_cols: set de nombres de columna categorica validas (de group_keys[].col). + - measures: lista de medidas numericas (str) preservando orden. + - measure_set: set de las medidas para validar pertenencia rapido. + - pivots: lista de pivots candidatos (dicts) tal cual vienen. + - group_keys: lista de dicts {col, cardinality, score} ya filtrada a entradas validas. + + Tolera estructuras incompletas o de tipo incorrecto sin lanzar. + """ + candidates = candidates if isinstance(candidates, dict) else {} + + gk_raw = candidates.get("group_keys") + group_keys = [] + if isinstance(gk_raw, list): + for gk in gk_raw: + if isinstance(gk, dict) and isinstance(gk.get("col"), str): + group_keys.append(gk) + group_cols = {gk["col"] for gk in group_keys} + + m_raw = candidates.get("measures") + measures = [m for m in m_raw if isinstance(m, str)] if isinstance(m_raw, list) else [] + measure_set = set(measures) + + p_raw = candidates.get("pivots") + pivots = p_raw if isinstance(p_raw, list) else [] + + return group_cols, measures, measure_set, pivots, group_keys + + +def _sorted_group_cols(group_keys: list) -> list: + """Nombres de columna categorica ordenados por score descendente. PURA.""" + + def _score(gk): + s = gk.get("score") + if isinstance(s, (int, float)) and not isinstance(s, bool): + return s + return 0.0 + + return [gk["col"] for gk in sorted(group_keys, key=_score, reverse=True)] + + +def _build_prompt(profile: dict, candidates: dict, max_aggs: int) -> str: + """Construye el prompt compacto SOLO con agregados. Funcion interna PURA. + + No toca red ni disco: testeable sin credenciales. Incluye el nombre de la tabla, + las columnas categoricas candidatas (con cardinalidad y score), las medidas + numericas y los pivots candidatos. Nunca filas crudas. + + Args: + profile: TableProfile (se usa solo profile['table'] para nombrar la tabla). + candidates: salida de select_groupby_keys. + max_aggs: tope de agregaciones a pedir. + + Returns: + El texto del prompt. + """ + profile = profile if isinstance(profile, dict) else {} + candidates = candidates if isinstance(candidates, dict) else {} + + table = profile.get("table") + table = str(table) if table is not None else "(tabla sin nombre)" + + lines = [ + f"Tabla: {table}", + ( + "Tarea: elegir las agregaciones (GROUP BY categorica x medidas numericas) y " + "los pivots MAS INFORMATIVOS para un analisis de grupos. Evita la explosion " + "combinatoria: NO combines todo contra todo, prioriza lo que mas informa." + ), + f"Devuelve a lo sumo {max_aggs} agregaciones.", + "", + "Columnas categoricas candidatas para GROUP BY (col: cardinalidad, score):", + ] + + group_keys = candidates.get("group_keys") or [] + for gk in group_keys: + if not isinstance(gk, dict) or not isinstance(gk.get("col"), str): + continue + lines.append( + f" - {gk['col']}: cardinalidad={_fmt_num(gk.get('cardinality'))}, " + f"score={_fmt_num(gk.get('score'))}" + ) + + measures = candidates.get("measures") or [] + lines.append("") + lines.append("Medidas numericas disponibles (para sum/avg por grupo):") + lines.append(" " + ", ".join(str(m) for m in measures if isinstance(m, str))) + + pivots = candidates.get("pivots") or [] + if pivots: + lines.append("") + lines.append("Pivots candidatos (index x columns -> value):") + for p in pivots: + if not isinstance(p, dict): + continue + lines.append( + f" - index={p.get('index')}, columns={p.get('columns')}, " + f"value={p.get('value')}" + ) + + lines.append("") + lines.append( + "Usa SOLO columnas de las listas anteriores; no inventes nombres. Responde " + "SOLO con el JSON descrito en las instrucciones del sistema." + ) + return "\n".join(lines) + + +def _extract_json(text: str): + """Extrae el primer bloque JSON (objeto o array) de la respuesta. PURA. + + Localiza el bloque que empieza antes (el primer '{' o el primer '[') y, para ese + delimitador, hace json.loads del rango hasta su ultimo cierre. Tolera texto basura + alrededor y fences ```json. NUNCA lanza: ante cualquier fallo devuelve None. + + Args: + text: respuesta cruda del LLM. + + Returns: + El objeto/lista deserializado, o None si no se pudo parsear. + """ + if not text or not isinstance(text, str): + return None + + opens = [] + i_obj = text.find("{") + if i_obj != -1: + opens.append((i_obj, "{", "}")) + i_arr = text.find("[") + if i_arr != -1: + opens.append((i_arr, "[", "]")) + opens.sort() + + for _, open_c, close_c in opens: + start = text.find(open_c) + end = text.rfind(close_c) + if start != -1 and end != -1 and end > start: + try: + return json.loads(text[start : end + 1]) + except (ValueError, TypeError): + continue + return None + + +def _validate_aggregations(raw_aggs, group_cols: set, measure_set: set, max_aggs: int) -> list: + """Filtra las agregaciones del LLM a las que usan SOLO columnas candidatas. PURA. + + Descarta cualquier agregacion cuyo group_by no este en group_cols o que no tenga + al menos una medida valida. Recorta las medidas a las presentes en measure_set. + Limita el resultado a max_aggs entradas. + """ + out = [] + if not isinstance(raw_aggs, list): + return out + for item in raw_aggs: + if not isinstance(item, dict): + continue + gb = item.get("group_by") + if not isinstance(gb, str) or gb not in group_cols: + continue # columna inventada -> se descarta + raw_measures = item.get("measures") + if isinstance(raw_measures, str): + raw_measures = [raw_measures] + if not isinstance(raw_measures, list): + continue + measures = [m for m in raw_measures if isinstance(m, str) and m in measure_set] + if not measures: + continue # sin medidas validas -> agregacion inutil + why = item.get("why") + why = str(why) if why is not None else "" + out.append({"group_by": gb, "measures": measures, "why": why}) + if len(out) >= max_aggs: + break + return out + + +def _validate_pivots(raw_pivots, group_cols: set, measure_set: set) -> list: + """Filtra los pivots del LLM a los que usan SOLO columnas candidatas. PURA. + + Descarta el pivot si index o columns no son columnas categoricas validas. Si el + value no es una medida valida, lo deja en None (un pivot de conteo sigue siendo util). + """ + out = [] + if not isinstance(raw_pivots, list): + return out + for item in raw_pivots: + if not isinstance(item, dict): + continue + idx = item.get("index") + cols = item.get("columns") + if not (isinstance(idx, str) and idx in group_cols): + continue + if not (isinstance(cols, str) and cols in group_cols): + continue + val = item.get("value") + if not (isinstance(val, str) and val in measure_set): + val = None + why = item.get("why") + why = str(why) if why is not None else "" + out.append({"index": idx, "columns": cols, "value": val, "why": why}) + return out + + +def _fallback_aggregations(group_cols_sorted: list, measures: list, max_aggs: int) -> list: + """Agregaciones deterministas: cada columna categorica x todas las medidas. PURA.""" + out = [] + for col in group_cols_sorted: + out.append( + { + "group_by": col, + "measures": list(measures), + "why": "selección cuantitativa (sin LLM)", + } + ) + if len(out) >= max_aggs: + break + return out + + +def _fallback_pivots(cand_pivots: list) -> list: + """Normaliza los pivots candidatos a la forma de salida (tal cual + why). PURA.""" + out = [] + if not isinstance(cand_pivots, list): + return out + for p in cand_pivots: + if not isinstance(p, dict): + continue + idx = p.get("index") + cols = p.get("columns") + if not (isinstance(idx, str) and isinstance(cols, str)): + continue + val = p.get("value") + if not isinstance(val, str): + val = None + out.append( + { + "index": idx, + "columns": cols, + "value": val, + "why": "selección cuantitativa (sin LLM)", + } + ) + return out + + +def suggest_aggregations_llm( + profile: dict, + candidates: dict, + max_aggs: int = 4, + model: str = "claude-haiku-4-5-20251001", +) -> dict: + """Elige las agregaciones y pivots mas informativos con UNA llamada al LLM. + + MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Toma el perfil de la tabla y + los candidatos cuantitativos (salida de select_groupby_keys) y deja que el LLM + seleccione/ordene las K agregaciones (GROUP BY categorica x medidas) y los pivots + mas utiles, con una razon corta cada uno, evitando la explosion combinatoria. + + Privacidad/coste: solo viaja al LLM el resumen AGREGADO de los candidatos, nunca + filas crudas. Una sola llamada barata. + + dict-no-throw con fallback determinista: NUNCA lanza. Si el LLM falla, el JSON no + parsea, o no produce seleccion valida -> construye la respuesta desde los candidatos + (group_keys x measures hasta max_aggs, pivots tal cual) con source="fallback". Las + columnas que el LLM invente (no presentes en los candidatos) se descartan. + + Args: + profile: TableProfile del grupo eda. Solo se usa profile['table'] para nombrar + la tabla en el prompt; puede ir vacio. + candidates: salida de select_groupby_keys, con la forma + {group_keys:[{col,cardinality,score}], measures:[str], + pivots:[{index,columns,value}]}. + max_aggs: tope de agregaciones a devolver. Default 4. Valores <1 o no-int se + normalizan a 4. + model: id del modelo Anthropic. Default 'claude-haiku-4-5-20251001' (haiku, + coste bajo, ~2-3s). + + Returns: + dict {status:"ok", source:"llm"|"fallback", + aggregations:[{group_by:str, measures:[str], why:str}], + pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}. + source=="llm" si el LLM produjo al menos una agregacion valida; en cualquier + otro caso "fallback". NUNCA lanza. + """ + if not isinstance(candidates, dict): + candidates = {} + if isinstance(max_aggs, bool) or not isinstance(max_aggs, int) or max_aggs < 1: + max_aggs = 4 + + group_cols, measures, measure_set, cand_pivots, group_keys = _candidate_view(candidates) + group_cols_sorted = _sorted_group_cols(group_keys) + + # Sin material suficiente para agregar: no merece la pena llamar al LLM. + if not group_cols or not measures: + return { + "status": "ok", + "source": "fallback", + "aggregations": [], + "pivots": _fallback_pivots(cand_pivots), + "note": "sin candidatos suficientes para agregar", + } + + prompt = _build_prompt(profile, candidates, max_aggs) + try: + text = ask_llm(prompt, model=model, system=_SYSTEM, echo=False) + except Exception: # noqa: BLE001 — degradacion: cualquier fallo de red/LLM. + text = "" + + parsed = _extract_json(text) + if parsed is not None: + if isinstance(parsed, dict): + raw_aggs = parsed.get("aggregations") + raw_pivots = parsed.get("pivots") + elif isinstance(parsed, list): + raw_aggs = parsed + raw_pivots = None + else: + raw_aggs = None + raw_pivots = None + + aggs = _validate_aggregations(raw_aggs, group_cols, measure_set, max_aggs) + if aggs: + pivots = _validate_pivots(raw_pivots, group_cols, measure_set) + if not pivots: + pivots = _fallback_pivots(cand_pivots) + return { + "status": "ok", + "source": "llm", + "aggregations": aggs, + "pivots": pivots, + "note": f"{len(aggs)} agregaciones y {len(pivots)} pivots seleccionados por el LLM", + } + + # Fallback determinista. + note = ( + "LLM no disponible; selección cuantitativa determinista" + if not text + else "LLM sin selección válida; selección cuantitativa determinista" + ) + return { + "status": "ok", + "source": "fallback", + "aggregations": _fallback_aggregations(group_cols_sorted, measures, max_aggs), + "pivots": _fallback_pivots(cand_pivots), + "note": note, + } diff --git a/python/functions/datascience/suggest_aggregations_llm_test.py b/python/functions/datascience/suggest_aggregations_llm_test.py new file mode 100644 index 00000000..29a4f2a9 --- /dev/null +++ b/python/functions/datascience/suggest_aggregations_llm_test.py @@ -0,0 +1,198 @@ +"""Tests para suggest_aggregations_llm. + +NO acceden a red ni a credenciales: las funciones internas (_build_prompt, +_extract_json, _validate_*, _fallback_*) son puras y testeables aisladas; la unica +via que llamaria al LLM (suggest_aggregations_llm) se prueba reemplazando el simbolo +`ask_llm` del modulo bajo prueba con una funcion simulada. Los candidatos van +literales en el test: NO se importa select_groupby_keys. + +Cubre golden (LLM ok con columnas validas), edge (max_aggs respetado, sin candidatos) +y error (LLM caido -> fallback, JSON invalido -> fallback, columna inventada -> se +descarta). Todos sin tocar la red. +""" + +import json + +import datascience.suggest_aggregations_llm as M +from datascience.suggest_aggregations_llm import ( + _extract_json, + _validate_aggregations, + suggest_aggregations_llm, +) + +# Candidatos de ejemplo con la forma que produce select_groupby_keys (literales). +_CANDIDATES = { + "group_keys": [ + {"col": "categoria", "cardinality": 8, "score": 0.91}, + {"col": "region", "cardinality": 5, "score": 0.74}, + {"col": "canal", "cardinality": 3, "score": 0.60}, + ], + "measures": ["importe", "unidades"], + "pivots": [ + {"index": "categoria", "columns": "region", "value": "importe"}, + ], +} +_PROFILE = {"table": "ventas"} + + +def _fake_returner(text): + """Devuelve un ask_llm simulado que ignora args y retorna `text`.""" + + def _fake(prompt, model="x", system="", echo=True, **kwargs): + return text + + return _fake + + +# --- _extract_json (parser puro, sin red) --- + + +def test_extract_json_object(): + obj = {"aggregations": [{"group_by": "categoria", "measures": ["importe"], "why": "x"}]} + assert _extract_json(json.dumps(obj)) == obj + + +def test_extract_json_wrapped_in_fences_and_junk(): + obj = {"aggregations": [], "pivots": []} + text = "Claro, aqui tienes:\n```json\n" + json.dumps(obj) + "\n```\nFin." + assert _extract_json(text) == obj + + +def test_extract_json_non_json_returns_none(): + assert _extract_json("no hay json aqui") is None + assert _extract_json("") is None + assert _extract_json(None) is None + + +# --- _validate_aggregations (puro) --- + + +def test_validate_aggregations_drops_invalid_columns(): + group_cols = {"categoria", "region"} + measure_set = {"importe", "unidades"} + raw = [ + {"group_by": "categoria", "measures": ["importe", "inventada"], "why": "ok"}, + {"group_by": "no_existe", "measures": ["importe"], "why": "mala"}, + {"group_by": "region", "measures": ["solo_inventada"], "why": "sin medidas"}, + ] + out = _validate_aggregations(raw, group_cols, measure_set, max_aggs=4) + # Solo sobrevive la primera, con las medidas recortadas a las validas. + assert out == [{"group_by": "categoria", "measures": ["importe"], "why": "ok"}] + + +# --- suggest_aggregations_llm: camino LLM (golden) --- + + +def test_llm_path_uses_selection(monkeypatch): + llm_obj = { + "aggregations": [ + {"group_by": "categoria", "measures": ["importe"], "why": "ventas por familia"}, + {"group_by": "region", "measures": ["importe", "unidades"], "why": "reparto geografico"}, + ], + "pivots": [ + {"index": "categoria", "columns": "region", "value": "importe", "why": "cruce clave"}, + ], + } + monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj))) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES) + assert out["status"] == "ok" + assert out["source"] == "llm" + assert out["aggregations"] == llm_obj["aggregations"] + assert out["pivots"][0]["index"] == "categoria" + assert out["pivots"][0]["why"] == "cruce clave" + + +def test_llm_path_respects_max_aggs(monkeypatch): + llm_obj = { + "aggregations": [ + {"group_by": "categoria", "measures": ["importe"], "why": "a"}, + {"group_by": "region", "measures": ["importe"], "why": "b"}, + {"group_by": "canal", "measures": ["unidades"], "why": "c"}, + ], + "pivots": [], + } + monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj))) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=2) + assert out["source"] == "llm" + assert len(out["aggregations"]) == 2 + + +def test_llm_invented_column_is_discarded(monkeypatch): + # El LLM mezcla una agregacion valida con otra de columna inexistente. + llm_obj = { + "aggregations": [ + {"group_by": "categoria", "measures": ["importe"], "why": "valida"}, + {"group_by": "columna_fantasma", "measures": ["importe"], "why": "inventada"}, + ], + "pivots": [ + {"index": "fantasma", "columns": "region", "value": "importe", "why": "mala"}, + ], + } + monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj))) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES) + assert out["source"] == "llm" + # La agregacion inventada se descarta; queda solo la valida. + assert [a["group_by"] for a in out["aggregations"]] == ["categoria"] + # El pivot con index fantasma se descarta -> cae a los pivots de candidates. + assert all(p["index"] in {"categoria", "region", "canal"} for p in out["pivots"]) + + +# --- suggest_aggregations_llm: fallback determinista (error paths) --- + + +def test_fallback_on_empty_llm_response(monkeypatch): + monkeypatch.setattr(M, "ask_llm", _fake_returner("")) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=4) + assert out["status"] == "ok" + assert out["source"] == "fallback" + # Las agregaciones se derivan de candidates (una por group_key, con todas las medidas). + assert out["aggregations"][0]["group_by"] in {"categoria", "region", "canal"} + assert out["aggregations"][0]["measures"] == ["importe", "unidades"] + assert out["aggregations"][0]["why"] == "selección cuantitativa (sin LLM)" + # Pivots tal cual de candidates. + assert out["pivots"][0]["index"] == "categoria" + + +def test_fallback_on_unparseable_response(monkeypatch): + monkeypatch.setattr(M, "ask_llm", _fake_returner("esto no es JSON {roto")) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES) + assert out["source"] == "fallback" + assert len(out["aggregations"]) >= 1 + + +def test_fallback_respects_max_aggs(monkeypatch): + monkeypatch.setattr(M, "ask_llm", _fake_returner("")) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=2) + assert out["source"] == "fallback" + assert len(out["aggregations"]) == 2 + + +def test_fallback_when_llm_raises(monkeypatch): + def _boom(*args, **kwargs): + raise RuntimeError("sin red") + + monkeypatch.setattr(M, "ask_llm", _boom) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES) + assert out["source"] == "fallback" + assert out["aggregations"] # no vacio, no lanza + + +def test_no_candidates_returns_empty_fallback(): + # Sin red porque ni siquiera se llama al LLM (no hay material). + out = suggest_aggregations_llm(_PROFILE, {"group_keys": [], "measures": [], "pivots": []}) + assert out["status"] == "ok" + assert out["source"] == "fallback" + assert out["aggregations"] == [] + + +def test_non_dict_candidates_does_not_raise(): + out = suggest_aggregations_llm(_PROFILE, None) + assert out["status"] == "ok" + assert out["aggregations"] == [] From fd59530751298abc0e2eb569c542850d453aea62 Mon Sep 17 00:00:00 2001 From: Egutierrez Date: Tue, 30 Jun 2026 15:33:55 +0200 Subject: [PATCH 2/2] =?UTF-8?q?feat(eda):=20cap=C3=ADtulo=20AGREGACION=20d?= =?UTF-8?q?el=20AutomaticEDA=20(groupby=20+=20pivot=20+=20barras)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Capítulo nuevo (siempre presente cuando hay categóricas agrupables) que analiza la tabla por grupos: stats de numéricas por grupo, tablas dinámicas (pivot) y gráficos de barras desde cero. Obtiene los datos por ctx['aggregations'] precomputado o en vivo vía push-down (ctx['db_path']+table), siguiendo el patrón de chapters/modelos.py. Degrada a None cuando no hay categóricas; emite los bloques del modelo (DataTable, Markdown, Figure) para que el paginador del núcleo no corte nada en PDF ni PPTX. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../automatic_eda/chapters/agregacion.py | 592 ++++++++++++++++++ .../automatic_eda/chapters/agregacion_test.py | 256 ++++++++ 2 files changed, 848 insertions(+) create mode 100644 python/functions/datascience/automatic_eda/chapters/agregacion.py create mode 100644 python/functions/datascience/automatic_eda/chapters/agregacion_test.py diff --git a/python/functions/datascience/automatic_eda/chapters/agregacion.py b/python/functions/datascience/automatic_eda/chapters/agregacion.py new file mode 100644 index 00000000..7b5e03e6 --- /dev/null +++ b/python/functions/datascience/automatic_eda/chapters/agregacion.py @@ -0,0 +1,592 @@ +"""Aggregation chapter (AGREGACION) — group analysis / OLAP of the EDA. + +This chapter is the group-by / pivot ("OLAP") section of an AutomaticEDA report +and is meant to be present **whenever the dataset has at least one low-cardinality +categorical column to group by**. For the most interesting categoricals (chosen +by their cardinality/relevance, optionally with an LLM) it renders, as blocks the +core paginator never cuts: + +1. **Per-group statistics** (split-apply-combine) — for each interesting + categorical key, the count of rows per group and, for each numeric measure, + its mean/median/std/min/max. One compact summary table (mean of every measure + per group) plus a per-measure detail table. +2. **Bar charts** — a vertical bar chart of a measure's mean per group, bars from + zero (Tufte Lie-Factor = 1). +3. **Pivot tables** — categorical A x categorical B -> aggregate of a measure, + limited to the top rows/cols so it fits a mobile page/slide, with a grouped + bar chart of the same pivot. + +The raw data needed to aggregate is **not** in the TableProfile, so — exactly +like ``modelos`` reads its cluster projection from ``ctx`` — this chapter gets +the aggregation results in one of two ways and degrades honestly when neither is +available: + +ctx keys this chapter consumes (all optional): + aggregations : dict — pre-computed results, used directly (offline / tests / + forward-compatible with a calculation phase). Shape:: + + {"groupby": [{"group_by": str, "measures": [str], "why": str, + "result": }], + "pivots": [{"index": str, "columns": str, "value": str, "agg": str, + "why": str, "result": }]} + + db_path, table : str — when ``aggregations`` is absent, the chapter selects + the interesting keys (``select_groupby_keys``), optionally asks an LLM + which to show (``suggest_aggregations_llm`` when ``run_agg_llm`` is True) + and computes the group-by/pivot results live via the push-down registry + functions ``groupby_stats_duckdb`` / ``pivot_table_duckdb``. + run_agg_llm : bool — when True (and ``db_path``/``table`` present), let the + LLM pick the interesting aggregations; otherwise the deterministic + quantitative selection is used. + agg_llm_model : str — model id for the optional LLM selection. + agg_max_keys, agg_max_card, agg_max_measures, agg_top_n : int — limits. + agg_insights : list — optional pre-computed micro-analysis entries + (``[{"title": str, "text": str}]``) rendered as an interpretation section. + +Contract: build_(profile, ctx) -> Chapter | None ; CHAPTER_VERSION = "x.y.z". +Reads everything defensively (``.get``) and never raises: anything missing +degrades to a note instead of aborting the chapter; the chapter returns ``None`` +only when the dataset has no categorical column to group by. +""" + +from __future__ import annotations + +from .. import model + +# Pure/impure registry functions (group ``eda``) this chapter composes. Imported +# defensively so the chapter still builds (degrading the affected part to a note) +# if a function is somehow unavailable / not indexed yet. +try: + from datascience.select_groupby_keys import select_groupby_keys +except Exception: # noqa: BLE001 — keep the chapter importable no matter what. + select_groupby_keys = None # type: ignore[assignment] +try: + from datascience.groupby_stats_duckdb import groupby_stats_duckdb +except Exception: # noqa: BLE001 + groupby_stats_duckdb = None # type: ignore[assignment] +try: + from datascience.pivot_table_duckdb import pivot_table_duckdb +except Exception: # noqa: BLE001 + pivot_table_duckdb = None # type: ignore[assignment] +try: + from datascience.suggest_aggregations_llm import suggest_aggregations_llm +except Exception: # noqa: BLE001 + suggest_aggregations_llm = None # type: ignore[assignment] + +CHAPTER_VERSION = "1.0.0" +CHAPTER_ID = "agregacion" +CHAPTER_TITLE = "Agregación por grupos" + +# Tableau-10 palette — stable colours for the pivot's grouped-bar series. +_SERIES_COLORS = [ + "#4e79a7", "#f28e2b", "#e15759", "#76b7b2", "#59a14f", + "#edc948", "#b07aa1", "#ff9da7", "#9c755f", "#bab0ac", +] + +# Defaults for the live selection/aggregation (overridable via ctx). +_DEF_MAX_KEYS = 3 +_DEF_MAX_CARD = 20 +_DEF_MAX_MEASURES = 4 +_DEF_TOP_N = 12 + + +# --------------------------------------------------------------------------- # +# Formatting helpers (mirror the other chapters' defensive style). +# --------------------------------------------------------------------------- # +def _fmt_num(value, decimals: int = 3) -> str: + if value is None: + return "—" + if isinstance(value, bool): + return "sí" if value else "no" + if isinstance(value, int): + return f"{value:,}".replace(",", ".") + if isinstance(value, float): + if value != value: # NaN + return "NaN" + if value in (float("inf"), float("-inf")): + return str(value) + text = f"{value:.{decimals}f}".rstrip("0").rstrip(".") + return text if text else "0" + return model._safe_str(value) + + +def _is_dict(v) -> bool: + return isinstance(v, dict) + + +def _measure_mean(group: dict, measure: str): + """Pull the mean of one measure out of a groupby-result group entry.""" + stats = group.get("stats") if _is_dict(group.get("stats")) else {} + ms = stats.get(measure) if _is_dict(stats.get(measure)) else {} + return ms.get("mean") + + +# --------------------------------------------------------------------------- # +# Plan + data resolution. Either a pre-computed ctx['aggregations'] is used +# verbatim, or the plan is selected and the results are computed live. +# --------------------------------------------------------------------------- # +def _resolve_candidates(profile: dict, ctx: dict) -> dict: + """Return {group_keys, measures, pivots, note} of interesting columns.""" + pre = ctx.get("agg_candidates") + if _is_dict(pre) and pre.get("group_keys") is not None: + return pre + if select_groupby_keys is not None: + try: + out = select_groupby_keys( + profile, + max_keys=int(ctx.get("agg_max_keys", _DEF_MAX_KEYS)), + max_card=int(ctx.get("agg_max_card", _DEF_MAX_CARD)), + max_measures=int(ctx.get("agg_max_measures", _DEF_MAX_MEASURES)), + ) + if _is_dict(out): + return out + except Exception: # noqa: BLE001 — fall through to the inline fallback. + pass + return _inline_candidates(profile, ctx) + + +def _inline_candidates(profile: dict, ctx: dict) -> dict: + """Minimal defensive selection when select_groupby_keys is unavailable.""" + max_card = int(ctx.get("agg_max_card", _DEF_MAX_CARD)) + max_keys = int(ctx.get("agg_max_keys", _DEF_MAX_KEYS)) + max_measures = int(ctx.get("agg_max_measures", _DEF_MAX_MEASURES)) + keys = profile.get("key_candidates") or [] + group_keys, measures = [], [] + for col in profile.get("columns") or []: + if not _is_dict(col): + continue + name = col.get("name") + it = col.get("inferred_type") + flags = col.get("flags") or [] + dc = col.get("distinct_count") + if it in ("categorical", "boolean") and name not in keys: + if ("possible_id" not in flags and "high_cardinality" not in flags + and "constant" not in flags + and isinstance(dc, int) and 2 <= dc <= max_card): + group_keys.append({"col": name, "cardinality": dc, "score": 0.0}) + elif it == "numeric": + num = col.get("numeric") or {} + if num.get("std") not in (None, 0) and not ( + "possible_id" in flags and (col.get("unique_pct") or 0) >= 0.99): + measures.append(name) + group_keys = group_keys[:max_keys] + measures = measures[:max_measures] + pivots = [] + if len(group_keys) >= 2: + pivots.append({"index": group_keys[0]["col"], + "columns": group_keys[1]["col"], + "value": measures[0] if measures else None}) + return {"group_keys": group_keys, "measures": measures, "pivots": pivots, + "note": "selección cuantitativa básica"} + + +def _resolve_plan(profile: dict, ctx: dict, candidates: dict) -> dict: + """Return {aggregations:[{group_by,measures,why}], pivots:[...], source}.""" + group_keys = candidates.get("group_keys") or [] + measures = candidates.get("measures") or [] + + if ctx.get("run_agg_llm") and suggest_aggregations_llm is not None: + try: + plan = suggest_aggregations_llm( + profile, candidates, + max_aggs=int(ctx.get("agg_max_keys", _DEF_MAX_KEYS)), + model=ctx.get("agg_llm_model", "claude-haiku-4-5-20251001")) + if _is_dict(plan) and plan.get("aggregations"): + return {"aggregations": plan.get("aggregations") or [], + "pivots": plan.get("pivots") or [], + "source": plan.get("source", "llm")} + except Exception: # noqa: BLE001 — fall back to the quantitative plan. + pass + + aggregations = [{ + "group_by": gk.get("col"), + "measures": measures, + "why": f"categórica de {_fmt_num(gk.get('cardinality'))} niveles", + } for gk in group_keys if _is_dict(gk) and gk.get("col")] + pivots = [] + for pv in candidates.get("pivots") or []: + if _is_dict(pv) and pv.get("index") and pv.get("columns"): + pivots.append({"index": pv.get("index"), "columns": pv.get("columns"), + "value": pv.get("value") or (measures[0] if measures else None), + "agg": "mean", "why": "cruce de dos categóricas"}) + return {"aggregations": aggregations, "pivots": pivots, "source": "quantitative"} + + +def _live_groupby(ctx: dict, group_by: str, measures: list, top_n: int): + """Compute one group-by result live via the push-down registry function.""" + db_path = ctx.get("db_path") + table = ctx.get("table") + if not db_path or not table or groupby_stats_duckdb is None: + return None + try: + out = groupby_stats_duckdb(db_path, table, group_by, list(measures or []), + top_n=top_n) + if _is_dict(out) and out.get("status") == "ok": + return out + except Exception: # noqa: BLE001 + return None + return None + + +def _live_pivot(ctx: dict, index: str, columns: str, value, agg: str): + """Compute one pivot live via the push-down registry function.""" + db_path = ctx.get("db_path") + table = ctx.get("table") + if not db_path or not table or pivot_table_duckdb is None or not value: + return None + try: + out = pivot_table_duckdb(db_path, table, index, columns, value, + agg=agg or "mean") + if _is_dict(out) and out.get("status") == "ok": + return out + except Exception: # noqa: BLE001 + return None + return None + + +# --------------------------------------------------------------------------- # +# Figure builders (lazy: matplotlib only imported when the renderer draws them). +# --------------------------------------------------------------------------- # +def _make_group_bars(group_by: str, measure: str, groups: list): + """Vertical bars: mean of ``measure`` per group, bars from zero.""" + labels, values = [], [] + for g in groups: + if not _is_dict(g): + continue + mean = _measure_mean(g, measure) + if mean is None: + continue + labels.append(model._safe_str(g.get("key"))) + values.append(float(mean)) + if not labels: + return None + + def _draw(): + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(6.6, 3.6)) + xs = list(range(len(labels))) + ax.bar(xs, values, color="#4e79a7", alpha=0.9, edgecolor="#2f4d6e", + linewidth=0.4) + ax.set_xticks(xs) + short = [(s[:18] + "…") if len(s) > 19 else s for s in labels] + rot = 30 if max((len(s) for s in short), default=0) > 6 else 0 + ax.set_xticklabels(short, rotation=rot, ha="right" if rot else "center", + fontsize=7) + ax.set_ylabel(f"media de {measure}", fontsize=8) + ax.set_xlabel(group_by, fontsize=8) + ax.set_title(f"Media de «{measure}» por «{group_by}»", fontsize=10) + ax.grid(axis="y", color="#dddddd", linewidth=0.6) + for spine in ("top", "right"): + ax.spines[spine].set_visible(False) + # Value labels above each bar. + vmax = max(values) if values else 0 + for x, v in zip(xs, values): + ax.text(x, v + (abs(vmax) * 0.01 if vmax else 0.01), + _fmt_num(v, 2), ha="center", va="bottom", fontsize=6.5) + fig.tight_layout() + return fig + + return _draw + + +def _make_pivot_bars(pivot: dict): + """Grouped bars of a pivot: x = row_labels, one series per col_label.""" + row_labels = pivot.get("row_labels") or [] + col_labels = pivot.get("col_labels") or [] + matrix = pivot.get("matrix") or [] + if not row_labels or not col_labels or not matrix: + return None + + def _draw(): + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + n_rows = len(row_labels) + n_cols = len(col_labels) + fig, ax = plt.subplots(figsize=(6.8, 3.8)) + total_w = 0.8 + bar_w = total_w / max(n_cols, 1) + base = list(range(n_rows)) + for j, clabel in enumerate(col_labels): + offs = [b - total_w / 2 + bar_w * (j + 0.5) for b in base] + vals = [] + for i in range(n_rows): + cell = matrix[i][j] if (i < len(matrix) and j < len(matrix[i])) else None + vals.append(float(cell) if isinstance(cell, (int, float)) else 0.0) + color = _SERIES_COLORS[j % len(_SERIES_COLORS)] + ax.bar(offs, vals, width=bar_w, color=color, alpha=0.9, + label=model._safe_str(clabel)) + ax.set_xticks(base) + short = [(s[:16] + "…") if len(s) > 17 else s + for s in (model._safe_str(r) for r in row_labels)] + rot = 30 if max((len(s) for s in short), default=0) > 6 else 0 + ax.set_xticklabels(short, rotation=rot, ha="right" if rot else "center", + fontsize=7) + ax.set_xlabel(model._safe_str(pivot.get("index")), fontsize=8) + ax.set_ylabel(f"{pivot.get('agg','mean')} de {pivot.get('value')}", + fontsize=8) + ax.set_title(f"{pivot.get('index')} × {pivot.get('columns')}", fontsize=10) + ax.grid(axis="y", color="#dddddd", linewidth=0.6) + ax.legend(title=model._safe_str(pivot.get("columns")), fontsize=6.5, + title_fontsize=7, frameon=True, framealpha=0.9, loc="best") + for spine in ("top", "right"): + ax.spines[spine].set_visible(False) + fig.tight_layout() + return fig + + return _draw + + +def _group_bars_maker(group_by: str, measure: str, groups: list): + """Bind per-aggregation args so the lazy closure is loop-safe.""" + def _make(): + return _make_group_bars(group_by, measure, groups)() + return _make + + +def _pivot_bars_maker(pivot: dict): + def _make(): + return _make_pivot_bars(pivot)() + return _make + + +# --------------------------------------------------------------------------- # +# Section builders. Each returns a list of blocks (possibly empty). +# --------------------------------------------------------------------------- # +def _groupby_section(group_by: str, measures: list, result: dict, why: str) -> list: + """Build the blocks for one group-by aggregation, or [] if unusable.""" + if not _is_dict(result) or not result.get("groups"): + return [] + groups = [g for g in result.get("groups") or [] if _is_dict(g)] + if not groups: + return [] + eff_measures = result.get("measures") or measures or [] + + blocks = [model.Heading(text=f"Agrupado por «{group_by}»", level=2)] + intro = f"**{why}.** " if why else "" + intro += (f"{_fmt_num(result.get('n_groups') or len(groups))} grupos" + f"{' (top por tamaño)' if result.get('truncated') else ''}.") + blocks.append(model.Markdown(text=intro)) + + # Summary table: one row per group, count + mean of every measure. + header = ["Grupo", "n"] + [f"{m} (media)" for m in eff_measures] + rows = [] + for g in groups: + row = [model._safe_str(g.get("key")), _fmt_num(g.get("n"))] + for m in eff_measures: + row.append(_fmt_num(_measure_mean(g, m), 2)) + rows.append(row) + blocks.append(model.DataTable( + header=header, rows=rows, title=f"Resumen por «{group_by}»", + note="Conteo de filas y media de cada medida por grupo.")) + + if not eff_measures: + return blocks + + # Primary measure: a bar chart + a detail table (mean/median/std/min/max). + primary = eff_measures[0] + bars = _make_group_bars(group_by, primary, groups) + if bars is not None: + blocks.append(model.Figure( + make=_group_bars_maker(group_by, primary, groups), + caption=f"Media de «{primary}» por «{group_by}» (barras desde cero).")) + + det_header = ["Grupo", "n", "media", "mediana", "σ", "mín", "máx"] + det_rows = [] + for g in groups: + stats = g.get("stats") if _is_dict(g.get("stats")) else {} + ms = stats.get(primary) if _is_dict(stats.get(primary)) else {} + det_rows.append([ + model._safe_str(g.get("key")), _fmt_num(g.get("n")), + _fmt_num(ms.get("mean"), 2), _fmt_num(ms.get("median"), 2), + _fmt_num(ms.get("std"), 2), _fmt_num(ms.get("min"), 2), + _fmt_num(ms.get("max"), 2), + ]) + blocks.append(model.DataTable( + header=det_header, rows=det_rows, + title=f"Detalle de «{primary}» por «{group_by}»")) + return blocks + + +def _pivot_section(pivot_spec: dict, result: dict) -> list: + """Build the blocks for one pivot table, or [] if unusable.""" + if not _is_dict(result) or not result.get("row_labels"): + return [] + row_labels = result.get("row_labels") or [] + col_labels = result.get("col_labels") or [] + matrix = result.get("matrix") or [] + if not row_labels or not col_labels or not matrix: + return [] + + index = result.get("index") or pivot_spec.get("index") + columns = result.get("columns") or pivot_spec.get("columns") + value = result.get("value") or pivot_spec.get("value") + agg = result.get("agg") or pivot_spec.get("agg") or "mean" + why = pivot_spec.get("why") or "" + + blocks = [model.Heading(text=f"Pivot: «{index}» × «{columns}»", level=2)] + intro = f"**{why}.** " if why else "" + intro += (f"{agg} de «{value}» cruzando «{index}» (filas) y «{columns}» " + f"(columnas).") + if result.get("truncated_rows") or result.get("truncated_cols"): + intro += " Limitado a las filas/columnas más frecuentes." + blocks.append(model.Markdown(text=intro)) + + header = [model._safe_str(index)] + [model._safe_str(c) for c in col_labels] + rows = [] + for i, rlabel in enumerate(row_labels): + row = [model._safe_str(rlabel)] + cells = matrix[i] if i < len(matrix) else [] + for j in range(len(col_labels)): + cell = cells[j] if j < len(cells) else None + row.append(_fmt_num(cell, 2)) + rows.append(row) + blocks.append(model.DataTable( + header=header, rows=rows, + title=f"{agg} de «{value}»", + note=f"Cada celda es {agg} de «{value}» para esa combinación.")) + + fig_pivot = {"row_labels": row_labels, "col_labels": col_labels, + "matrix": matrix, "index": index, "columns": columns, + "value": value, "agg": agg} + if _make_pivot_bars(fig_pivot) is not None: + blocks.append(model.Figure( + make=_pivot_bars_maker(fig_pivot), + caption=f"{agg} de «{value}» por «{index}» y «{columns}» " + f"(barras agrupadas).")) + return blocks + + +def _insights_section(ctx: dict) -> list: + """Optional pre-computed micro-analysis of the aggregations (SHOULD-11.4).""" + entries = ctx.get("agg_insights") + if not isinstance(entries, list) or not entries: + return [] + blocks = [model.Heading(text="Interpretación de los grupos", level=2)] + for e in entries: + if not _is_dict(e): + continue + title = model._safe_str(e.get("title")) + text = model._safe_str(e.get("text")) + line = (f"**{title}.** " if title else "") + text + if line.strip(): + blocks.append(model.Markdown(text=line)) + return blocks if len(blocks) > 1 else [] + + +# --------------------------------------------------------------------------- # +# Pre-computed path: ctx['aggregations'] already carries the results. +# --------------------------------------------------------------------------- # +def _sections_from_precomputed(agg: dict) -> list: + sections = [] + for entry in agg.get("groupby") or []: + if not _is_dict(entry): + continue + sections += _groupby_section( + entry.get("group_by"), entry.get("measures") or [], + entry.get("result") or {}, entry.get("why") or "") + for entry in agg.get("pivots") or []: + if not _is_dict(entry): + continue + sections += _pivot_section(entry, entry.get("result") or {}) + return sections + + +# --------------------------------------------------------------------------- # +# Live path: select keys, pick a plan, compute results via push-down functions. +# --------------------------------------------------------------------------- # +def _sections_live(profile: dict, ctx: dict, candidates: dict) -> list: + top_n = int(ctx.get("agg_top_n", _DEF_TOP_N)) + plan = _resolve_plan(profile, ctx, candidates) + sections = [] + for agg in plan.get("aggregations") or []: + if not _is_dict(agg) or not agg.get("group_by"): + continue + result = _live_groupby(ctx, agg.get("group_by"), + agg.get("measures") or [], top_n) + if result is not None: + sections += _groupby_section(agg.get("group_by"), + agg.get("measures") or [], result, + agg.get("why") or "") + for pv in plan.get("pivots") or []: + if not _is_dict(pv) or not pv.get("index") or not pv.get("columns"): + continue + result = _live_pivot(ctx, pv.get("index"), pv.get("columns"), + pv.get("value"), pv.get("agg") or "mean") + if result is not None: + sections += _pivot_section(pv, result) + return sections + + +# --------------------------------------------------------------------------- # +# Entry point. +# --------------------------------------------------------------------------- # +def _intro_blocks() -> list: + text = ( + "Este capítulo analiza la tabla **por grupos** (split-apply-combine): " + "elige las columnas categóricas más informativas — por su cardinalidad " + "y relevancia, no todas contra todas, para no inflar comparaciones " + "espurias — y resume las variables numéricas dentro de cada grupo " + "(conteo, media, mediana, desviación). Las **tablas dinámicas** (pivot) " + "cruzan dos categóricas sobre una medida, y los **gráficos de barras** " + "(siempre desde cero) comparan los grupos de un vistazo." + ) + return [model.Heading(text=CHAPTER_TITLE, level=1), + model.Markdown(text=text)] + + +def build_agregacion(profile: dict, ctx: dict): + """Build the AGREGACION Chapter, or None if the dataset can't be grouped. + + Args: + profile: the ``eda`` group TableProfile dict. + ctx: presentation context (see module docstring for the keys consumed). + + Returns: + A ``model.Chapter`` with per-group stats, pivots and bar charts; or + ``None`` when the dataset has no low-cardinality categorical column to + group by (the chapter does not apply). + """ + profile = profile or {} + ctx = ctx or {} + if not isinstance(profile, dict): + return None + + # Pre-computed results take precedence (offline / tests / forward-compat). + pre = ctx.get("aggregations") + if _is_dict(pre) and (pre.get("groupby") or pre.get("pivots")): + sections = _sections_from_precomputed(pre) + if not sections: + return None + blocks = _intro_blocks() + sections + _insights_section(ctx) + return model.Chapter(id=CHAPTER_ID, title=CHAPTER_TITLE, + version=CHAPTER_VERSION, blocks=blocks) + + # Live path: needs at least one categorical key to group by. + candidates = _resolve_candidates(profile, ctx) + if not _is_dict(candidates) or not (candidates.get("group_keys")): + return None # chapter does not apply: nothing to group by. + + sections = _sections_live(profile, ctx, candidates) + if not sections: + # Applies (there are categorical keys) but no aggregation data is + # reachable: emit an honest note instead of fabricating numbers. + keys = ", ".join(model._safe_str((k or {}).get("col")) + for k in candidates.get("group_keys") or [] + if _is_dict(k)) + note = model.Note( + "No se pudo calcular la agregación: el capítulo necesita los datos " + "crudos. Pasa ctx['db_path'] + ctx['table'] (para el cálculo " + "push-down en DuckDB) o ctx['aggregations'] ya precalculado. " + f"Columnas categóricas candidatas: {keys or '—'}.") + blocks = _intro_blocks() + [note] + _insights_section(ctx) + return model.Chapter(id=CHAPTER_ID, title=CHAPTER_TITLE, + version=CHAPTER_VERSION, blocks=blocks) + + blocks = _intro_blocks() + sections + _insights_section(ctx) + return model.Chapter(id=CHAPTER_ID, title=CHAPTER_TITLE, + version=CHAPTER_VERSION, blocks=blocks) diff --git a/python/functions/datascience/automatic_eda/chapters/agregacion_test.py b/python/functions/datascience/automatic_eda/chapters/agregacion_test.py new file mode 100644 index 00000000..e35005be --- /dev/null +++ b/python/functions/datascience/automatic_eda/chapters/agregacion_test.py @@ -0,0 +1,256 @@ +"""Tests for the AGREGACION chapter — DoD: golden + edges + error/no-cut path. + +Self-contained and deterministic: no DuckDB and no LLM. The aggregation results +are passed pre-computed via ``ctx['aggregations']`` (the same shape the push-down +registry functions ``groupby_stats_duckdb`` / ``pivot_table_duckdb`` produce), so +the chapter's rendering logic is exercised without touching disk or the network. +Live push-down + LLM selection are covered separately by the golden script. + +Verifies: +- Golden: a profile with categoricals + numerics builds a Chapter with per-group + stats tables, a pivot table and bar-chart figures, and it renders to PDF AND + PPTX showing the group keys, values and pivot — nothing cut. +- Edges: a dataset with no low-cardinality categorical returns None; an empty + profile returns None; a profile that *could* be grouped but has no reachable + data degrades to an honest note instead of raising. +- No-cut: many groups (30) + a long interpretation paragraph survive intact in + the rendered PDF (table split by rows, text wrapped whole). +""" + +import os +import re +import tempfile + +from pptx import Presentation +from pypdf import PdfReader + +from datascience.automatic_eda.chapters.agregacion import build_agregacion +from datascience.automatic_eda.model import Chapter +from datascience.render_automatic_eda_pdf import render_automatic_eda_pdf +from datascience.render_automatic_eda_pptx import render_automatic_eda_pptx + + +# --------------------------------------------------------------------------- # +# Synthetic fixtures. +# --------------------------------------------------------------------------- # +def _profile() -> dict: + """A titanic-like profile: 2 categoricals + 2 numeric measures + 1 id.""" + return { + "table": "titanic", + "source": "/data/titanic.csv", + "n_rows": 891, + "n_cols": 5, + "key_candidates": ["passenger_id"], + "columns": [ + {"name": "passenger_id", "inferred_type": "numeric", + "unique_pct": 1.0, "flags": ["possible_id"], + "numeric": {"mean": 446.0, "std": 257.0}}, + {"name": "sex", "inferred_type": "categorical", "distinct_count": 2, + "flags": [], "categorical": {"n_distinct": 2, "imbalance": 0.1, + "top": [{"value": "male", "count": 577}]}}, + {"name": "pclass", "inferred_type": "categorical", "distinct_count": 3, + "flags": [], "categorical": {"n_distinct": 3, "imbalance": 0.2}}, + {"name": "fare", "inferred_type": "numeric", "flags": [], + "numeric": {"mean": 32.2, "std": 49.7, "cv": 1.54}}, + {"name": "age", "inferred_type": "numeric", "flags": [], + "numeric": {"mean": 29.7, "std": 14.5, "cv": 0.49}}, + ], + } + + +def _groupby_result(group_by: str, keys_n: list) -> dict: + """A groupby_stats_duckdb-shaped result for `fare` and `age`.""" + groups = [] + for i, (key, n) in enumerate(keys_n): + groups.append({ + "key": key, "n": n, + "stats": { + "fare": {"mean": 20.0 + i * 15, "median": 10.0 + i * 8, + "std": 40.0 + i, "min": 0.0, "max": 512.3}, + "age": {"mean": 28.0 + i, "median": 27.0 + i, "std": 14.0, + "min": 0.42, "max": 80.0}, + }, + }) + return {"status": "ok", "group_by": group_by, "measures": ["fare", "age"], + "aggs": ["count", "mean", "median", "std", "min", "max"], + "n_groups": len(groups), "truncated": False, "groups": groups} + + +def _pivot_result() -> dict: + return {"status": "ok", "index": "sex", "columns": "pclass", "value": "fare", + "agg": "mean", "row_labels": ["male", "female"], + "col_labels": ["1", "2", "3"], + "matrix": [[62.0, 19.0, 12.0], [110.0, 22.0, 15.0]], + "truncated_rows": False, "truncated_cols": False} + + +def _ctx_precomputed() -> dict: + return { + "aggregations": { + "groupby": [ + {"group_by": "sex", "measures": ["fare", "age"], + "why": "sexo del pasajero", + "result": _groupby_result("sex", [("male", 577), ("female", 314)])}, + {"group_by": "pclass", "measures": ["fare", "age"], + "why": "clase del billete", + "result": _groupby_result( + "pclass", [("3", 491), ("1", 216), ("2", 184)])}, + ], + "pivots": [ + {"index": "sex", "columns": "pclass", "value": "fare", + "agg": "mean", "why": "tarifa por sexo y clase", + "result": _pivot_result()}, + ], + }, + "agg_insights": [ + {"title": "Tarifa por sexo", + "text": "Las mujeres pagaron de media casi el doble que los hombres."}, + ], + } + + +def _pdf_text(path: str) -> str: + txt = "".join((pg.extract_text() or "") for pg in PdfReader(path).pages) + return re.sub(r"\s+", " ", txt) + + +def _pptx_text(path: str) -> str: + prs = Presentation(path) + parts = [] + for sl in prs.slides: + for sh in sl.shapes: + if sh.has_text_frame: + parts.append(sh.text_frame.text) + if sh.has_table: + tb = sh.table + for r in range(len(tb.rows)): + for c in range(len(tb.columns)): + parts.append(tb.cell(r, c).text) + return re.sub(r"\s+", " ", " ".join(parts)) + + +# --------------------------------------------------------------------------- # +# Golden: builds a Chapter and renders to both formats. +# --------------------------------------------------------------------------- # +def test_golden_chapter_blocks_present(): + ch = build_agregacion(_profile(), _ctx_precomputed()) + assert isinstance(ch, Chapter) + assert ch.id == "agregacion" + kinds = [b.kind for b in ch.blocks] + assert "heading" in kinds + assert kinds.count("data_table") >= 3 # 2 group summaries + pivot (+details) + assert "figure" in kinds # at least one bar chart. + # Headings mention the group keys and the pivot. + htext = " ".join(b.text for b in ch.blocks if b.kind == "heading") + assert "sex" in htext and "pclass" in htext and "Pivot" in htext + + +def test_golden_render_pdf(): + ch = build_agregacion(_profile(), _ctx_precomputed()) + with tempfile.TemporaryDirectory() as d: + out = os.path.join(d, "agg.pdf") + res = render_automatic_eda_pdf([ch], out, {"write_manifest": False}) + assert res["path"] == out and os.path.exists(out) + assert res["n_pages"] >= 1 + txt = _pdf_text(out) + assert "Agregación por grupos" in txt + assert "male" in txt and "female" in txt # group + pivot labels. + assert "Pivot" in txt + assert "mediana" in txt # per-measure detail. + assert "casi el doble" in txt # interpretation kept. + + +def test_golden_render_pptx(): + ch = build_agregacion(_profile(), _ctx_precomputed()) + with tempfile.TemporaryDirectory() as d: + out = os.path.join(d, "agg.pptx") + res = render_automatic_eda_pptx([ch], out, {"write_manifest": False}) + assert res["path"] == out and os.path.exists(out) + assert res["n_slides"] >= 1 + txt = _pptx_text(out) + assert "male" in txt and "pclass" in txt + assert "Pivot" in txt or "sex" in txt + + +# --------------------------------------------------------------------------- # +# Edges. +# --------------------------------------------------------------------------- # +def test_edge_no_categorical_returns_none(): + # Only numerics + an id: nothing to group by -> chapter does not apply. + prof = { + "table": "t", "n_rows": 100, "key_candidates": ["id"], + "columns": [ + {"name": "id", "inferred_type": "numeric", "unique_pct": 1.0, + "flags": ["possible_id"], "numeric": {"std": 10.0}}, + {"name": "x", "inferred_type": "numeric", "flags": [], + "numeric": {"mean": 1.0, "std": 2.0}}, + ], + } + assert build_agregacion(prof, {}) is None + + +def test_edge_empty_profile_returns_none(): + assert build_agregacion({}, {}) is None + assert build_agregacion(None, None) is None + + +def test_edge_high_cardinality_only_returns_none(): + # The single categorical is id-like (high cardinality) -> not groupable. + prof = { + "table": "t", "n_rows": 100, "key_candidates": ["uuid"], + "columns": [ + {"name": "uuid", "inferred_type": "categorical", "distinct_count": 100, + "flags": ["high_cardinality", "possible_id"]}, + {"name": "x", "inferred_type": "numeric", "flags": [], + "numeric": {"mean": 1.0, "std": 2.0}}, + ], + } + assert build_agregacion(prof, {}) is None + + +def test_live_without_data_degrades_to_note(): + # Has a categorical to group by but no db_path / no precomputed results: + # must NOT raise and must emit an honest note (chapter still applies). + prof = { + "table": "t", "n_rows": 100, "key_candidates": [], + "columns": [ + {"name": "grp", "inferred_type": "categorical", "distinct_count": 3, + "flags": [], "categorical": {"n_distinct": 3}}, + {"name": "v", "inferred_type": "numeric", "flags": [], + "numeric": {"mean": 1.0, "std": 2.0}}, + ], + } + ch = build_agregacion(prof, {}) + assert isinstance(ch, Chapter) + notes = [b.text for b in ch.blocks if b.kind == "note"] + assert any("datos crudos" in n for n in notes) + + +# --------------------------------------------------------------------------- # +# No-cut: many groups + long text survive intact in the PDF. +# --------------------------------------------------------------------------- # +def test_anti_corte_muchos_grupos_y_texto_largo(): + keys_n = [(f"grupo_{i:02d}", 30 - (i % 5)) for i in range(30)] + long_text = " ".join(f"palabra{i}" for i in range(120)) + ctx = { + "aggregations": { + "groupby": [ + {"group_by": "cat", "measures": ["fare"], "why": "muchos niveles", + "result": _groupby_result("cat", keys_n)}, + ], + "pivots": [], + }, + "agg_insights": [{"title": "Nota larga", "text": long_text}], + } + ch = build_agregacion(_profile(), ctx) + with tempfile.TemporaryDirectory() as d: + out = os.path.join(d, "big.pdf") + res = render_automatic_eda_pdf([ch], out, {"write_manifest": False}) + assert res["path"] == out + assert res["n_pages"] > 1 # 30-row table + figure spill across pages. + txt = _pdf_text(out) + # First and last group labels both survive (table not truncated). + assert "grupo_00" in txt and "grupo_29" in txt + # First, middle and last words of the long paragraph all present. + for i in (0, 60, 119): + assert f"palabra{i}" in txt