Files
fn_registry/python/functions/datascience/suggest_aggregations_llm.py
T
egutierrez 96da9e3015 feat(eda): funciones de agregación/OLAP para AutomaticEDA (groupby/pivot push-down + selección LLM)
Cuatro funciones nuevas del grupo eda que nutren el capítulo AGREGACION:
- select_groupby_keys (pure): elige categóricas agrupables + numéricas medida desde el TableProfile.
- groupby_stats_duckdb (impure): GROUP BY push-down en DuckDB (count/mean/median/std/min/max por grupo).
- pivot_table_duckdb (impure): pivot A×B push-down, limitado a top filas/cols para no cortar.
- suggest_aggregations_llm (impure): el LLM elige las agregaciones interesantes con fallback determinista.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-30 15:33:55 +02:00

406 lines
16 KiB
Python

"""suggest_aggregations_llm — el LLM elige las agregaciones mas informativas (grupo `eda`).
MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Dado el `TableProfile` de una
tabla y los CANDIDATOS cuantitativos que produce `select_groupby_keys`
(`{group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}`),
con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x
medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una
razon corta cada uno. El objetivo es evitar la explosion combinatoria: en vez de
"todo contra todo", el LLM se queda con lo que mas informa.
Privacidad y coste: NO se envian filas crudas al LLM. El prompt solo lleva el resumen
AGREGADO de los candidatos (nombre de la tabla, columnas categoricas con su
cardinalidad/score, medidas y pivots posibles). Una sola llamada barata.
Reusa `ask_llm` del registry (grupo claude-direct, API directa con el token OAuth de
Claude en ~/.claude/.credentials.json, arranque 0). Impura: una llamada de red.
Estilo dict-no-throw con FALLBACK DETERMINISTA: la funcion NUNCA lanza y SIEMPRE
devuelve algo usable. Si `ask_llm` falla (devuelve ""), el JSON no parsea, o el LLM no
produce ninguna seleccion valida, se construye la respuesta directamente desde los
candidatos (group_keys x measures hasta max_aggs, pivots tal cual) con
`source="fallback"`. Ademas, toda columna que el LLM invente (no presente en los
candidatos) se descarta.
"""
import json
from core.ask_llm import ask_llm
_SYSTEM = (
"Eres un analista de datos conciso. Te dan los CANDIDATOS AGREGADOS de una tabla "
"(columnas categoricas para GROUP BY con su cardinalidad, medidas numericas y "
"pivots posibles) y eliges las agregaciones y pivots MAS INFORMATIVOS para "
"entender los grupos, evitando la explosion combinatoria (no todo contra todo). "
"No recibes filas crudas. Responde en espanol. Responde SIEMPRE y SOLO con un "
"unico objeto JSON valido, sin texto alrededor ni fences de markdown, con la forma "
'{"aggregations": [{"group_by": "<col categorica>", "measures": ["<medida>", ...], '
'"why": "<razon corta>"}], "pivots": [{"index": "<col>", "columns": "<col>", '
'"value": "<medida o null>", "why": "<razon corta>"}]}. Usa SOLO nombres de columna '
"que aparezcan en los candidatos; no inventes nombres."
)
def _fmt_num(value) -> str:
"""Formatea un numero de forma compacta para el prompt (None -> '?')."""
if value is None:
return "?"
if isinstance(value, bool):
return str(value)
if isinstance(value, float):
if value == int(value):
return str(int(value))
return f"{value:.4g}"
return str(value)
def _candidate_view(candidates: dict):
"""Extrae las vistas utiles de los candidatos. Funcion interna PURA.
Devuelve la tupla (group_cols, measures, measure_set, pivots, group_keys):
- group_cols: set de nombres de columna categorica validas (de group_keys[].col).
- measures: lista de medidas numericas (str) preservando orden.
- measure_set: set de las medidas para validar pertenencia rapido.
- pivots: lista de pivots candidatos (dicts) tal cual vienen.
- group_keys: lista de dicts {col, cardinality, score} ya filtrada a entradas validas.
Tolera estructuras incompletas o de tipo incorrecto sin lanzar.
"""
candidates = candidates if isinstance(candidates, dict) else {}
gk_raw = candidates.get("group_keys")
group_keys = []
if isinstance(gk_raw, list):
for gk in gk_raw:
if isinstance(gk, dict) and isinstance(gk.get("col"), str):
group_keys.append(gk)
group_cols = {gk["col"] for gk in group_keys}
m_raw = candidates.get("measures")
measures = [m for m in m_raw if isinstance(m, str)] if isinstance(m_raw, list) else []
measure_set = set(measures)
p_raw = candidates.get("pivots")
pivots = p_raw if isinstance(p_raw, list) else []
return group_cols, measures, measure_set, pivots, group_keys
def _sorted_group_cols(group_keys: list) -> list:
"""Nombres de columna categorica ordenados por score descendente. PURA."""
def _score(gk):
s = gk.get("score")
if isinstance(s, (int, float)) and not isinstance(s, bool):
return s
return 0.0
return [gk["col"] for gk in sorted(group_keys, key=_score, reverse=True)]
def _build_prompt(profile: dict, candidates: dict, max_aggs: int) -> str:
"""Construye el prompt compacto SOLO con agregados. Funcion interna PURA.
No toca red ni disco: testeable sin credenciales. Incluye el nombre de la tabla,
las columnas categoricas candidatas (con cardinalidad y score), las medidas
numericas y los pivots candidatos. Nunca filas crudas.
Args:
profile: TableProfile (se usa solo profile['table'] para nombrar la tabla).
candidates: salida de select_groupby_keys.
max_aggs: tope de agregaciones a pedir.
Returns:
El texto del prompt.
"""
profile = profile if isinstance(profile, dict) else {}
candidates = candidates if isinstance(candidates, dict) else {}
table = profile.get("table")
table = str(table) if table is not None else "(tabla sin nombre)"
lines = [
f"Tabla: {table}",
(
"Tarea: elegir las agregaciones (GROUP BY categorica x medidas numericas) y "
"los pivots MAS INFORMATIVOS para un analisis de grupos. Evita la explosion "
"combinatoria: NO combines todo contra todo, prioriza lo que mas informa."
),
f"Devuelve a lo sumo {max_aggs} agregaciones.",
"",
"Columnas categoricas candidatas para GROUP BY (col: cardinalidad, score):",
]
group_keys = candidates.get("group_keys") or []
for gk in group_keys:
if not isinstance(gk, dict) or not isinstance(gk.get("col"), str):
continue
lines.append(
f" - {gk['col']}: cardinalidad={_fmt_num(gk.get('cardinality'))}, "
f"score={_fmt_num(gk.get('score'))}"
)
measures = candidates.get("measures") or []
lines.append("")
lines.append("Medidas numericas disponibles (para sum/avg por grupo):")
lines.append(" " + ", ".join(str(m) for m in measures if isinstance(m, str)))
pivots = candidates.get("pivots") or []
if pivots:
lines.append("")
lines.append("Pivots candidatos (index x columns -> value):")
for p in pivots:
if not isinstance(p, dict):
continue
lines.append(
f" - index={p.get('index')}, columns={p.get('columns')}, "
f"value={p.get('value')}"
)
lines.append("")
lines.append(
"Usa SOLO columnas de las listas anteriores; no inventes nombres. Responde "
"SOLO con el JSON descrito en las instrucciones del sistema."
)
return "\n".join(lines)
def _extract_json(text: str):
"""Extrae el primer bloque JSON (objeto o array) de la respuesta. PURA.
Localiza el bloque que empieza antes (el primer '{' o el primer '[') y, para ese
delimitador, hace json.loads del rango hasta su ultimo cierre. Tolera texto basura
alrededor y fences ```json. NUNCA lanza: ante cualquier fallo devuelve None.
Args:
text: respuesta cruda del LLM.
Returns:
El objeto/lista deserializado, o None si no se pudo parsear.
"""
if not text or not isinstance(text, str):
return None
opens = []
i_obj = text.find("{")
if i_obj != -1:
opens.append((i_obj, "{", "}"))
i_arr = text.find("[")
if i_arr != -1:
opens.append((i_arr, "[", "]"))
opens.sort()
for _, open_c, close_c in opens:
start = text.find(open_c)
end = text.rfind(close_c)
if start != -1 and end != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except (ValueError, TypeError):
continue
return None
def _validate_aggregations(raw_aggs, group_cols: set, measure_set: set, max_aggs: int) -> list:
"""Filtra las agregaciones del LLM a las que usan SOLO columnas candidatas. PURA.
Descarta cualquier agregacion cuyo group_by no este en group_cols o que no tenga
al menos una medida valida. Recorta las medidas a las presentes en measure_set.
Limita el resultado a max_aggs entradas.
"""
out = []
if not isinstance(raw_aggs, list):
return out
for item in raw_aggs:
if not isinstance(item, dict):
continue
gb = item.get("group_by")
if not isinstance(gb, str) or gb not in group_cols:
continue # columna inventada -> se descarta
raw_measures = item.get("measures")
if isinstance(raw_measures, str):
raw_measures = [raw_measures]
if not isinstance(raw_measures, list):
continue
measures = [m for m in raw_measures if isinstance(m, str) and m in measure_set]
if not measures:
continue # sin medidas validas -> agregacion inutil
why = item.get("why")
why = str(why) if why is not None else ""
out.append({"group_by": gb, "measures": measures, "why": why})
if len(out) >= max_aggs:
break
return out
def _validate_pivots(raw_pivots, group_cols: set, measure_set: set) -> list:
"""Filtra los pivots del LLM a los que usan SOLO columnas candidatas. PURA.
Descarta el pivot si index o columns no son columnas categoricas validas. Si el
value no es una medida valida, lo deja en None (un pivot de conteo sigue siendo util).
"""
out = []
if not isinstance(raw_pivots, list):
return out
for item in raw_pivots:
if not isinstance(item, dict):
continue
idx = item.get("index")
cols = item.get("columns")
if not (isinstance(idx, str) and idx in group_cols):
continue
if not (isinstance(cols, str) and cols in group_cols):
continue
val = item.get("value")
if not (isinstance(val, str) and val in measure_set):
val = None
why = item.get("why")
why = str(why) if why is not None else ""
out.append({"index": idx, "columns": cols, "value": val, "why": why})
return out
def _fallback_aggregations(group_cols_sorted: list, measures: list, max_aggs: int) -> list:
"""Agregaciones deterministas: cada columna categorica x todas las medidas. PURA."""
out = []
for col in group_cols_sorted:
out.append(
{
"group_by": col,
"measures": list(measures),
"why": "selección cuantitativa (sin LLM)",
}
)
if len(out) >= max_aggs:
break
return out
def _fallback_pivots(cand_pivots: list) -> list:
"""Normaliza los pivots candidatos a la forma de salida (tal cual + why). PURA."""
out = []
if not isinstance(cand_pivots, list):
return out
for p in cand_pivots:
if not isinstance(p, dict):
continue
idx = p.get("index")
cols = p.get("columns")
if not (isinstance(idx, str) and isinstance(cols, str)):
continue
val = p.get("value")
if not isinstance(val, str):
val = None
out.append(
{
"index": idx,
"columns": cols,
"value": val,
"why": "selección cuantitativa (sin LLM)",
}
)
return out
def suggest_aggregations_llm(
profile: dict,
candidates: dict,
max_aggs: int = 4,
model: str = "claude-haiku-4-5-20251001",
) -> dict:
"""Elige las agregaciones y pivots mas informativos con UNA llamada al LLM.
MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Toma el perfil de la tabla y
los candidatos cuantitativos (salida de select_groupby_keys) y deja que el LLM
seleccione/ordene las K agregaciones (GROUP BY categorica x medidas) y los pivots
mas utiles, con una razon corta cada uno, evitando la explosion combinatoria.
Privacidad/coste: solo viaja al LLM el resumen AGREGADO de los candidatos, nunca
filas crudas. Una sola llamada barata.
dict-no-throw con fallback determinista: NUNCA lanza. Si el LLM falla, el JSON no
parsea, o no produce seleccion valida -> construye la respuesta desde los candidatos
(group_keys x measures hasta max_aggs, pivots tal cual) con source="fallback". Las
columnas que el LLM invente (no presentes en los candidatos) se descartan.
Args:
profile: TableProfile del grupo eda. Solo se usa profile['table'] para nombrar
la tabla en el prompt; puede ir vacio.
candidates: salida de select_groupby_keys, con la forma
{group_keys:[{col,cardinality,score}], measures:[str],
pivots:[{index,columns,value}]}.
max_aggs: tope de agregaciones a devolver. Default 4. Valores <1 o no-int se
normalizan a 4.
model: id del modelo Anthropic. Default 'claude-haiku-4-5-20251001' (haiku,
coste bajo, ~2-3s).
Returns:
dict {status:"ok", source:"llm"|"fallback",
aggregations:[{group_by:str, measures:[str], why:str}],
pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}.
source=="llm" si el LLM produjo al menos una agregacion valida; en cualquier
otro caso "fallback". NUNCA lanza.
"""
if not isinstance(candidates, dict):
candidates = {}
if isinstance(max_aggs, bool) or not isinstance(max_aggs, int) or max_aggs < 1:
max_aggs = 4
group_cols, measures, measure_set, cand_pivots, group_keys = _candidate_view(candidates)
group_cols_sorted = _sorted_group_cols(group_keys)
# Sin material suficiente para agregar: no merece la pena llamar al LLM.
if not group_cols or not measures:
return {
"status": "ok",
"source": "fallback",
"aggregations": [],
"pivots": _fallback_pivots(cand_pivots),
"note": "sin candidatos suficientes para agregar",
}
prompt = _build_prompt(profile, candidates, max_aggs)
try:
text = ask_llm(prompt, model=model, system=_SYSTEM, echo=False)
except Exception: # noqa: BLE001 — degradacion: cualquier fallo de red/LLM.
text = ""
parsed = _extract_json(text)
if parsed is not None:
if isinstance(parsed, dict):
raw_aggs = parsed.get("aggregations")
raw_pivots = parsed.get("pivots")
elif isinstance(parsed, list):
raw_aggs = parsed
raw_pivots = None
else:
raw_aggs = None
raw_pivots = None
aggs = _validate_aggregations(raw_aggs, group_cols, measure_set, max_aggs)
if aggs:
pivots = _validate_pivots(raw_pivots, group_cols, measure_set)
if not pivots:
pivots = _fallback_pivots(cand_pivots)
return {
"status": "ok",
"source": "llm",
"aggregations": aggs,
"pivots": pivots,
"note": f"{len(aggs)} agregaciones y {len(pivots)} pivots seleccionados por el LLM",
}
# Fallback determinista.
note = (
"LLM no disponible; selección cuantitativa determinista"
if not text
else "LLM sin selección válida; selección cuantitativa determinista"
)
return {
"status": "ok",
"source": "fallback",
"aggregations": _fallback_aggregations(group_cols_sorted, measures, max_aggs),
"pivots": _fallback_pivots(cand_pivots),
"note": note,
}