From 96da9e3015c3eee89a9ea5fd2dd613d28009e5b2 Mon Sep 17 00:00:00 2001 From: Egutierrez Date: Tue, 30 Jun 2026 15:33:55 +0200 Subject: [PATCH] =?UTF-8?q?feat(eda):=20funciones=20de=20agregaci=C3=B3n/O?= =?UTF-8?q?LAP=20para=20AutomaticEDA=20(groupby/pivot=20push-down=20+=20se?= =?UTF-8?q?lecci=C3=B3n=20LLM)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cuatro funciones nuevas del grupo eda que nutren el capítulo AGREGACION: - select_groupby_keys (pure): elige categóricas agrupables + numéricas medida desde el TableProfile. - groupby_stats_duckdb (impure): GROUP BY push-down en DuckDB (count/mean/median/std/min/max por grupo). - pivot_table_duckdb (impure): pivot A×B push-down, limitado a top filas/cols para no cortar. - suggest_aggregations_llm (impure): el LLM elige las agregaciones interesantes con fallback determinista. Co-Authored-By: Claude Opus 4.8 (1M context) --- python/functions/datascience/__init__.py | 6 + .../datascience/groupby_stats_duckdb.md | 87 ++++ .../datascience/groupby_stats_duckdb.py | 184 ++++++++ .../datascience/groupby_stats_duckdb_test.py | 106 +++++ .../datascience/pivot_table_duckdb.md | 92 ++++ .../datascience/pivot_table_duckdb.py | 176 ++++++++ .../datascience/pivot_table_duckdb_test.py | 115 +++++ .../datascience/select_groupby_keys.md | 158 +++++++ .../datascience/select_groupby_keys.py | 310 ++++++++++++++ .../datascience/select_groupby_keys_test.py | 213 +++++++++ .../datascience/suggest_aggregations_llm.md | 96 +++++ .../datascience/suggest_aggregations_llm.py | 405 ++++++++++++++++++ .../suggest_aggregations_llm_test.py | 198 +++++++++ 13 files changed, 2146 insertions(+) create mode 100644 python/functions/datascience/groupby_stats_duckdb.md create mode 100644 python/functions/datascience/groupby_stats_duckdb.py create mode 100644 python/functions/datascience/groupby_stats_duckdb_test.py create mode 100644 python/functions/datascience/pivot_table_duckdb.md create mode 100644 python/functions/datascience/pivot_table_duckdb.py create mode 100644 python/functions/datascience/pivot_table_duckdb_test.py create mode 100644 python/functions/datascience/select_groupby_keys.md create mode 100644 python/functions/datascience/select_groupby_keys.py create mode 100644 python/functions/datascience/select_groupby_keys_test.py create mode 100644 python/functions/datascience/suggest_aggregations_llm.md create mode 100644 python/functions/datascience/suggest_aggregations_llm.py create mode 100644 python/functions/datascience/suggest_aggregations_llm_test.py diff --git a/python/functions/datascience/__init__.py b/python/functions/datascience/__init__.py index 9fc8c206..f2746d11 100644 --- a/python/functions/datascience/__init__.py +++ b/python/functions/datascience/__init__.py @@ -25,6 +25,7 @@ from .describe_numeric import describe_numeric from .summarize_categorical import summarize_categorical from .infer_semantic_type import infer_semantic_type from .column_quality_score import column_quality_score +from .select_groupby_keys import select_groupby_keys from .render_eda_markdown import render_eda_markdown from .detect_distribution_type import detect_distribution_type from .spearman_corr import spearman_corr @@ -36,6 +37,8 @@ from .infer_fk_containment_duckdb import infer_fk_containment_duckdb from .build_join_graph import build_join_graph from .association_matrix import association_matrix from .correlation_matrix_duckdb import correlation_matrix_duckdb +from .pivot_table_duckdb import pivot_table_duckdb +from .groupby_stats_duckdb import groupby_stats_duckdb from .pca_explained import pca_explained from .kmeans_segments import kmeans_segments from .isolation_forest_outliers import isolation_forest_outliers @@ -82,6 +85,8 @@ __all__ = [ "build_join_graph", "association_matrix", "correlation_matrix_duckdb", + "pivot_table_duckdb", + "groupby_stats_duckdb", "pca_explained", "kmeans_segments", "isolation_forest_outliers", @@ -96,6 +101,7 @@ __all__ = [ "summarize_categorical", "infer_semantic_type", "column_quality_score", + "select_groupby_keys", "render_eda_markdown", "detect_distribution_type", "pull_gsc_search_analytics", diff --git a/python/functions/datascience/groupby_stats_duckdb.md b/python/functions/datascience/groupby_stats_duckdb.md new file mode 100644 index 00000000..faf22222 --- /dev/null +++ b/python/functions/datascience/groupby_stats_duckdb.md @@ -0,0 +1,87 @@ +--- +name: groupby_stats_duckdb +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def groupby_stats_duckdb(db_path: str, table: str, group_by: str, measures: list, aggs: list = None, top_n: int = 15) -> dict" +description: "Agregaciones GROUP BY con push-down SQL en DuckDB: para cada measure numerica calcula mean/median/std/min/max por grupo (split-apply-combine en el motor), trayendo solo una fila por grupo. Nucleo de un capitulo de agregacion/OLAP de un EDA. count = tamanio del grupo, independiente de measures." +tags: [eda, groupby, aggregation, olap, duckdb, datascience, push-down, split-apply-combine] +uses_functions: [duckdb_query_readonly_py_infra] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: db_path + desc: "Ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la base. Path inexistente -> {status:'error'} sin lanzar." + - name: table + desc: "Nombre de la tabla. Se interpola citado con dobles comillas (soporta nombres con espacios; las comillas internas se escapan)." + - name: group_by + desc: "Columna por la que agrupar. Se interpola citada. Sus valores distintos son las claves de los grupos." + - name: measures + desc: "Lista de columnas numericas a agregar. Lista vacia es valida: cada grupo trae solo su tamanio `n` y `stats` vacio." + - name: aggs + desc: "Lista de agregaciones. None (default) = ['count','mean','median','std','min','max']. Validas: count (tamanio del grupo, va a `n`), mean->avg, median, std->stddev_samp, min, max (estas cinco por measure). Agg desconocido -> error." + - name: top_n + desc: "Maximo de grupos a devolver, ordenados por tamanio de grupo descendente (default 15). Internamente se piden top_n+1 para detectar truncado." +output: "dict. En exito {status:'ok', group_by, measures:[...], aggs:[...], n_groups:int, truncated:bool, groups:[{key:, n:int, stats:{:{mean,median,std,min,max}}}], note:str}. Las estadisticas son float o None (p.ej. std de un grupo de 1 fila -> NULL -> None). En error {status:'error', error:str} (no lanza)." +tested: true +tests: ["agrega por grupo con valores conocidos", "db inexistente devuelve error sin lanzar", "measures vacias agrega solo count", "columna con espacio agrupa bien"] +test_file_path: "python/functions/datascience/groupby_stats_duckdb_test.py" +file_path: "python/functions/datascience/groupby_stats_duckdb.py" +--- + +## Ejemplo + +```python +import duckdb +from datascience import groupby_stats_duckdb + +# Cargar el titanic en una tabla DuckDB de prueba. +db = "/tmp/titanic.duckdb" +con = duckdb.connect(db) +con.execute( + "CREATE TABLE titanic AS " + "SELECT * FROM read_csv_auto('https://raw.githubusercontent.com/" + "datasciencedojo/datasets/master/titanic.csv')" +) +con.close() + +# Agrupar por sexo midiendo edad y tarifa. +res = groupby_stats_duckdb(db, "titanic", "Sex", ["Age", "Fare"]) +print(res["status"]) # ok +print(res["n_groups"]) # 2 (male, female) +for g in res["groups"]: + print(g["key"], g["n"], round(g["stats"]["Fare"]["mean"], 2)) +# female 314 44.48 +# male 577 25.52 +``` + +## Cuando usarla + +Cuando en un EDA necesitas el clasico split-apply-combine: "para cada categoria de X, +¿cuanto vale en media/mediana/desviacion/min/max la metrica Y?". Es el nucleo de un +capitulo de agregacion/OLAP. Usala antes de pintar barras o boxplots por grupo, para +detectar segmentos con comportamiento distinto, o para resumir una tabla grande sin +traer las filas a RAM: todo el GROUP BY ocurre push-down en el motor de DuckDB y solo +viaja una fila por grupo. `top_n` te deja quedarte con los grupos mas poblados. + +## Gotchas + +- Funcion impura: lee un archivo DuckDB del disco (read_only, nunca lo modifica). La + tabla debe existir ya en el `.db` (no carga CSV; para eso crea la tabla antes). +- Identificadores (tabla, group_by, measures) se interpolan citados con dobles comillas + y escapando las internas: soporta nombres con espacios y evita inyeccion. No pases + expresiones SQL como group_by/measure — solo nombres de columna. +- `count` es el tamanio del grupo (`COUNT(*)`), independiente de las measures: se + refleja en el campo `n` de cada grupo, NO como clave dentro de `stats`. Las claves de + `stats[measure]` son las measure-aggs efectivas (mean/median/std/min/max menos count). +- `std` usa `stddev_samp` (muestral, n-1): un grupo con una sola fila da `NULL` -> `None`. + Las measures pueden contener NULLs; cada agregada los ignora segun la semantica de DuckDB. +- `truncated:True` indica que habia mas grupos que `top_n` (se devolvieron los `top_n` + mayores por tamanio). Sube `top_n` si necesitas todos los grupos. +- Si `measures` esta vacio, cada grupo trae solo `n` y `stats == {}` (valido, util para + un simple conteo por categoria). diff --git a/python/functions/datascience/groupby_stats_duckdb.py b/python/functions/datascience/groupby_stats_duckdb.py new file mode 100644 index 00000000..9a2bfd1e --- /dev/null +++ b/python/functions/datascience/groupby_stats_duckdb.py @@ -0,0 +1,184 @@ +"""groupby_stats_duckdb — agregaciones GROUP BY con push-down SQL en DuckDB. + +Funcion impura: lee de disco a traves de DuckDB (via la primitiva read-only +`duckdb_query_readonly` del grupo `duckdb`). Pertenece al grupo de capacidad `eda`. + +Ejecuta un `GROUP BY ` en el motor de DuckDB (split-apply-combine con +push-down) calculando, para cada columna numerica de `measures`, las agregaciones +pedidas (mean/median/std/min/max). Solo trae al cliente una fila por grupo, nunca +las filas crudas: apto para tablas grandes. Es el nucleo de un capitulo de +agregacion/OLAP de un EDA. + +Estilo dict-no-throw del grupo duckdb: nunca lanza; captura cualquier error y +devuelve {status:'error', error:str}. +""" + +from infra import duckdb_query_readonly + +# Mapeo agg -> funcion agregada SQL de DuckDB. `count` se trata aparte: es +# COUNT(*) (tamanio del grupo), independiente de las measures. +_AGG_SQL = { + "mean": "avg", + "median": "median", + "std": "stddev_samp", + "min": "min", + "max": "max", +} + +# Aggs por defecto cuando aggs=None. count primero (tamanio del grupo) + las +# cinco estadisticas por measure. +_DEFAULT_AGGS = ["count", "mean", "median", "std", "min", "max"] + + +def _quote_ident(ident: str) -> str: + """Cita un identificador SQL con dobles comillas, escapando las internas. + + Soporta nombres con espacios o caracteres especiales y evita inyeccion: dentro + de un identificador entrecomillado el unico caracter peligroso es la propia + comilla doble, que se duplica ("") segun el estandar SQL. DuckDB no admite + parametros posicionales para nombres de tabla/columna, asi que esta es la via + segura de interpolarlos. + """ + return '"' + str(ident).replace('"', '""') + '"' + + +def groupby_stats_duckdb( + db_path: str, + table: str, + group_by: str, + measures: list, + aggs: list = None, + top_n: int = 15, +) -> dict: + """GROUP BY con agregaciones por measure, todo push-down en DuckDB. + + Args: + db_path: ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la + base. Un path inexistente devuelve {status:'error', ...} sin lanzar. + table: nombre de la tabla. Se interpola citado con dobles comillas (soporta + nombres con espacios). + group_by: columna por la que agrupar. Se interpola citada. + measures: lista de columnas numericas a agregar. Lista vacia es valida: + cada grupo trae solo su tamanio `n` y `stats` vacio. + aggs: lista de agregaciones a calcular. None (default) = + ["count", "mean", "median", "std", "min", "max"]. Valores validos: + count (tamanio del grupo, va a `n`), mean, median, std, min, max + (estas cinco se calculan por cada measure). Un agg desconocido devuelve + error. + top_n: numero maximo de grupos a devolver, ordenados por tamanio de grupo + descendente (default 15). Se pide top_n+1 internamente para detectar si + habia mas grupos y marcar `truncated`. + + Returns: + dict. En exito: + {status:'ok', + group_by:str, + measures:[...], + aggs:[...], # las efectivas (incluye count si se pidio) + n_groups:int, # nº de grupos devueltos (<= top_n) + truncated:bool, # True si habia mas de top_n grupos + groups:[{key:, n:int, + stats:{:{mean,median,std,min,max}}}, ...], + note:str} + Las estadisticas son float o None (p.ej. stddev_samp de un grupo de una + sola fila -> NULL -> None). En error (sin lanzar): {status:'error', error:str}. + """ + try: + # 1. Validar entradas. + if not isinstance(table, str) or table == "": + return {"status": "error", "error": "table must be a non-empty string"} + if not isinstance(group_by, str) or group_by == "": + return {"status": "error", "error": "group_by must be a non-empty string"} + + if measures is None: + measures = [] + if not isinstance(measures, list): + return {"status": "error", "error": "measures must be a list"} + for m in measures: + if not isinstance(m, str) or m == "": + return { + "status": "error", + "error": f"invalid measure identifier: {m!r}", + } + + if aggs is None: + aggs = list(_DEFAULT_AGGS) + if not isinstance(aggs, list) or len(aggs) == 0: + return { + "status": "error", + "error": "aggs must be a non-empty list or None", + } + for a in aggs: + if a != "count" and a not in _AGG_SQL: + return { + "status": "error", + "error": f"unknown agg {a!r}; valid: count, " + + ", ".join(_AGG_SQL), + } + + if not isinstance(top_n, int) or isinstance(top_n, bool) or top_n < 1: + return {"status": "error", "error": "top_n must be a positive int"} + + # 2. Aggs por measure = todas menos count (count es el tamanio del grupo, + # se mapea siempre a la columna `n`). + measure_aggs = [a for a in aggs if a != "count"] + + # 3. Construir el SELECT. grp y n primero; luego un termino por measure x agg + # con alias posicional (m{idx}_{agg}) para no chocar con nombres de columna + # que lleven espacios o caracteres raros. + select_terms = [f"{_quote_ident(group_by)} AS grp", "COUNT(*) AS n"] + agg_index = [] # (measure_name, agg_name, alias) + for mi, m in enumerate(measures): + for a in measure_aggs: + alias = f"m{mi}_{a}" + fn = _AGG_SQL[a] + select_terms.append(f"{fn}({_quote_ident(m)}) AS {alias}") + agg_index.append((m, a, alias)) + + # Pedimos top_n+1 grupos para detectar truncado (habia mas que top_n). + sql = ( + f"SELECT {', '.join(select_terms)} " + f"FROM {_quote_ident(table)} " + f"GROUP BY {_quote_ident(group_by)} " + f"ORDER BY n DESC " + f"LIMIT {top_n + 1}" + ) + + # 4. Ejecutar push-down. sandbox=True (default) basta: la tabla ya existe en + # el .db, no necesitamos read_csv/read_blob ni acceso al filesystem. + result = duckdb_query_readonly(db_path, sql, max_rows=top_n + 1) + if result.get("status") != "ok": + return { + "status": "error", + "error": "groupby query failed: " + + str(result.get("error", "unknown")), + } + + rows = result.get("rows", []) + truncated = len(rows) > top_n + if truncated: + rows = rows[:top_n] + + # 5. Reconstruir la estructura por grupo. + groups = [] + for row in rows: + stats = {m: {} for m in measures} + for (m, a, alias) in agg_index: + stats[m][a] = row.get(alias) + groups.append( + {"key": row.get("grp"), "n": row.get("n"), "stats": stats} + ) + + return { + "status": "ok", + "group_by": group_by, + "measures": list(measures), + "aggs": list(aggs), + "n_groups": len(groups), + "truncated": truncated, + "groups": groups, + "note": f"GROUP BY {group_by}: top {len(groups)} grupos por tamanio sobre " + f"{len(measures)} measure(s)", + } + except Exception as e: # noqa: BLE001 + return {"status": "error", "error": str(e)} diff --git a/python/functions/datascience/groupby_stats_duckdb_test.py b/python/functions/datascience/groupby_stats_duckdb_test.py new file mode 100644 index 00000000..e0857b3d --- /dev/null +++ b/python/functions/datascience/groupby_stats_duckdb_test.py @@ -0,0 +1,106 @@ +"""Tests para groupby_stats_duckdb.""" + +import os +import sys + +import duckdb + +# Permitir importar funciones del registry (from infra import ..., from datascience import ...). +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions")) + +from datascience.groupby_stats_duckdb import groupby_stats_duckdb + + +def _make_db(tmp_path, rows): + """Crea una DuckDB con tabla t(g VARCHAR, x DOUBLE) e inserta `rows`.""" + db = os.path.join(str(tmp_path), "t.duckdb") + con = duckdb.connect(db) + con.execute("CREATE TABLE t(g VARCHAR, x DOUBLE)") + con.executemany("INSERT INTO t VALUES (?, ?)", rows) + con.close() + return db + + +def test_agrega_por_grupo_con_valores_conocidos(tmp_path): + # Grupo a: [10, 20, 30] -> n=3, mean=20, min=10, max=30, median=20, std=10. + # Grupo b: [5, 15] -> n=2, mean=10, median=10. + # Grupo c: [100] -> n=1, mean=100, std=None (1 sola fila). + rows = [ + ("a", 10.0), ("a", 20.0), ("a", 30.0), + ("b", 5.0), ("b", 15.0), + ("c", 100.0), + ] + db = _make_db(tmp_path, rows) + res = groupby_stats_duckdb(db, "t", "g", ["x"]) + assert res["status"] == "ok", res + assert res["n_groups"] == 3 + assert res["truncated"] is False + assert res["aggs"] == ["count", "mean", "median", "std", "min", "max"] + + by_key = {g["key"]: g for g in res["groups"]} + assert set(by_key) == {"a", "b", "c"} + + # Grupo a: comprobacion manual de mean/min/max/median/std. + sa = by_key["a"]["stats"]["x"] + assert by_key["a"]["n"] == 3 + assert abs(sa["mean"] - 20.0) < 1e-9 + assert abs(sa["min"] - 10.0) < 1e-9 + assert abs(sa["max"] - 30.0) < 1e-9 + assert abs(sa["median"] - 20.0) < 1e-9 + assert "std" in sa and sa["std"] is not None + assert abs(sa["std"] - 10.0) < 1e-9 # stddev_samp([10,20,30]) = 10 + + # Grupo b: mean y median conocidas. + sb = by_key["b"]["stats"]["x"] + assert by_key["b"]["n"] == 2 + assert abs(sb["mean"] - 10.0) < 1e-9 + assert abs(sb["median"] - 10.0) < 1e-9 + assert "median" in sb and "std" in sb + + # Grupo c: una sola fila -> std None (stddev_samp NULL), mean/min/max definidos. + sc = by_key["c"]["stats"]["x"] + assert by_key["c"]["n"] == 1 + assert abs(sc["mean"] - 100.0) < 1e-9 + assert sc["std"] is None + + +def test_db_inexistente_devuelve_error_sin_lanzar(tmp_path): + db = os.path.join(str(tmp_path), "no_existe.duckdb") + res = groupby_stats_duckdb(db, "t", "g", ["x"]) + assert res["status"] == "error", res + assert isinstance(res["error"], str) and res["error"] + + +def test_measures_vacias_agrega_solo_count(tmp_path): + rows = [("a", 1.0), ("a", 2.0), ("b", 3.0)] + db = _make_db(tmp_path, rows) + res = groupby_stats_duckdb(db, "t", "g", []) + assert res["status"] == "ok", res + by_key = {g["key"]: g for g in res["groups"]} + assert by_key["a"]["n"] == 2 + assert by_key["b"]["n"] == 1 + # Sin measures, stats por grupo es un dict vacio (valido). + assert by_key["a"]["stats"] == {} + assert by_key["b"]["stats"] == {} + + +def test_columna_con_espacio_agrupa_bien(tmp_path): + # Tabla con nombres de columna con espacios -> prueba el quoting con dobles + # comillas tanto en group_by como en la measure. + db = os.path.join(str(tmp_path), "space.duckdb") + con = duckdb.connect(db) + con.execute('CREATE TABLE t("my col" VARCHAR, "the val" DOUBLE)') + con.executemany( + 'INSERT INTO t VALUES (?, ?)', + [("x", 1.0), ("x", 3.0), ("y", 10.0)], + ) + con.close() + + res = groupby_stats_duckdb(db, "t", "my col", ["the val"]) + assert res["status"] == "ok", res + by_key = {g["key"]: g for g in res["groups"]} + assert by_key["x"]["n"] == 2 + assert abs(by_key["x"]["stats"]["the val"]["mean"] - 2.0) < 1e-9 + assert by_key["y"]["n"] == 1 + assert abs(by_key["y"]["stats"]["the val"]["mean"] - 10.0) < 1e-9 diff --git a/python/functions/datascience/pivot_table_duckdb.md b/python/functions/datascience/pivot_table_duckdb.md new file mode 100644 index 00000000..a14d55eb --- /dev/null +++ b/python/functions/datascience/pivot_table_duckdb.md @@ -0,0 +1,92 @@ +--- +name: pivot_table_duckdb +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def pivot_table_duckdb(db_path: str, table: str, index: str, columns: str, value: str, agg: str = 'mean', top_rows: int = 10, top_cols: int = 8) -> dict" +description: "Pivot table (index x columns -> agg(value)) calculada con push-down SQL en DuckDB (GROUP BY en el motor, sin traer filas a RAM) y recortada a las top_rows filas y top_cols columnas con mas observaciones para que quepa entera en un PDF movil / slide PPTX sin cortarse. Version push-down para tablas grandes de la funcion pura `pivot` (que pivota list[dict] en memoria)." +tags: [eda, pivot, duckdb, aggregate, datascience, push-down, report] +uses_functions: [duckdb_query_readonly_py_infra] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: db_path + desc: "Ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la base." + - name: table + desc: "Nombre de la tabla a pivotar. Se interpola citado con dobles comillas (DuckDB no admite parametros para identificadores)." + - name: index + desc: "Columna cuyos valores forman las filas de la pivot (eje vertical)." + - name: columns + desc: "Columna cuyos valores forman las columnas de la pivot (eje horizontal)." + - name: value + desc: "Columna numerica a agregar en cada celda. Ignorada cuando agg='count'." + - name: agg + desc: "Funcion de agregacion: mean, sum, count, min, max, median. mean->avg(), count->COUNT(*). Otro valor devuelve {status:'error'}." + - name: top_rows + desc: "Numero maximo de filas a conservar, elegidas por mayor numero de observaciones (suma de COUNT(*) por valor de index). Default 10." + - name: top_cols + desc: "Numero maximo de columnas a conservar, elegidas por mayor numero de observaciones (suma de COUNT(*) por valor de columns). Default 8." +output: "dict. En exito {status:'ok', index, columns, value, agg, row_labels:[...], col_labels:[...], matrix:[[...]], truncated_rows:bool, truncated_cols:bool, note:str}. matrix tiene len(row_labels) filas y cada fila len(col_labels) celdas (valor agregado o None si la combinacion no existe). truncated_* indica si hubo mas filas/columnas que el top. En error {status:'error', error:str} (no lanza)." +tested: true +tests: ["pivot mean labels y celda conocida", "pivot trunca a top rows y top cols", "pivot count no necesita value real", "pivot db inexistente devuelve error sin lanzar", "pivot agg invalido devuelve error"] +test_file_path: "python/functions/datascience/pivot_table_duckdb_test.py" +file_path: "python/functions/datascience/pivot_table_duckdb.py" +--- + +## Ejemplo + +```python +import duckdb +from datascience import pivot_table_duckdb + +# Tabla DuckDB de prueba estilo titanic: sex x pclass -> mean(fare). +db = "/tmp/pivot_demo.duckdb" +con = duckdb.connect(db) +con.execute( + "CREATE TABLE titanic AS SELECT * FROM (VALUES " + "('male',1,211.3),('female',1,151.5),('male',3,7.9)," + "('female',3,16.7),('male',1,52.0),('female',2,41.6)" + ") t(sex, pclass, fare)" +) +con.close() + +res = pivot_table_duckdb(db, "titanic", index="sex", columns="pclass", value="fare", agg="mean") +print(res["status"]) # ok +print(res["row_labels"]) # ['female', 'male'] (orden por nº de observaciones desc; empate -> etiqueta) +print(res["col_labels"]) # [1, 3, 2] (pclass=1 tiene 3 obs, pclass=3 -> 2, pclass=2 -> 1) +print(res["matrix"]) # [[151.5, 16.7, 41.6], [131.65, 7.9, None]] (male/pclass=2 no existe -> None) +``` + +## Cuando usarla + +Cuando quieres una pivot table (`index` x `columns` -> `agg(value)`) de una tabla +DuckDB con MUCHAS filas y necesitas que el resultado quepa entero en un informe: un +PDF abierto en el movil o un slide PPTX, donde una matriz de 50x30 se cortaria. La +agregacion se hace push-down en el motor (no traes las filas a RAM) y el resultado se +limita a las `top_rows` x `top_cols` combinaciones con mas observaciones. Encaja en el +flujo `eda` para resumir el cruce de dos categoricas (sexo x clase, region x producto) +contra una metrica. Para pivotar un `list[dict]` ya cargado en memoria usa la funcion +pura `pivot_py_datascience`; esta es la version push-down sobre disco. + +## Gotchas + +- Funcion impura: lee un archivo DuckDB del disco (read_only, nunca lo modifica). +- Recorta a `top_rows` x `top_cols` por numero de observaciones (suma de `COUNT(*)`), + NO por magnitud del valor agregado. Si habia mas filas/columnas, `truncated_rows` / + `truncated_cols` quedan en True y esas combinaciones NO aparecen en la matriz. +- Las celdas sin datos (combinacion `index` x `columns` que no existe en la tabla) se + rellenan con `None`, no con 0: distinguir "cero medido" de "sin observaciones". +- `agg='count'` cuenta filas por celda con `COUNT(*)` e ignora `value` (puedes pasar + cualquier nombre de columna). Para el resto de aggs, `value` debe ser una columna + numerica real o la query fallara con `{status:'error'}`. +- `agg` solo admite mean, sum, count, min, max, median; cualquier otro valor devuelve + `{status:'error'}` sin tocar la base. +- Orden de `row_labels` / `col_labels`: por numero de observaciones descendente, con + desempate estable por etiqueta. No es orden alfabetico ni el de aparicion. +- La query se ejecuta con `sandbox=False` en `duckdb_query_readonly` (uso interno + confiable: el SQL lo construye esta funcion, no un cliente externo). diff --git a/python/functions/datascience/pivot_table_duckdb.py b/python/functions/datascience/pivot_table_duckdb.py new file mode 100644 index 00000000..cee20464 --- /dev/null +++ b/python/functions/datascience/pivot_table_duckdb.py @@ -0,0 +1,176 @@ +"""pivot_table_duckdb — pivot table (index x columns -> agg(value)) con push-down SQL. + +Funcion impura: lee de disco a traves de DuckDB reusando la primitiva read-only del +grupo `duckdb` (`duckdb_query_readonly`). Pertenece al grupo de capacidad `eda` +(exploratory data analysis). + +A diferencia de la funcion pura `pivot` (que pivota un `list[dict]` ya cargado en +memoria), esta version empuja la agregacion al motor de DuckDB (push-down): el +GROUP BY lo resuelve DuckDB y solo se traen los valores agregados, nunca las filas +crudas. Esto la hace apta para tablas grandes. + +Ademas reduce el resultado a las `top_rows` filas y `top_cols` columnas con mas +observaciones, de modo que la pivot quepa entera en un PDF movil / slide PPTX sin +cortarse. Marca `truncated_rows`/`truncated_cols` cuando hubo recorte. + +Estilo dict-no-throw del grupo duckdb: nunca lanza; captura cualquier error y +devuelve {status:'error', error:str}. +""" + +from collections import defaultdict + +from infra import duckdb_query_readonly + +# Funciones de agregacion permitidas y su nombre en SQL DuckDB. +# mean -> avg; el resto mapea directo. count se trata aparte (COUNT(*), sin value). +_AGG_SQL = { + "mean": "avg", + "sum": "sum", + "count": "count", + "min": "min", + "max": "max", + "median": "median", +} + + +def _quote_ident(ident: str) -> str: + """Cita un identificador SQL con dobles comillas, escapando `"` -> `""`. + + DuckDB no admite parametros posicionales para nombres de tabla/columna, asi que + hay que interpolarlos. El quoting con `"` y el doblado de comillas internas evita + que un nombre rompa la sentencia (mismo patron que correlation_matrix_duckdb). + """ + return '"' + str(ident).replace('"', '""') + '"' + + +def pivot_table_duckdb( + db_path: str, + table: str, + index: str, + columns: str, + value: str, + agg: str = "mean", + top_rows: int = 10, + top_cols: int = 8, +) -> dict: + """Pivot table push-down en DuckDB, recortada a top_rows x top_cols. + + Construye una pivot (filas = valores de `index`, columnas = valores de `columns`, + celda = `agg(value)`) agregando en el motor de DuckDB, y la reduce a las filas y + columnas con mas observaciones para que quepa en un PDF / slide. + + Args: + db_path: ruta al archivo DuckDB. Debe existir (read_only NO crea la base). + table: nombre de la tabla a pivotar. + index: columna cuyos valores forman las filas de la pivot. + columns: columna cuyos valores forman las columnas de la pivot. + value: columna numerica a agregar. Ignorada cuando agg="count". + agg: funcion de agregacion. Una de: "mean", "sum", "count", "min", "max", + "median". mean se traduce a avg(); count a COUNT(*). + top_rows: numero maximo de filas a conservar, elegidas por mayor numero de + observaciones (suma de COUNT(*) por valor de index). Default 10. + top_cols: numero maximo de columnas a conservar, elegidas por mayor numero de + observaciones (suma de COUNT(*) por valor de columns). Default 8. + + Returns: + dict. En exito: + {status:'ok', + index, columns, value, agg, + row_labels:[...], # valores de index, en orden de freq desc + col_labels:[...], # valores de columns, en orden de freq desc + matrix:[[...], ...], # len == len(row_labels); cada fila + # len == len(col_labels); celda = agg o None + truncated_rows:bool, truncated_cols:bool, + note:str} + En error (sin lanzar): {status:'error', error:str}. + """ + try: + if not isinstance(agg, str) or agg not in _AGG_SQL: + return { + "status": "error", + "error": "invalid agg " + + repr(agg) + + "; allowed: " + + ", ".join(sorted(_AGG_SQL)), + } + + # Paso 1 (push-down): agregar (index, columns) -> agg(value) + COUNT(*). + if agg == "count": + agg_expr = "COUNT(*)" + else: + agg_expr = f"{_AGG_SQL[agg]}({_quote_ident(value)})" + + sql = ( + f"SELECT {_quote_ident(index)} AS r, " + f"{_quote_ident(columns)} AS c, " + f"{agg_expr} AS v, " + f"COUNT(*) AS n " + f"FROM {_quote_ident(table)} " + f"GROUP BY {_quote_ident(index)}, {_quote_ident(columns)}" + ) + + # max_rows alto: queremos todos los grupos (index x columns) para elegir el + # top con criterio global. sandbox=False igual que correlation_matrix_duckdb, + # porque db_path es una ruta interna de confianza. + result = duckdb_query_readonly( + db_path, sql, max_rows=1_000_000, sandbox=False + ) + if result.get("status") != "ok": + return { + "status": "error", + "error": "pivot query failed: " + + str(result.get("error", "unknown")), + } + + # Paso 2 (en python): contar observaciones por fila y por columna, y guardar + # el valor agregado de cada celda (r, c). + row_obs: dict = defaultdict(int) + col_obs: dict = defaultdict(int) + cell: dict = {} + for row in result.get("rows", []): + r = row.get("r") + c = row.get("c") + n = row.get("n") or 0 + row_obs[r] += n + col_obs[c] += n + cell[(r, c)] = row.get("v") + + def _top(obs: dict, limit: int): + # Orden: mas observaciones primero; desempate estable por etiqueta. + ranked = sorted(obs.items(), key=lambda kv: (-kv[1], str(kv[0]))) + selected = [label for label, _ in ranked[:limit]] + return selected, len(ranked) > limit + + row_labels, truncated_rows = _top(row_obs, top_rows) + col_labels, truncated_cols = _top(col_obs, top_cols) + + # Paso 3: materializar la matriz; None donde la combinacion no existe. + matrix = [ + [cell.get((r, c)) for c in col_labels] for r in row_labels + ] + + note = ( + f"pivot {agg}({value}) reducida a {len(row_labels)}x{len(col_labels)} " + "(top por observaciones) para caber en PDF/slide" + ) + if agg == "count": + note = ( + f"pivot count(*) reducida a {len(row_labels)}x{len(col_labels)} " + "(top por observaciones) para caber en PDF/slide" + ) + + return { + "status": "ok", + "index": index, + "columns": columns, + "value": value, + "agg": agg, + "row_labels": row_labels, + "col_labels": col_labels, + "matrix": matrix, + "truncated_rows": truncated_rows, + "truncated_cols": truncated_cols, + "note": note, + } + except Exception as e: # noqa: BLE001 + return {"status": "error", "error": str(e)} diff --git a/python/functions/datascience/pivot_table_duckdb_test.py b/python/functions/datascience/pivot_table_duckdb_test.py new file mode 100644 index 00000000..53f85e2e --- /dev/null +++ b/python/functions/datascience/pivot_table_duckdb_test.py @@ -0,0 +1,115 @@ +"""Tests para pivot_table_duckdb.""" + +import os +import sys + +import duckdb + +# Permitir importar funciones del registry (from infra import ..., from datascience import ...). +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions")) + +from datascience.pivot_table_duckdb import pivot_table_duckdb + + +def _make_db(tmp_name: str) -> str: + """Crea una DuckDB con dos categoricas (a, b) y un valor numerico conocido. + + Filas: + a='x', b='y', val=10 + a='x', b='y', val=20 -> mean(x,y) = 15, count(x,y) = 2 + a='x', b='z', val=5 -> mean(x,z) = 5 + a='w', b='y', val=100 -> mean(w,y) = 100 + Observaciones por a: x=3, w=1. Por b: y=3, z=1. + La combinacion (w, z) no existe -> celda None. + """ + db = os.path.join("/tmp", tmp_name) + if os.path.exists(db): + os.remove(db) + con = duckdb.connect(db) + con.execute("CREATE TABLE t (a VARCHAR, b VARCHAR, val DOUBLE)") + con.execute( + "INSERT INTO t VALUES " + "('x','y',10),('x','y',20),('x','z',5),('w','y',100)" + ) + con.close() + return db + + +def test_pivot_mean_labels_y_celda_conocida(): + db = _make_db("pivot_test_mean.duckdb") + res = pivot_table_duckdb(db, "t", index="a", columns="b", value="val", agg="mean") + assert res["status"] == "ok", res + # Filas ordenadas por observaciones desc: x (3) antes que w (1). + assert res["row_labels"] == ["x", "w"], res["row_labels"] + # Columnas ordenadas por observaciones desc: y (3) antes que z (1). + assert res["col_labels"] == ["y", "z"], res["col_labels"] + # matrix[0][0] = mean(a='x', b='y') = (10 + 20) / 2 = 15. + assert abs(res["matrix"][0][0] - 15.0) < 1e-9, res["matrix"] + # matrix[0][1] = mean(a='x', b='z') = 5. + assert abs(res["matrix"][0][1] - 5.0) < 1e-9, res["matrix"] + # matrix[1][0] = mean(a='w', b='y') = 100. + assert abs(res["matrix"][1][0] - 100.0) < 1e-9, res["matrix"] + # (w, z) no existe -> None. + assert res["matrix"][1][1] is None, res["matrix"] + # Sin truncado con los defaults (top_rows=10, top_cols=8). + assert res["truncated_rows"] is False + assert res["truncated_cols"] is False + # La matriz es rectangular consistente con las etiquetas. + assert len(res["matrix"]) == len(res["row_labels"]) + for fila in res["matrix"]: + assert len(fila) == len(res["col_labels"]) + + +def test_pivot_trunca_a_top_rows_y_top_cols(): + db = _make_db("pivot_test_trunc.duckdb") + res = pivot_table_duckdb( + db, "t", index="a", columns="b", value="val", agg="mean", + top_rows=1, top_cols=1, + ) + assert res["status"] == "ok", res + # Solo la fila/columna mas frecuente sobrevive. + assert res["row_labels"] == ["x"], res["row_labels"] + assert res["col_labels"] == ["y"], res["col_labels"] + assert res["matrix"] == [[15.0]], res["matrix"] + # Habia mas de 1 fila y mas de 1 columna -> truncado en ambos ejes. + assert res["truncated_rows"] is True + assert res["truncated_cols"] is True + + +def test_pivot_count_no_necesita_value_real(): + db = _make_db("pivot_test_count.duckdb") + # value apunta a una columna real pero count(*) la ignora; tambien valdria un + # nombre cualquiera. Verificamos que count funciona igualmente. + res = pivot_table_duckdb( + db, "t", index="a", columns="b", value="val", agg="count" + ) + assert res["status"] == "ok", res + assert res["row_labels"] == ["x", "w"] + assert res["col_labels"] == ["y", "z"] + # count(a='x', b='y') = 2 observaciones. + assert res["matrix"][0][0] == 2, res["matrix"] + # count(a='x', b='z') = 1. + assert res["matrix"][0][1] == 1, res["matrix"] + # count(a='w', b='y') = 1. + assert res["matrix"][1][0] == 1, res["matrix"] + # (w, z) no existe -> None. + assert res["matrix"][1][1] is None, res["matrix"] + + +def test_pivot_db_inexistente_devuelve_error_sin_lanzar(): + res = pivot_table_duckdb( + "/nonexistent/path/does_not_exist.duckdb", + "t", index="a", columns="b", value="val", agg="mean", + ) + assert res["status"] == "error", res + assert isinstance(res["error"], str) + + +def test_pivot_agg_invalido_devuelve_error(): + db = _make_db("pivot_test_badagg.duckdb") + res = pivot_table_duckdb( + db, "t", index="a", columns="b", value="val", agg="stddev" + ) + assert res["status"] == "error", res + assert "invalid agg" in res["error"] diff --git a/python/functions/datascience/select_groupby_keys.md b/python/functions/datascience/select_groupby_keys.md new file mode 100644 index 00000000..d97332eb --- /dev/null +++ b/python/functions/datascience/select_groupby_keys.md @@ -0,0 +1,158 @@ +--- +id: select_groupby_keys_py_datascience +name: select_groupby_keys +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: pure +signature: "def select_groupby_keys(profile: dict, max_keys: int = 3, max_card: int = 20, max_measures: int = 4) -> dict" +description: "Elige deterministicamente las columnas categoricas mas interesantes para GROUP BY, las numericas medida y pares pivote a partir de un TableProfile del grupo eda. Respaldo cuantitativo para el capitulo de agregacion/OLAP de un EDA. Funcion pura, no muta el input, nunca lanza." +tags: [eda, aggregation, groupby, olap, profiling, datascience] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "" +imports: [] +example: | + from datascience import select_groupby_keys + profile = { + "n_rows": 891, + "key_candidates": ["passenger_id"], + "columns": [ + {"name": "sex", "inferred_type": "categorical", "distinct_count": 2, + "unique_pct": 0.002, "null_pct": 0.0, "flags": [], + "categorical": {"imbalance": 1.8}, "numeric": None}, + {"name": "pclass", "inferred_type": "categorical", "distinct_count": 3, + "unique_pct": 0.003, "null_pct": 0.0, "flags": [], + "categorical": {"imbalance": 2.5}, "numeric": None}, + {"name": "fare", "inferred_type": "numeric", "distinct_count": 200, + "unique_pct": 0.2, "null_pct": 0.0, "flags": [], + "numeric": {"std": 49.7, "cv": 1.54}, "categorical": None}, + ], + } + select_groupby_keys(profile) + # {"group_keys": [{"col": "sex", ...}, {"col": "pclass", ...}], + # "measures": ["fare"], + # "pivots": [{"index": "sex", "columns": "pclass", "value": "fare"}], + # "note": "2 clave(s) de grupo: sex, pclass; 1 medida(s): fare; 1 pivot(s)."} +tested: true +tests: + - "test_titanic_picks_good_cats_excludes_id_and_constant" + - "test_titanic_measures_exclude_id_constant_and_keep_numerics" + - "test_titanic_generates_one_pivot" + - "test_empty_profile_returns_all_empty_and_does_not_crash" + - "test_none_profile_does_not_crash" + - "test_only_numerics_yields_empty_group_keys_and_no_pivots" + - "test_high_cardinality_and_max_card_are_excluded" + - "test_max_keys_limits_group_keys" + - "test_three_keys_cap_pivots_to_two" + - "test_does_not_mutate_input" +test_file_path: "python/functions/datascience/select_groupby_keys_test.py" +file_path: "python/functions/datascience/select_groupby_keys.py" +params: + - name: profile + desc: > + TableProfile dict del grupo eda (p.ej. salida de summarize_table_duckdb). + Se lee de forma defensiva (.get / or [] / isinstance). Claves usadas: + columns (list[ColumnProfile]), key_candidates (list de nombres de columna + o dicts {name}), n_rows. Cada ColumnProfile usa: name, inferred_type + ("numeric"|"categorical"|"datetime"|"text"|"boolean"), distinct_count, + unique_pct (0..1), null_pct (0..1), flags (list[str], reconoce + "possible_id"/"high_cardinality"/"constant"), numeric ({std, cv, ...}|None) + y categorical ({imbalance, mode_pct, ...}|None). + - name: max_keys + desc: "Numero maximo de claves de grupo (group_keys) a devolver. Default 3." + - name: max_card + desc: > + Cardinalidad maxima (distinct_count) que una columna categorica puede + tener para seguir siendo candidata a clave de grupo. Default 20. + - name: max_measures + desc: "Numero maximo de columnas medida (nombres) a devolver. Default 4." +output: > + dict con group_keys (list de {col, cardinality, score} ordenada por score + desc), measures (list[str] de nombres de columnas numericas ordenadas por + dispersion), pivots (list de {index, columns, value}, hasta 2 pares + categorica x categorica con la primera measure como valor) y note (str, + resumen corto en espanol de lo elegido). Ante profile vacio/None devuelve + todas las listas vacias y una note descriptiva; nunca lanza. +--- + +## Ejemplo + +```python +from datascience import select_groupby_keys + +# TableProfile estilo titanic: 2 categoricas buenas, 1 numerica medida, +# 1 id secuencial (descartado) y un key_candidate declarado. +profile = { + "n_rows": 891, + "key_candidates": ["passenger_id"], + "columns": [ + {"name": "sex", "inferred_type": "categorical", "distinct_count": 2, + "unique_pct": 0.002, "null_pct": 0.0, "flags": [], + "categorical": {"imbalance": 1.8}, "numeric": None}, + {"name": "pclass", "inferred_type": "categorical", "distinct_count": 3, + "unique_pct": 0.003, "null_pct": 0.0, "flags": [], + "categorical": {"imbalance": 2.5}, "numeric": None}, + {"name": "fare", "inferred_type": "numeric", "distinct_count": 200, + "unique_pct": 0.2, "null_pct": 0.0, "flags": [], + "numeric": {"std": 49.7, "cv": 1.54}, "categorical": None}, + {"name": "passenger_id", "inferred_type": "numeric", "distinct_count": 891, + "unique_pct": 1.0, "null_pct": 0.0, "flags": ["possible_id"], + "numeric": {"std": 257.4, "cv": 0.58}, "categorical": None}, + ], +} + +select_groupby_keys(profile) +# { +# "group_keys": [ +# {"col": "sex", "cardinality": 2, "score": 0.5556}, +# {"col": "pclass", "cardinality": 3, "score": 0.4}, +# ], +# "measures": ["fare"], # passenger_id excluido (id secuencial) +# "pivots": [{"index": "sex", "columns": "pclass", "value": "fare"}], +# "note": "2 clave(s) de grupo: sex, pclass; 1 medida(s): fare; 1 pivot(s).", +# } +``` + +## Cuando usarla + +Cuando hayas perfilado una tabla con el grupo `eda` (p.ej. +`summarize_table_duckdb`) y necesites decidir, sin mirar los datos, por qué +columnas merece la pena agrupar (GROUP BY) y qué métricas numéricas agregar: +el respaldo cuantitativo del capítulo de agregación/OLAP de un AutomaticEDA, o +para proponer pivotes en un dashboard. Es la capa de selección sobre el +TableProfile crudo: lee el perfil, ordena candidatos de forma determinista y +no toca los datos. + +## Notas + +Función pura, sin I/O ni dependencias externas (solo stdlib), no muta +`profile`. Lectura defensiva total (`.get`, `or []`, `isinstance`): un `{}` o +`None` produce `{"group_keys": [], "measures": [], "pivots": [], "note": ...}` +y nunca lanza. + +Criterios de selección (deterministas): + +- **group_keys** — candidatas con `inferred_type` en `("categorical","boolean")`. + Se descartan las que estén en `key_candidates`, con flag + `possible_id`/`high_cardinality`/`constant`, con `distinct_count` fuera de + `[2, max_card]`, o all-null (`null_pct >= 0.999`). `score = card_score * + balance_score`: `card_score` mantiene un plateau para cardinalidad moderada + (2..12) y decae hacia `max_card`; `balance_score = 1/imbalance` usando + `categorical.imbalance` si está, aproximando con `mode_pct` si no, o un valor + neutro (0.5) en último caso. Devuelve hasta `max_keys`, ordenadas por score + desc (empates por orden de columna). +- **measures** — candidatas con `inferred_type` en + `("numeric","integer","float")`. Se descartan id-like (flag `possible_id` y + `unique_pct >= 0.99`) y constantes (`numeric.std` == 0 o None). Se rankean por + dispersión informativa: `abs(cv)` si está, si no `abs(std)`. Devuelve hasta + `max_measures` **nombres** (strings). +- **pivots** — hasta 2 pares `(group_keys[i].col, group_keys[j].col)` con i dict: + """Select GROUP BY keys, measures and pivot pairs from a TableProfile. + + Reads everything defensively (``.get(...)``, ``or []``, ``isinstance``) and + never raises. With an empty/None profile it returns every list empty. + + Selection rules (deterministic): + + - **group_keys** (categorical columns to group by): candidates have + ``inferred_type`` in ``("categorical", "boolean")``. Discarded if they are + in ``profile['key_candidates']``, carry a ``possible_id`` / + ``high_cardinality`` / ``constant`` flag, have ``distinct_count`` outside + ``[2, max_card]``, or are all-null (``null_pct >= 0.999``). Each survivor + gets ``score = card_score * balance_score`` where ``card_score`` keeps a + plateau for moderate cardinality (2..12) and decays towards ``max_card``, + and ``balance_score = 1 / imbalance`` (``categorical.imbalance`` if + present, else approximated from ``mode_pct``, else a neutral default). + The top ``max_keys`` by score (desc, ties by column order) are returned. + + - **measures** (numeric columns to aggregate): candidates have + ``inferred_type`` in ``("numeric", "integer", "float")``. Discarded if + id-like (``possible_id`` flag *and* ``unique_pct >= 0.99``) or constant + (``numeric.std`` is ``0`` or ``None``). Ranked by informative dispersion: + ``abs(cv)`` when available, else ``abs(std)``. The top ``max_measures`` + **names** are returned. + + - **pivots**: up to 2 ``(group_keys[i].col, group_keys[j].col)`` pairs with + ``i < j``, using the first measure as the aggregated value. Empty when + fewer than 2 group keys were selected. + + Args: + profile: TableProfile dict of the ``eda`` group. Relevant keys: + ``columns`` (list[ColumnProfile]), ``key_candidates`` (list of + column names or ``{name}`` dicts), ``n_rows``. Each ColumnProfile + uses: ``name``, ``inferred_type``, ``distinct_count``, + ``unique_pct`` (0..1), ``null_pct`` (0..1), ``flags`` (list[str]), + ``numeric`` ({std, cv, ...}|None), ``categorical`` + ({imbalance, mode_pct, ...}|None). + max_keys: Maximum number of group-by keys to return. Default 3. + max_card: Maximum cardinality (``distinct_count``) a categorical column + may have to still qualify as a group key. Default 20. + max_measures: Maximum number of measure names to return. Default 4. + + Returns: + dict with: + group_keys (list[{col, cardinality, score}], ordered by score desc), + measures (list[str], numeric column names ordered by dispersion), + pivots (list[{index, columns, value}], up to 2 pairs), + note (str, short summary of what was chosen). + """ + if not isinstance(profile, dict): + profile = {} + + try: + max_keys = int(max_keys) + except (TypeError, ValueError): + max_keys = 3 + try: + max_card = int(max_card) + except (TypeError, ValueError): + max_card = 20 + try: + max_measures = int(max_measures) + except (TypeError, ValueError): + max_measures = 4 + max_keys = max(max_keys, 0) + max_card = max(max_card, 2) + max_measures = max(max_measures, 0) + + columns = profile.get("columns") or [] + if not isinstance(columns, (list, tuple)): + columns = [] + + key_names = _key_candidate_names(profile.get("key_candidates")) + + group_keys = _select_group_keys(columns, key_names, max_keys, max_card) + measures = _select_measures(columns, max_measures) + pivots = _select_pivots(group_keys, measures) + + return { + "group_keys": group_keys, + "measures": measures, + "pivots": pivots, + "note": _build_note(group_keys, measures, pivots), + } + + +# --------------------------------------------------------------------------- +# group_keys +# --------------------------------------------------------------------------- + +_GROUP_TYPES = ("categorical", "boolean") +_DISQUALIFYING_FLAGS = frozenset({"possible_id", "high_cardinality", "constant"}) +_CARD_PLATEAU_HI = 12 # cardinalities 2..12 are all "moderate" (best). + + +def _select_group_keys(columns, key_names, max_keys, max_card) -> list: + """Rank categorical/boolean columns suitable for GROUP BY.""" + scored = [] + for idx, col in enumerate(columns): + if not isinstance(col, dict): + continue + if (col.get("inferred_type") or "") not in _GROUP_TYPES: + continue + + name = col.get("name") + if name is None: + continue + if name in key_names: + continue + + flags = _as_set(col.get("flags")) + if flags & _DISQUALIFYING_FLAGS: + continue + + if _num(col.get("null_pct"), 0.0) >= 0.999: + continue + + card = _num(col.get("distinct_count"), 0.0) + if card < 2 or card > max_card: + continue + card_i = int(card) + + score = _card_score(card_i, max_card) * _balance_score(col.get("categorical")) + scored.append((round(score, 6), idx, name, card_i)) + + # Deterministic: higher score first, ties broken by original column order. + scored.sort(key=lambda t: (-t[0], t[1])) + + out = [] + for score, _idx, name, card_i in scored[:max_keys]: + out.append({"col": name, "cardinality": card_i, "score": score}) + return out + + +def _card_score(card: int, max_card: int) -> float: + """Prefer moderate cardinality; plateau at 2..12, decay towards max_card.""" + if card <= 1: + return 0.0 + if card <= _CARD_PLATEAU_HI: + return 1.0 + denom = max(max_card - _CARD_PLATEAU_HI, 1) + over = card - _CARD_PLATEAU_HI + return max(0.1, 1.0 - over / denom) + + +def _balance_score(categorical) -> float: + """1.0 for a perfectly balanced category, decaying as imbalance grows. + + Uses ``categorical.imbalance`` (max_count/min_count, >= 1) when available; + otherwise approximates from ``mode_pct`` (top-class dominance); otherwise a + neutral default so the column is still selectable. + """ + if isinstance(categorical, dict): + imbalance = categorical.get("imbalance") + if isinstance(imbalance, (int, float)) and imbalance >= 1.0: + return 1.0 / float(imbalance) + mode_pct = categorical.get("mode_pct") + if isinstance(mode_pct, (int, float)): + return _clamp(1.0 - float(mode_pct), 0.0, 1.0) + return 0.5 + + +# --------------------------------------------------------------------------- +# measures +# --------------------------------------------------------------------------- + +_NUMERIC_TYPES = ("numeric", "integer", "float") + + +def _select_measures(columns, max_measures) -> list: + """Rank numeric columns by informative dispersion (cv, else std).""" + scored = [] + for idx, col in enumerate(columns): + if not isinstance(col, dict): + continue + if (col.get("inferred_type") or "") not in _NUMERIC_TYPES: + continue + + name = col.get("name") + if name is None: + continue + + flags = _as_set(col.get("flags")) + unique_pct = _num(col.get("unique_pct"), 0.0) + if "possible_id" in flags and unique_pct >= 0.99: + continue # sequential id, not a measure. + + numeric = col.get("numeric") + std = numeric.get("std") if isinstance(numeric, dict) else None + if not isinstance(std, (int, float)) or std == 0: + continue # constant or unknown spread -> not informative. + + cv = numeric.get("cv") if isinstance(numeric, dict) else None + if isinstance(cv, (int, float)): + dispersion = abs(float(cv)) + else: + dispersion = abs(float(std)) + + scored.append((dispersion, idx, name)) + + # Higher dispersion first, ties broken by original column order. + scored.sort(key=lambda t: (-t[0], t[1])) + return [name for _disp, _idx, name in scored[:max_measures]] + + +# --------------------------------------------------------------------------- +# pivots +# --------------------------------------------------------------------------- + + +def _select_pivots(group_keys, measures) -> list: + """Up to 2 (cat_a, cat_b) pairs from the chosen group keys.""" + if not isinstance(group_keys, list) or len(group_keys) < 2: + return [] + value = measures[0] if measures else None + pairs = [] + n = len(group_keys) + for i in range(n): + for j in range(i + 1, n): + pairs.append({ + "index": group_keys[i].get("col"), + "columns": group_keys[j].get("col"), + "value": value, + }) + if len(pairs) >= 2: + return pairs + return pairs + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _build_note(group_keys, measures, pivots) -> str: + """One-line Spanish summary of the selection.""" + parts = [] + if group_keys: + cols = ", ".join(str(g.get("col")) for g in group_keys) + parts.append(f"{len(group_keys)} clave(s) de grupo: {cols}") + else: + parts.append("sin categóricas agrupables") + if measures: + parts.append(f"{len(measures)} medida(s): " + ", ".join(str(m) for m in measures)) + else: + parts.append("sin medidas numéricas") + if pivots: + parts.append(f"{len(pivots)} pivot(s)") + return "; ".join(parts) + "." + + +def _key_candidate_names(key_candidates) -> set: + """Normalize ``key_candidates`` (strings or ``{name}`` dicts) to a name set.""" + names = set() + if not isinstance(key_candidates, (list, tuple)): + return names + for entry in key_candidates: + if isinstance(entry, str): + names.add(entry) + elif isinstance(entry, dict): + nm = entry.get("name") or entry.get("col") + if nm is not None: + names.add(nm) + return names + + +def _as_set(flags) -> set: + """Coerce a flags value into a set, tolerating None / non-iterables.""" + if isinstance(flags, (list, tuple, set)): + return set(flags) + return set() + + +def _num(value, default: float) -> float: + """Best-effort float conversion with a fallback default.""" + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _clamp(x: float, lo: float, hi: float) -> float: + """Recorta x al rango [lo, hi].""" + if x < lo: + return lo + if x > hi: + return hi + return x diff --git a/python/functions/datascience/select_groupby_keys_test.py b/python/functions/datascience/select_groupby_keys_test.py new file mode 100644 index 00000000..0aaf167d --- /dev/null +++ b/python/functions/datascience/select_groupby_keys_test.py @@ -0,0 +1,213 @@ +"""Tests para select_groupby_keys (grupo eda, dominio datascience).""" + +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +from select_groupby_keys import select_groupby_keys + + +def _cat_col(name, card, *, imbalance=2.0, flags=None, null_pct=0.0): + """ColumnProfile categorico minimo con bloque categorical.""" + return { + "name": name, + "inferred_type": "categorical", + "distinct_count": card, + "unique_pct": card / 1000.0, + "null_pct": null_pct, + "flags": flags or [], + "numeric": None, + "categorical": {"imbalance": imbalance, "mode_pct": 0.5, "n_distinct": card}, + } + + +def _num_col(name, *, std, cv, flags=None, unique_pct=0.1): + """ColumnProfile numerico minimo con bloque numeric.""" + return { + "name": name, + "inferred_type": "numeric", + "distinct_count": 200, + "unique_pct": unique_pct, + "null_pct": 0.0, + "flags": flags or [], + "numeric": {"std": std, "cv": cv}, + "categorical": None, + } + + +def _titanic_like_profile() -> dict: + """Perfil estilo titanic: 2 categoricas buenas, 2 numericas, 1 id, 1 constante.""" + return { + "n_rows": 891, + "key_candidates": ["passenger_id"], + "columns": [ + _cat_col("sex", 2, imbalance=1.8), + _cat_col("pclass", 3, imbalance=2.5), + _num_col("age", std=14.5, cv=0.49), + _num_col("fare", std=49.7, cv=1.54), + # id secuencial: flag possible_id + unique_pct alto. + { + "name": "passenger_id", + "inferred_type": "numeric", + "distinct_count": 891, + "unique_pct": 1.0, + "null_pct": 0.0, + "flags": ["possible_id"], + "numeric": {"std": 257.4, "cv": 0.58}, + "categorical": None, + }, + # columna constante: flag constant + std 0. + { + "name": "embarked_const", + "inferred_type": "categorical", + "distinct_count": 1, + "unique_pct": 0.001, + "null_pct": 0.0, + "flags": ["constant"], + "numeric": None, + "categorical": {"imbalance": 1.0}, + }, + ], + } + + +def test_titanic_picks_good_cats_excludes_id_and_constant(): + out = select_groupby_keys(_titanic_like_profile()) + + # Elige las dos categoricas buenas. + chosen_cols = {g["col"] for g in out["group_keys"]} + assert chosen_cols == {"sex", "pclass"} + + # Excluye la constante y el key_candidate. + assert "embarked_const" not in chosen_cols + assert "passenger_id" not in chosen_cols + + # Cada group key trae col, cardinality y score. + for g in out["group_keys"]: + assert set(g.keys()) == {"col", "cardinality", "score"} + assert isinstance(g["score"], float) + by_col = {g["col"]: g for g in out["group_keys"]} + assert by_col["sex"]["cardinality"] == 2 + assert by_col["pclass"]["cardinality"] == 3 + + # Ordenadas por score descendente. + scores = [g["score"] for g in out["group_keys"]] + assert scores == sorted(scores, reverse=True) + + +def test_titanic_measures_exclude_id_constant_and_keep_numerics(): + out = select_groupby_keys(_titanic_like_profile()) + + # Solo nombres (strings) de numericas informativas, sin el id secuencial. + assert all(isinstance(m, str) for m in out["measures"]) + assert "passenger_id" not in out["measures"] + assert set(out["measures"]) == {"age", "fare"} + + # fare tiene mayor cv (1.54 > 0.49) -> primero. + assert out["measures"][0] == "fare" + + +def test_titanic_generates_one_pivot(): + out = select_groupby_keys(_titanic_like_profile()) + + # Con 2 group keys -> exactamente 1 pivot. + assert len(out["pivots"]) == 1 + pivot = out["pivots"][0] + assert set(pivot.keys()) == {"index", "columns", "value"} + assert {pivot["index"], pivot["columns"]} == {"sex", "pclass"} + # El valor es la primera measure (fare). + assert pivot["value"] == "fare" + + +def test_empty_profile_returns_all_empty_and_does_not_crash(): + out = select_groupby_keys({}) + assert out["group_keys"] == [] + assert out["measures"] == [] + assert out["pivots"] == [] + assert isinstance(out["note"], str) + + +def test_none_profile_does_not_crash(): + out = select_groupby_keys(None) + assert out == { + "group_keys": [], + "measures": [], + "pivots": [], + "note": out["note"], + } + assert isinstance(out["note"], str) + + +def test_only_numerics_yields_empty_group_keys_and_no_pivots(): + profile = { + "n_rows": 500, + "key_candidates": [], + "columns": [ + _num_col("price", std=12.0, cv=0.6), + _num_col("weight", std=3.0, cv=0.2), + ], + } + out = select_groupby_keys(profile) + assert out["group_keys"] == [] + assert out["pivots"] == [] + # Las numericas si se eligen como measures. + assert set(out["measures"]) == {"price", "weight"} + assert out["measures"][0] == "price" # mayor cv. + + +def test_high_cardinality_and_max_card_are_excluded(): + profile = { + "n_rows": 1000, + "key_candidates": [], + "columns": [ + _cat_col("city", 50, flags=["high_cardinality"]), # flag -> fuera. + _cat_col("zone", 35), # card 35 > max_card 20 -> fuera. + _cat_col("region", 5), # valida. + ], + } + out = select_groupby_keys(profile, max_card=20) + assert {g["col"] for g in out["group_keys"]} == {"region"} + + +def test_max_keys_limits_group_keys(): + profile = { + "n_rows": 1000, + "key_candidates": [], + "columns": [ + _cat_col("a", 4, imbalance=1.0), + _cat_col("b", 5, imbalance=1.2), + _cat_col("c", 6, imbalance=1.5), + _cat_col("d", 7, imbalance=2.0), + ], + } + out = select_groupby_keys(profile, max_keys=2) + assert len(out["group_keys"]) == 2 + # Hasta 2 pivots con >=2 keys (aqui exactamente 1 par posible entre 2 keys). + assert len(out["pivots"]) == 1 + + +def test_three_keys_cap_pivots_to_two(): + profile = { + "n_rows": 1000, + "key_candidates": [], + "columns": [ + _cat_col("a", 4, imbalance=1.0), + _cat_col("b", 5, imbalance=1.1), + _cat_col("c", 6, imbalance=1.2), + _num_col("m", std=10.0, cv=0.5), + ], + } + out = select_groupby_keys(profile, max_keys=3) + assert len(out["group_keys"]) == 3 + # 3 keys -> 3 pares posibles, capado a 2. + assert len(out["pivots"]) == 2 + for p in out["pivots"]: + assert p["value"] == "m" + + +def test_does_not_mutate_input(): + profile = _titanic_like_profile() + before = repr(profile) + select_groupby_keys(profile) + assert repr(profile) == before diff --git a/python/functions/datascience/suggest_aggregations_llm.md b/python/functions/datascience/suggest_aggregations_llm.md new file mode 100644 index 00000000..2b4a79fd --- /dev/null +++ b/python/functions/datascience/suggest_aggregations_llm.md @@ -0,0 +1,96 @@ +--- +name: suggest_aggregations_llm +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def suggest_aggregations_llm(profile: dict, candidates: dict, max_aggs: int = 4, model: str = \"claude-haiku-4-5-20251001\") -> dict" +description: "MUST-11.1 del capitulo AGREGACION del AutomaticEDA (grupo eda). Dado el TableProfile de una tabla y los candidatos cuantitativos de select_groupby_keys ({group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}), con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una razon corta cada uno, evitando la explosion combinatoria (no todo contra todo). Privacidad/coste: NO envia filas crudas, solo el resumen AGREGADO de los candidatos (tabla, columnas categoricas con cardinalidad/score, medidas, pivots). Reusa ask_llm del grupo claude-direct (API directa con token OAuth de Claude). Impura, dict-no-throw: NUNCA lanza y SIEMPRE devuelve algo usable; si el LLM falla, el JSON no parsea o no hay seleccion valida, cae a un fallback determinista construido desde los candidatos (source='fallback'). Toda columna que el LLM invente se descarta." +tags: [eda, claude-direct, llm, aggregation, groupby, pivot, datascience, automatic-eda] +params: + - name: profile + desc: "TableProfile del grupo eda. Solo se usa profile['table'] para nombrar la tabla en el prompt; puede ir vacio o sin esa clave (se usa '(tabla sin nombre)')." + - name: candidates + desc: "Salida de select_groupby_keys: {group_keys:[{col, cardinality, score}], measures:[str], pivots:[{index, columns, value}]}. group_keys = columnas categoricas candidatas para GROUP BY; measures = columnas numericas a agregar (sum/avg); pivots = cruces index x columns -> value sugeridos. Cualquier columna que el LLM elija debe existir aqui o se descarta. None o no-dict se trata como vacio." + - name: max_aggs + desc: "Tope de agregaciones a devolver. Default 4. Valores <1 o no-int se normalizan a 4. Limita tanto la seleccion del LLM como el fallback determinista, para evitar la explosion combinatoria." + - name: model + desc: "id del modelo Anthropic a usar en la unica llamada. Default 'claude-haiku-4-5-20251001' (haiku, coste bajo, ~2-3s). Para razones mas finas, pasar p.ej. 'claude-opus-4-8'." +output: "dict dict-no-throw: {status:'ok', source:'llm'|'fallback', aggregations:[{group_by:str, measures:[str], why:str}], pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}. source=='llm' si el LLM produjo al menos una agregacion valida (columnas existentes en candidates); en cualquier otro caso (LLM caido, JSON invalido, seleccion vacia, sin candidatos) source=='fallback' y aggregations/pivots se derivan de candidates con why='selección cuantitativa (sin LLM)'. NUNCA lanza." +uses_functions: [ask_llm_py_core, select_groupby_keys_py_datascience] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +tested: true +tests: ["test_extract_json_object", "test_extract_json_wrapped_in_fences_and_junk", "test_extract_json_non_json_returns_none", "test_validate_aggregations_drops_invalid_columns", "test_llm_path_uses_selection", "test_llm_path_respects_max_aggs", "test_llm_invented_column_is_discarded", "test_fallback_on_empty_llm_response", "test_fallback_on_unparseable_response", "test_fallback_respects_max_aggs", "test_fallback_when_llm_raises", "test_no_candidates_returns_empty_fallback", "test_non_dict_candidates_does_not_raise"] +test_file_path: "python/functions/datascience/suggest_aggregations_llm_test.py" +file_path: "python/functions/datascience/suggest_aggregations_llm.py" +--- + +## Ejemplo + +```python +import sys, os +sys.path.insert(0, os.path.join("python", "functions")) + +from datascience.suggest_aggregations_llm import suggest_aggregations_llm + +profile = {"table": "ventas"} + +# candidates = salida de select_groupby_keys (aqui literal de ejemplo). +candidates = { + "group_keys": [ + {"col": "categoria", "cardinality": 8, "score": 0.91}, + {"col": "region", "cardinality": 5, "score": 0.74}, + {"col": "canal", "cardinality": 3, "score": 0.60}, + ], + "measures": ["importe", "unidades"], + "pivots": [ + {"index": "categoria", "columns": "region", "value": "importe"}, + ], +} + +out = suggest_aggregations_llm(profile, candidates, max_aggs=4) # haiku por defecto + +print("fuente:", out["source"]) # "llm" o "fallback" si no hay red +for agg in out["aggregations"]: + print(f"GROUP BY {agg['group_by']} -> {agg['measures']} ({agg['why']})") +for piv in out["pivots"]: + print(f"pivot {piv['index']} x {piv['columns']} = {piv['value']} ({piv['why']})") +``` + +## Cuando usarla + +Justo despues de `select_groupby_keys` en el capitulo AGREGACION del AutomaticEDA: +cuando ya tienes los candidatos cuantitativos (columnas categoricas con cardinalidad, +medidas numericas y pivots posibles) y quieres que un LLM se quede con las K +agregaciones y pivots MAS INFORMATIVOS en vez de generar "todo contra todo". Usala para +priorizar el plan de analisis de grupos antes de materializar las tablas con +`aggregate_by_group` / pivots, manteniendo el coste y el ruido bajos. Si no hay red o +credenciales, sigue funcionando con un fallback determinista, asi que es seguro +ponerla en un pipeline. + +## Gotchas + +- **Impura: hace 1 llamada de red al LLM.** No es determinista ni gratis. Latencia + tipica ~2-3s con haiku. Una sola llamada cubre toda la seleccion. +- **Requiere token OAuth de Claude** en `~/.claude/.credentials.json` (via `ask_llm` / + grupo `claude-direct`). Sin token / sin red NO lanza: cae al **fallback + determinista** (`source="fallback"`) construido desde `candidates` + (group_keys x measures hasta `max_aggs`, pivots tal cual) con + `why="selección cuantitativa (sin LLM)"`. Comprueba `out["source"]` para saber si la + seleccion vino del LLM o del fallback. +- **NO envia filas crudas al LLM**, solo el resumen AGREGADO de los candidatos. Esto + exige que `candidates` venga ya calculado por `select_groupby_keys` (cardinalidades, + scores, medidas, pivots). +- **Valida columnas inventadas**: si el LLM propone un `group_by`/`measure`/`index`/ + `columns` que no esta en `candidates`, esa entrada se descarta (las medidas se + recortan a las validas). Si tras validar no queda ninguna agregacion, cae al + fallback completo. +- **`max_aggs` acota la explosion combinatoria** tanto en el camino LLM como en el + fallback. Subirlo demasiado reintroduce el ruido que esta funcion evita. +- **Modelo `haiku` por defecto** para coste bajo; sube a `claude-opus-4-8` si necesitas + razones (`why`) mas finas (mas caro y lento). diff --git a/python/functions/datascience/suggest_aggregations_llm.py b/python/functions/datascience/suggest_aggregations_llm.py new file mode 100644 index 00000000..b7fc4ac5 --- /dev/null +++ b/python/functions/datascience/suggest_aggregations_llm.py @@ -0,0 +1,405 @@ +"""suggest_aggregations_llm — el LLM elige las agregaciones mas informativas (grupo `eda`). + +MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Dado el `TableProfile` de una +tabla y los CANDIDATOS cuantitativos que produce `select_groupby_keys` +(`{group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}`), +con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x +medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una +razon corta cada uno. El objetivo es evitar la explosion combinatoria: en vez de +"todo contra todo", el LLM se queda con lo que mas informa. + +Privacidad y coste: NO se envian filas crudas al LLM. El prompt solo lleva el resumen +AGREGADO de los candidatos (nombre de la tabla, columnas categoricas con su +cardinalidad/score, medidas y pivots posibles). Una sola llamada barata. + +Reusa `ask_llm` del registry (grupo claude-direct, API directa con el token OAuth de +Claude en ~/.claude/.credentials.json, arranque 0). Impura: una llamada de red. + +Estilo dict-no-throw con FALLBACK DETERMINISTA: la funcion NUNCA lanza y SIEMPRE +devuelve algo usable. Si `ask_llm` falla (devuelve ""), el JSON no parsea, o el LLM no +produce ninguna seleccion valida, se construye la respuesta directamente desde los +candidatos (group_keys x measures hasta max_aggs, pivots tal cual) con +`source="fallback"`. Ademas, toda columna que el LLM invente (no presente en los +candidatos) se descarta. +""" + +import json + +from core.ask_llm import ask_llm + +_SYSTEM = ( + "Eres un analista de datos conciso. Te dan los CANDIDATOS AGREGADOS de una tabla " + "(columnas categoricas para GROUP BY con su cardinalidad, medidas numericas y " + "pivots posibles) y eliges las agregaciones y pivots MAS INFORMATIVOS para " + "entender los grupos, evitando la explosion combinatoria (no todo contra todo). " + "No recibes filas crudas. Responde en espanol. Responde SIEMPRE y SOLO con un " + "unico objeto JSON valido, sin texto alrededor ni fences de markdown, con la forma " + '{"aggregations": [{"group_by": "", "measures": ["", ...], ' + '"why": ""}], "pivots": [{"index": "", "columns": "", ' + '"value": "", "why": ""}]}. Usa SOLO nombres de columna ' + "que aparezcan en los candidatos; no inventes nombres." +) + + +def _fmt_num(value) -> str: + """Formatea un numero de forma compacta para el prompt (None -> '?').""" + if value is None: + return "?" + if isinstance(value, bool): + return str(value) + if isinstance(value, float): + if value == int(value): + return str(int(value)) + return f"{value:.4g}" + return str(value) + + +def _candidate_view(candidates: dict): + """Extrae las vistas utiles de los candidatos. Funcion interna PURA. + + Devuelve la tupla (group_cols, measures, measure_set, pivots, group_keys): + - group_cols: set de nombres de columna categorica validas (de group_keys[].col). + - measures: lista de medidas numericas (str) preservando orden. + - measure_set: set de las medidas para validar pertenencia rapido. + - pivots: lista de pivots candidatos (dicts) tal cual vienen. + - group_keys: lista de dicts {col, cardinality, score} ya filtrada a entradas validas. + + Tolera estructuras incompletas o de tipo incorrecto sin lanzar. + """ + candidates = candidates if isinstance(candidates, dict) else {} + + gk_raw = candidates.get("group_keys") + group_keys = [] + if isinstance(gk_raw, list): + for gk in gk_raw: + if isinstance(gk, dict) and isinstance(gk.get("col"), str): + group_keys.append(gk) + group_cols = {gk["col"] for gk in group_keys} + + m_raw = candidates.get("measures") + measures = [m for m in m_raw if isinstance(m, str)] if isinstance(m_raw, list) else [] + measure_set = set(measures) + + p_raw = candidates.get("pivots") + pivots = p_raw if isinstance(p_raw, list) else [] + + return group_cols, measures, measure_set, pivots, group_keys + + +def _sorted_group_cols(group_keys: list) -> list: + """Nombres de columna categorica ordenados por score descendente. PURA.""" + + def _score(gk): + s = gk.get("score") + if isinstance(s, (int, float)) and not isinstance(s, bool): + return s + return 0.0 + + return [gk["col"] for gk in sorted(group_keys, key=_score, reverse=True)] + + +def _build_prompt(profile: dict, candidates: dict, max_aggs: int) -> str: + """Construye el prompt compacto SOLO con agregados. Funcion interna PURA. + + No toca red ni disco: testeable sin credenciales. Incluye el nombre de la tabla, + las columnas categoricas candidatas (con cardinalidad y score), las medidas + numericas y los pivots candidatos. Nunca filas crudas. + + Args: + profile: TableProfile (se usa solo profile['table'] para nombrar la tabla). + candidates: salida de select_groupby_keys. + max_aggs: tope de agregaciones a pedir. + + Returns: + El texto del prompt. + """ + profile = profile if isinstance(profile, dict) else {} + candidates = candidates if isinstance(candidates, dict) else {} + + table = profile.get("table") + table = str(table) if table is not None else "(tabla sin nombre)" + + lines = [ + f"Tabla: {table}", + ( + "Tarea: elegir las agregaciones (GROUP BY categorica x medidas numericas) y " + "los pivots MAS INFORMATIVOS para un analisis de grupos. Evita la explosion " + "combinatoria: NO combines todo contra todo, prioriza lo que mas informa." + ), + f"Devuelve a lo sumo {max_aggs} agregaciones.", + "", + "Columnas categoricas candidatas para GROUP BY (col: cardinalidad, score):", + ] + + group_keys = candidates.get("group_keys") or [] + for gk in group_keys: + if not isinstance(gk, dict) or not isinstance(gk.get("col"), str): + continue + lines.append( + f" - {gk['col']}: cardinalidad={_fmt_num(gk.get('cardinality'))}, " + f"score={_fmt_num(gk.get('score'))}" + ) + + measures = candidates.get("measures") or [] + lines.append("") + lines.append("Medidas numericas disponibles (para sum/avg por grupo):") + lines.append(" " + ", ".join(str(m) for m in measures if isinstance(m, str))) + + pivots = candidates.get("pivots") or [] + if pivots: + lines.append("") + lines.append("Pivots candidatos (index x columns -> value):") + for p in pivots: + if not isinstance(p, dict): + continue + lines.append( + f" - index={p.get('index')}, columns={p.get('columns')}, " + f"value={p.get('value')}" + ) + + lines.append("") + lines.append( + "Usa SOLO columnas de las listas anteriores; no inventes nombres. Responde " + "SOLO con el JSON descrito en las instrucciones del sistema." + ) + return "\n".join(lines) + + +def _extract_json(text: str): + """Extrae el primer bloque JSON (objeto o array) de la respuesta. PURA. + + Localiza el bloque que empieza antes (el primer '{' o el primer '[') y, para ese + delimitador, hace json.loads del rango hasta su ultimo cierre. Tolera texto basura + alrededor y fences ```json. NUNCA lanza: ante cualquier fallo devuelve None. + + Args: + text: respuesta cruda del LLM. + + Returns: + El objeto/lista deserializado, o None si no se pudo parsear. + """ + if not text or not isinstance(text, str): + return None + + opens = [] + i_obj = text.find("{") + if i_obj != -1: + opens.append((i_obj, "{", "}")) + i_arr = text.find("[") + if i_arr != -1: + opens.append((i_arr, "[", "]")) + opens.sort() + + for _, open_c, close_c in opens: + start = text.find(open_c) + end = text.rfind(close_c) + if start != -1 and end != -1 and end > start: + try: + return json.loads(text[start : end + 1]) + except (ValueError, TypeError): + continue + return None + + +def _validate_aggregations(raw_aggs, group_cols: set, measure_set: set, max_aggs: int) -> list: + """Filtra las agregaciones del LLM a las que usan SOLO columnas candidatas. PURA. + + Descarta cualquier agregacion cuyo group_by no este en group_cols o que no tenga + al menos una medida valida. Recorta las medidas a las presentes en measure_set. + Limita el resultado a max_aggs entradas. + """ + out = [] + if not isinstance(raw_aggs, list): + return out + for item in raw_aggs: + if not isinstance(item, dict): + continue + gb = item.get("group_by") + if not isinstance(gb, str) or gb not in group_cols: + continue # columna inventada -> se descarta + raw_measures = item.get("measures") + if isinstance(raw_measures, str): + raw_measures = [raw_measures] + if not isinstance(raw_measures, list): + continue + measures = [m for m in raw_measures if isinstance(m, str) and m in measure_set] + if not measures: + continue # sin medidas validas -> agregacion inutil + why = item.get("why") + why = str(why) if why is not None else "" + out.append({"group_by": gb, "measures": measures, "why": why}) + if len(out) >= max_aggs: + break + return out + + +def _validate_pivots(raw_pivots, group_cols: set, measure_set: set) -> list: + """Filtra los pivots del LLM a los que usan SOLO columnas candidatas. PURA. + + Descarta el pivot si index o columns no son columnas categoricas validas. Si el + value no es una medida valida, lo deja en None (un pivot de conteo sigue siendo util). + """ + out = [] + if not isinstance(raw_pivots, list): + return out + for item in raw_pivots: + if not isinstance(item, dict): + continue + idx = item.get("index") + cols = item.get("columns") + if not (isinstance(idx, str) and idx in group_cols): + continue + if not (isinstance(cols, str) and cols in group_cols): + continue + val = item.get("value") + if not (isinstance(val, str) and val in measure_set): + val = None + why = item.get("why") + why = str(why) if why is not None else "" + out.append({"index": idx, "columns": cols, "value": val, "why": why}) + return out + + +def _fallback_aggregations(group_cols_sorted: list, measures: list, max_aggs: int) -> list: + """Agregaciones deterministas: cada columna categorica x todas las medidas. PURA.""" + out = [] + for col in group_cols_sorted: + out.append( + { + "group_by": col, + "measures": list(measures), + "why": "selección cuantitativa (sin LLM)", + } + ) + if len(out) >= max_aggs: + break + return out + + +def _fallback_pivots(cand_pivots: list) -> list: + """Normaliza los pivots candidatos a la forma de salida (tal cual + why). PURA.""" + out = [] + if not isinstance(cand_pivots, list): + return out + for p in cand_pivots: + if not isinstance(p, dict): + continue + idx = p.get("index") + cols = p.get("columns") + if not (isinstance(idx, str) and isinstance(cols, str)): + continue + val = p.get("value") + if not isinstance(val, str): + val = None + out.append( + { + "index": idx, + "columns": cols, + "value": val, + "why": "selección cuantitativa (sin LLM)", + } + ) + return out + + +def suggest_aggregations_llm( + profile: dict, + candidates: dict, + max_aggs: int = 4, + model: str = "claude-haiku-4-5-20251001", +) -> dict: + """Elige las agregaciones y pivots mas informativos con UNA llamada al LLM. + + MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Toma el perfil de la tabla y + los candidatos cuantitativos (salida de select_groupby_keys) y deja que el LLM + seleccione/ordene las K agregaciones (GROUP BY categorica x medidas) y los pivots + mas utiles, con una razon corta cada uno, evitando la explosion combinatoria. + + Privacidad/coste: solo viaja al LLM el resumen AGREGADO de los candidatos, nunca + filas crudas. Una sola llamada barata. + + dict-no-throw con fallback determinista: NUNCA lanza. Si el LLM falla, el JSON no + parsea, o no produce seleccion valida -> construye la respuesta desde los candidatos + (group_keys x measures hasta max_aggs, pivots tal cual) con source="fallback". Las + columnas que el LLM invente (no presentes en los candidatos) se descartan. + + Args: + profile: TableProfile del grupo eda. Solo se usa profile['table'] para nombrar + la tabla en el prompt; puede ir vacio. + candidates: salida de select_groupby_keys, con la forma + {group_keys:[{col,cardinality,score}], measures:[str], + pivots:[{index,columns,value}]}. + max_aggs: tope de agregaciones a devolver. Default 4. Valores <1 o no-int se + normalizan a 4. + model: id del modelo Anthropic. Default 'claude-haiku-4-5-20251001' (haiku, + coste bajo, ~2-3s). + + Returns: + dict {status:"ok", source:"llm"|"fallback", + aggregations:[{group_by:str, measures:[str], why:str}], + pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}. + source=="llm" si el LLM produjo al menos una agregacion valida; en cualquier + otro caso "fallback". NUNCA lanza. + """ + if not isinstance(candidates, dict): + candidates = {} + if isinstance(max_aggs, bool) or not isinstance(max_aggs, int) or max_aggs < 1: + max_aggs = 4 + + group_cols, measures, measure_set, cand_pivots, group_keys = _candidate_view(candidates) + group_cols_sorted = _sorted_group_cols(group_keys) + + # Sin material suficiente para agregar: no merece la pena llamar al LLM. + if not group_cols or not measures: + return { + "status": "ok", + "source": "fallback", + "aggregations": [], + "pivots": _fallback_pivots(cand_pivots), + "note": "sin candidatos suficientes para agregar", + } + + prompt = _build_prompt(profile, candidates, max_aggs) + try: + text = ask_llm(prompt, model=model, system=_SYSTEM, echo=False) + except Exception: # noqa: BLE001 — degradacion: cualquier fallo de red/LLM. + text = "" + + parsed = _extract_json(text) + if parsed is not None: + if isinstance(parsed, dict): + raw_aggs = parsed.get("aggregations") + raw_pivots = parsed.get("pivots") + elif isinstance(parsed, list): + raw_aggs = parsed + raw_pivots = None + else: + raw_aggs = None + raw_pivots = None + + aggs = _validate_aggregations(raw_aggs, group_cols, measure_set, max_aggs) + if aggs: + pivots = _validate_pivots(raw_pivots, group_cols, measure_set) + if not pivots: + pivots = _fallback_pivots(cand_pivots) + return { + "status": "ok", + "source": "llm", + "aggregations": aggs, + "pivots": pivots, + "note": f"{len(aggs)} agregaciones y {len(pivots)} pivots seleccionados por el LLM", + } + + # Fallback determinista. + note = ( + "LLM no disponible; selección cuantitativa determinista" + if not text + else "LLM sin selección válida; selección cuantitativa determinista" + ) + return { + "status": "ok", + "source": "fallback", + "aggregations": _fallback_aggregations(group_cols_sorted, measures, max_aggs), + "pivots": _fallback_pivots(cand_pivots), + "note": note, + } diff --git a/python/functions/datascience/suggest_aggregations_llm_test.py b/python/functions/datascience/suggest_aggregations_llm_test.py new file mode 100644 index 00000000..29a4f2a9 --- /dev/null +++ b/python/functions/datascience/suggest_aggregations_llm_test.py @@ -0,0 +1,198 @@ +"""Tests para suggest_aggregations_llm. + +NO acceden a red ni a credenciales: las funciones internas (_build_prompt, +_extract_json, _validate_*, _fallback_*) son puras y testeables aisladas; la unica +via que llamaria al LLM (suggest_aggregations_llm) se prueba reemplazando el simbolo +`ask_llm` del modulo bajo prueba con una funcion simulada. Los candidatos van +literales en el test: NO se importa select_groupby_keys. + +Cubre golden (LLM ok con columnas validas), edge (max_aggs respetado, sin candidatos) +y error (LLM caido -> fallback, JSON invalido -> fallback, columna inventada -> se +descarta). Todos sin tocar la red. +""" + +import json + +import datascience.suggest_aggregations_llm as M +from datascience.suggest_aggregations_llm import ( + _extract_json, + _validate_aggregations, + suggest_aggregations_llm, +) + +# Candidatos de ejemplo con la forma que produce select_groupby_keys (literales). +_CANDIDATES = { + "group_keys": [ + {"col": "categoria", "cardinality": 8, "score": 0.91}, + {"col": "region", "cardinality": 5, "score": 0.74}, + {"col": "canal", "cardinality": 3, "score": 0.60}, + ], + "measures": ["importe", "unidades"], + "pivots": [ + {"index": "categoria", "columns": "region", "value": "importe"}, + ], +} +_PROFILE = {"table": "ventas"} + + +def _fake_returner(text): + """Devuelve un ask_llm simulado que ignora args y retorna `text`.""" + + def _fake(prompt, model="x", system="", echo=True, **kwargs): + return text + + return _fake + + +# --- _extract_json (parser puro, sin red) --- + + +def test_extract_json_object(): + obj = {"aggregations": [{"group_by": "categoria", "measures": ["importe"], "why": "x"}]} + assert _extract_json(json.dumps(obj)) == obj + + +def test_extract_json_wrapped_in_fences_and_junk(): + obj = {"aggregations": [], "pivots": []} + text = "Claro, aqui tienes:\n```json\n" + json.dumps(obj) + "\n```\nFin." + assert _extract_json(text) == obj + + +def test_extract_json_non_json_returns_none(): + assert _extract_json("no hay json aqui") is None + assert _extract_json("") is None + assert _extract_json(None) is None + + +# --- _validate_aggregations (puro) --- + + +def test_validate_aggregations_drops_invalid_columns(): + group_cols = {"categoria", "region"} + measure_set = {"importe", "unidades"} + raw = [ + {"group_by": "categoria", "measures": ["importe", "inventada"], "why": "ok"}, + {"group_by": "no_existe", "measures": ["importe"], "why": "mala"}, + {"group_by": "region", "measures": ["solo_inventada"], "why": "sin medidas"}, + ] + out = _validate_aggregations(raw, group_cols, measure_set, max_aggs=4) + # Solo sobrevive la primera, con las medidas recortadas a las validas. + assert out == [{"group_by": "categoria", "measures": ["importe"], "why": "ok"}] + + +# --- suggest_aggregations_llm: camino LLM (golden) --- + + +def test_llm_path_uses_selection(monkeypatch): + llm_obj = { + "aggregations": [ + {"group_by": "categoria", "measures": ["importe"], "why": "ventas por familia"}, + {"group_by": "region", "measures": ["importe", "unidades"], "why": "reparto geografico"}, + ], + "pivots": [ + {"index": "categoria", "columns": "region", "value": "importe", "why": "cruce clave"}, + ], + } + monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj))) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES) + assert out["status"] == "ok" + assert out["source"] == "llm" + assert out["aggregations"] == llm_obj["aggregations"] + assert out["pivots"][0]["index"] == "categoria" + assert out["pivots"][0]["why"] == "cruce clave" + + +def test_llm_path_respects_max_aggs(monkeypatch): + llm_obj = { + "aggregations": [ + {"group_by": "categoria", "measures": ["importe"], "why": "a"}, + {"group_by": "region", "measures": ["importe"], "why": "b"}, + {"group_by": "canal", "measures": ["unidades"], "why": "c"}, + ], + "pivots": [], + } + monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj))) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=2) + assert out["source"] == "llm" + assert len(out["aggregations"]) == 2 + + +def test_llm_invented_column_is_discarded(monkeypatch): + # El LLM mezcla una agregacion valida con otra de columna inexistente. + llm_obj = { + "aggregations": [ + {"group_by": "categoria", "measures": ["importe"], "why": "valida"}, + {"group_by": "columna_fantasma", "measures": ["importe"], "why": "inventada"}, + ], + "pivots": [ + {"index": "fantasma", "columns": "region", "value": "importe", "why": "mala"}, + ], + } + monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj))) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES) + assert out["source"] == "llm" + # La agregacion inventada se descarta; queda solo la valida. + assert [a["group_by"] for a in out["aggregations"]] == ["categoria"] + # El pivot con index fantasma se descarta -> cae a los pivots de candidates. + assert all(p["index"] in {"categoria", "region", "canal"} for p in out["pivots"]) + + +# --- suggest_aggregations_llm: fallback determinista (error paths) --- + + +def test_fallback_on_empty_llm_response(monkeypatch): + monkeypatch.setattr(M, "ask_llm", _fake_returner("")) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=4) + assert out["status"] == "ok" + assert out["source"] == "fallback" + # Las agregaciones se derivan de candidates (una por group_key, con todas las medidas). + assert out["aggregations"][0]["group_by"] in {"categoria", "region", "canal"} + assert out["aggregations"][0]["measures"] == ["importe", "unidades"] + assert out["aggregations"][0]["why"] == "selección cuantitativa (sin LLM)" + # Pivots tal cual de candidates. + assert out["pivots"][0]["index"] == "categoria" + + +def test_fallback_on_unparseable_response(monkeypatch): + monkeypatch.setattr(M, "ask_llm", _fake_returner("esto no es JSON {roto")) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES) + assert out["source"] == "fallback" + assert len(out["aggregations"]) >= 1 + + +def test_fallback_respects_max_aggs(monkeypatch): + monkeypatch.setattr(M, "ask_llm", _fake_returner("")) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=2) + assert out["source"] == "fallback" + assert len(out["aggregations"]) == 2 + + +def test_fallback_when_llm_raises(monkeypatch): + def _boom(*args, **kwargs): + raise RuntimeError("sin red") + + monkeypatch.setattr(M, "ask_llm", _boom) + + out = suggest_aggregations_llm(_PROFILE, _CANDIDATES) + assert out["source"] == "fallback" + assert out["aggregations"] # no vacio, no lanza + + +def test_no_candidates_returns_empty_fallback(): + # Sin red porque ni siquiera se llama al LLM (no hay material). + out = suggest_aggregations_llm(_PROFILE, {"group_keys": [], "measures": [], "pivots": []}) + assert out["status"] == "ok" + assert out["source"] == "fallback" + assert out["aggregations"] == [] + + +def test_non_dict_candidates_does_not_raise(): + out = suggest_aggregations_llm(_PROFILE, None) + assert out["status"] == "ok" + assert out["aggregations"] == []