96da9e3015
Cuatro funciones nuevas del grupo eda que nutren el capítulo AGREGACION: - select_groupby_keys (pure): elige categóricas agrupables + numéricas medida desde el TableProfile. - groupby_stats_duckdb (impure): GROUP BY push-down en DuckDB (count/mean/median/std/min/max por grupo). - pivot_table_duckdb (impure): pivot A×B push-down, limitado a top filas/cols para no cortar. - suggest_aggregations_llm (impure): el LLM elige las agregaciones interesantes con fallback determinista. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
406 lines
16 KiB
Python
406 lines
16 KiB
Python
"""suggest_aggregations_llm — el LLM elige las agregaciones mas informativas (grupo `eda`).
|
|
|
|
MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Dado el `TableProfile` de una
|
|
tabla y los CANDIDATOS cuantitativos que produce `select_groupby_keys`
|
|
(`{group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}`),
|
|
con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x
|
|
medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una
|
|
razon corta cada uno. El objetivo es evitar la explosion combinatoria: en vez de
|
|
"todo contra todo", el LLM se queda con lo que mas informa.
|
|
|
|
Privacidad y coste: NO se envian filas crudas al LLM. El prompt solo lleva el resumen
|
|
AGREGADO de los candidatos (nombre de la tabla, columnas categoricas con su
|
|
cardinalidad/score, medidas y pivots posibles). Una sola llamada barata.
|
|
|
|
Reusa `ask_llm` del registry (grupo claude-direct, API directa con el token OAuth de
|
|
Claude en ~/.claude/.credentials.json, arranque 0). Impura: una llamada de red.
|
|
|
|
Estilo dict-no-throw con FALLBACK DETERMINISTA: la funcion NUNCA lanza y SIEMPRE
|
|
devuelve algo usable. Si `ask_llm` falla (devuelve ""), el JSON no parsea, o el LLM no
|
|
produce ninguna seleccion valida, se construye la respuesta directamente desde los
|
|
candidatos (group_keys x measures hasta max_aggs, pivots tal cual) con
|
|
`source="fallback"`. Ademas, toda columna que el LLM invente (no presente en los
|
|
candidatos) se descarta.
|
|
"""
|
|
|
|
import json
|
|
|
|
from core.ask_llm import ask_llm
|
|
|
|
_SYSTEM = (
|
|
"Eres un analista de datos conciso. Te dan los CANDIDATOS AGREGADOS de una tabla "
|
|
"(columnas categoricas para GROUP BY con su cardinalidad, medidas numericas y "
|
|
"pivots posibles) y eliges las agregaciones y pivots MAS INFORMATIVOS para "
|
|
"entender los grupos, evitando la explosion combinatoria (no todo contra todo). "
|
|
"No recibes filas crudas. Responde en espanol. Responde SIEMPRE y SOLO con un "
|
|
"unico objeto JSON valido, sin texto alrededor ni fences de markdown, con la forma "
|
|
'{"aggregations": [{"group_by": "<col categorica>", "measures": ["<medida>", ...], '
|
|
'"why": "<razon corta>"}], "pivots": [{"index": "<col>", "columns": "<col>", '
|
|
'"value": "<medida o null>", "why": "<razon corta>"}]}. Usa SOLO nombres de columna '
|
|
"que aparezcan en los candidatos; no inventes nombres."
|
|
)
|
|
|
|
|
|
def _fmt_num(value) -> str:
|
|
"""Formatea un numero de forma compacta para el prompt (None -> '?')."""
|
|
if value is None:
|
|
return "?"
|
|
if isinstance(value, bool):
|
|
return str(value)
|
|
if isinstance(value, float):
|
|
if value == int(value):
|
|
return str(int(value))
|
|
return f"{value:.4g}"
|
|
return str(value)
|
|
|
|
|
|
def _candidate_view(candidates: dict):
|
|
"""Extrae las vistas utiles de los candidatos. Funcion interna PURA.
|
|
|
|
Devuelve la tupla (group_cols, measures, measure_set, pivots, group_keys):
|
|
- group_cols: set de nombres de columna categorica validas (de group_keys[].col).
|
|
- measures: lista de medidas numericas (str) preservando orden.
|
|
- measure_set: set de las medidas para validar pertenencia rapido.
|
|
- pivots: lista de pivots candidatos (dicts) tal cual vienen.
|
|
- group_keys: lista de dicts {col, cardinality, score} ya filtrada a entradas validas.
|
|
|
|
Tolera estructuras incompletas o de tipo incorrecto sin lanzar.
|
|
"""
|
|
candidates = candidates if isinstance(candidates, dict) else {}
|
|
|
|
gk_raw = candidates.get("group_keys")
|
|
group_keys = []
|
|
if isinstance(gk_raw, list):
|
|
for gk in gk_raw:
|
|
if isinstance(gk, dict) and isinstance(gk.get("col"), str):
|
|
group_keys.append(gk)
|
|
group_cols = {gk["col"] for gk in group_keys}
|
|
|
|
m_raw = candidates.get("measures")
|
|
measures = [m for m in m_raw if isinstance(m, str)] if isinstance(m_raw, list) else []
|
|
measure_set = set(measures)
|
|
|
|
p_raw = candidates.get("pivots")
|
|
pivots = p_raw if isinstance(p_raw, list) else []
|
|
|
|
return group_cols, measures, measure_set, pivots, group_keys
|
|
|
|
|
|
def _sorted_group_cols(group_keys: list) -> list:
|
|
"""Nombres de columna categorica ordenados por score descendente. PURA."""
|
|
|
|
def _score(gk):
|
|
s = gk.get("score")
|
|
if isinstance(s, (int, float)) and not isinstance(s, bool):
|
|
return s
|
|
return 0.0
|
|
|
|
return [gk["col"] for gk in sorted(group_keys, key=_score, reverse=True)]
|
|
|
|
|
|
def _build_prompt(profile: dict, candidates: dict, max_aggs: int) -> str:
|
|
"""Construye el prompt compacto SOLO con agregados. Funcion interna PURA.
|
|
|
|
No toca red ni disco: testeable sin credenciales. Incluye el nombre de la tabla,
|
|
las columnas categoricas candidatas (con cardinalidad y score), las medidas
|
|
numericas y los pivots candidatos. Nunca filas crudas.
|
|
|
|
Args:
|
|
profile: TableProfile (se usa solo profile['table'] para nombrar la tabla).
|
|
candidates: salida de select_groupby_keys.
|
|
max_aggs: tope de agregaciones a pedir.
|
|
|
|
Returns:
|
|
El texto del prompt.
|
|
"""
|
|
profile = profile if isinstance(profile, dict) else {}
|
|
candidates = candidates if isinstance(candidates, dict) else {}
|
|
|
|
table = profile.get("table")
|
|
table = str(table) if table is not None else "(tabla sin nombre)"
|
|
|
|
lines = [
|
|
f"Tabla: {table}",
|
|
(
|
|
"Tarea: elegir las agregaciones (GROUP BY categorica x medidas numericas) y "
|
|
"los pivots MAS INFORMATIVOS para un analisis de grupos. Evita la explosion "
|
|
"combinatoria: NO combines todo contra todo, prioriza lo que mas informa."
|
|
),
|
|
f"Devuelve a lo sumo {max_aggs} agregaciones.",
|
|
"",
|
|
"Columnas categoricas candidatas para GROUP BY (col: cardinalidad, score):",
|
|
]
|
|
|
|
group_keys = candidates.get("group_keys") or []
|
|
for gk in group_keys:
|
|
if not isinstance(gk, dict) or not isinstance(gk.get("col"), str):
|
|
continue
|
|
lines.append(
|
|
f" - {gk['col']}: cardinalidad={_fmt_num(gk.get('cardinality'))}, "
|
|
f"score={_fmt_num(gk.get('score'))}"
|
|
)
|
|
|
|
measures = candidates.get("measures") or []
|
|
lines.append("")
|
|
lines.append("Medidas numericas disponibles (para sum/avg por grupo):")
|
|
lines.append(" " + ", ".join(str(m) for m in measures if isinstance(m, str)))
|
|
|
|
pivots = candidates.get("pivots") or []
|
|
if pivots:
|
|
lines.append("")
|
|
lines.append("Pivots candidatos (index x columns -> value):")
|
|
for p in pivots:
|
|
if not isinstance(p, dict):
|
|
continue
|
|
lines.append(
|
|
f" - index={p.get('index')}, columns={p.get('columns')}, "
|
|
f"value={p.get('value')}"
|
|
)
|
|
|
|
lines.append("")
|
|
lines.append(
|
|
"Usa SOLO columnas de las listas anteriores; no inventes nombres. Responde "
|
|
"SOLO con el JSON descrito en las instrucciones del sistema."
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _extract_json(text: str):
|
|
"""Extrae el primer bloque JSON (objeto o array) de la respuesta. PURA.
|
|
|
|
Localiza el bloque que empieza antes (el primer '{' o el primer '[') y, para ese
|
|
delimitador, hace json.loads del rango hasta su ultimo cierre. Tolera texto basura
|
|
alrededor y fences ```json. NUNCA lanza: ante cualquier fallo devuelve None.
|
|
|
|
Args:
|
|
text: respuesta cruda del LLM.
|
|
|
|
Returns:
|
|
El objeto/lista deserializado, o None si no se pudo parsear.
|
|
"""
|
|
if not text or not isinstance(text, str):
|
|
return None
|
|
|
|
opens = []
|
|
i_obj = text.find("{")
|
|
if i_obj != -1:
|
|
opens.append((i_obj, "{", "}"))
|
|
i_arr = text.find("[")
|
|
if i_arr != -1:
|
|
opens.append((i_arr, "[", "]"))
|
|
opens.sort()
|
|
|
|
for _, open_c, close_c in opens:
|
|
start = text.find(open_c)
|
|
end = text.rfind(close_c)
|
|
if start != -1 and end != -1 and end > start:
|
|
try:
|
|
return json.loads(text[start : end + 1])
|
|
except (ValueError, TypeError):
|
|
continue
|
|
return None
|
|
|
|
|
|
def _validate_aggregations(raw_aggs, group_cols: set, measure_set: set, max_aggs: int) -> list:
|
|
"""Filtra las agregaciones del LLM a las que usan SOLO columnas candidatas. PURA.
|
|
|
|
Descarta cualquier agregacion cuyo group_by no este en group_cols o que no tenga
|
|
al menos una medida valida. Recorta las medidas a las presentes en measure_set.
|
|
Limita el resultado a max_aggs entradas.
|
|
"""
|
|
out = []
|
|
if not isinstance(raw_aggs, list):
|
|
return out
|
|
for item in raw_aggs:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
gb = item.get("group_by")
|
|
if not isinstance(gb, str) or gb not in group_cols:
|
|
continue # columna inventada -> se descarta
|
|
raw_measures = item.get("measures")
|
|
if isinstance(raw_measures, str):
|
|
raw_measures = [raw_measures]
|
|
if not isinstance(raw_measures, list):
|
|
continue
|
|
measures = [m for m in raw_measures if isinstance(m, str) and m in measure_set]
|
|
if not measures:
|
|
continue # sin medidas validas -> agregacion inutil
|
|
why = item.get("why")
|
|
why = str(why) if why is not None else ""
|
|
out.append({"group_by": gb, "measures": measures, "why": why})
|
|
if len(out) >= max_aggs:
|
|
break
|
|
return out
|
|
|
|
|
|
def _validate_pivots(raw_pivots, group_cols: set, measure_set: set) -> list:
|
|
"""Filtra los pivots del LLM a los que usan SOLO columnas candidatas. PURA.
|
|
|
|
Descarta el pivot si index o columns no son columnas categoricas validas. Si el
|
|
value no es una medida valida, lo deja en None (un pivot de conteo sigue siendo util).
|
|
"""
|
|
out = []
|
|
if not isinstance(raw_pivots, list):
|
|
return out
|
|
for item in raw_pivots:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
idx = item.get("index")
|
|
cols = item.get("columns")
|
|
if not (isinstance(idx, str) and idx in group_cols):
|
|
continue
|
|
if not (isinstance(cols, str) and cols in group_cols):
|
|
continue
|
|
val = item.get("value")
|
|
if not (isinstance(val, str) and val in measure_set):
|
|
val = None
|
|
why = item.get("why")
|
|
why = str(why) if why is not None else ""
|
|
out.append({"index": idx, "columns": cols, "value": val, "why": why})
|
|
return out
|
|
|
|
|
|
def _fallback_aggregations(group_cols_sorted: list, measures: list, max_aggs: int) -> list:
|
|
"""Agregaciones deterministas: cada columna categorica x todas las medidas. PURA."""
|
|
out = []
|
|
for col in group_cols_sorted:
|
|
out.append(
|
|
{
|
|
"group_by": col,
|
|
"measures": list(measures),
|
|
"why": "selección cuantitativa (sin LLM)",
|
|
}
|
|
)
|
|
if len(out) >= max_aggs:
|
|
break
|
|
return out
|
|
|
|
|
|
def _fallback_pivots(cand_pivots: list) -> list:
|
|
"""Normaliza los pivots candidatos a la forma de salida (tal cual + why). PURA."""
|
|
out = []
|
|
if not isinstance(cand_pivots, list):
|
|
return out
|
|
for p in cand_pivots:
|
|
if not isinstance(p, dict):
|
|
continue
|
|
idx = p.get("index")
|
|
cols = p.get("columns")
|
|
if not (isinstance(idx, str) and isinstance(cols, str)):
|
|
continue
|
|
val = p.get("value")
|
|
if not isinstance(val, str):
|
|
val = None
|
|
out.append(
|
|
{
|
|
"index": idx,
|
|
"columns": cols,
|
|
"value": val,
|
|
"why": "selección cuantitativa (sin LLM)",
|
|
}
|
|
)
|
|
return out
|
|
|
|
|
|
def suggest_aggregations_llm(
|
|
profile: dict,
|
|
candidates: dict,
|
|
max_aggs: int = 4,
|
|
model: str = "claude-haiku-4-5-20251001",
|
|
) -> dict:
|
|
"""Elige las agregaciones y pivots mas informativos con UNA llamada al LLM.
|
|
|
|
MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Toma el perfil de la tabla y
|
|
los candidatos cuantitativos (salida de select_groupby_keys) y deja que el LLM
|
|
seleccione/ordene las K agregaciones (GROUP BY categorica x medidas) y los pivots
|
|
mas utiles, con una razon corta cada uno, evitando la explosion combinatoria.
|
|
|
|
Privacidad/coste: solo viaja al LLM el resumen AGREGADO de los candidatos, nunca
|
|
filas crudas. Una sola llamada barata.
|
|
|
|
dict-no-throw con fallback determinista: NUNCA lanza. Si el LLM falla, el JSON no
|
|
parsea, o no produce seleccion valida -> construye la respuesta desde los candidatos
|
|
(group_keys x measures hasta max_aggs, pivots tal cual) con source="fallback". Las
|
|
columnas que el LLM invente (no presentes en los candidatos) se descartan.
|
|
|
|
Args:
|
|
profile: TableProfile del grupo eda. Solo se usa profile['table'] para nombrar
|
|
la tabla en el prompt; puede ir vacio.
|
|
candidates: salida de select_groupby_keys, con la forma
|
|
{group_keys:[{col,cardinality,score}], measures:[str],
|
|
pivots:[{index,columns,value}]}.
|
|
max_aggs: tope de agregaciones a devolver. Default 4. Valores <1 o no-int se
|
|
normalizan a 4.
|
|
model: id del modelo Anthropic. Default 'claude-haiku-4-5-20251001' (haiku,
|
|
coste bajo, ~2-3s).
|
|
|
|
Returns:
|
|
dict {status:"ok", source:"llm"|"fallback",
|
|
aggregations:[{group_by:str, measures:[str], why:str}],
|
|
pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}.
|
|
source=="llm" si el LLM produjo al menos una agregacion valida; en cualquier
|
|
otro caso "fallback". NUNCA lanza.
|
|
"""
|
|
if not isinstance(candidates, dict):
|
|
candidates = {}
|
|
if isinstance(max_aggs, bool) or not isinstance(max_aggs, int) or max_aggs < 1:
|
|
max_aggs = 4
|
|
|
|
group_cols, measures, measure_set, cand_pivots, group_keys = _candidate_view(candidates)
|
|
group_cols_sorted = _sorted_group_cols(group_keys)
|
|
|
|
# Sin material suficiente para agregar: no merece la pena llamar al LLM.
|
|
if not group_cols or not measures:
|
|
return {
|
|
"status": "ok",
|
|
"source": "fallback",
|
|
"aggregations": [],
|
|
"pivots": _fallback_pivots(cand_pivots),
|
|
"note": "sin candidatos suficientes para agregar",
|
|
}
|
|
|
|
prompt = _build_prompt(profile, candidates, max_aggs)
|
|
try:
|
|
text = ask_llm(prompt, model=model, system=_SYSTEM, echo=False)
|
|
except Exception: # noqa: BLE001 — degradacion: cualquier fallo de red/LLM.
|
|
text = ""
|
|
|
|
parsed = _extract_json(text)
|
|
if parsed is not None:
|
|
if isinstance(parsed, dict):
|
|
raw_aggs = parsed.get("aggregations")
|
|
raw_pivots = parsed.get("pivots")
|
|
elif isinstance(parsed, list):
|
|
raw_aggs = parsed
|
|
raw_pivots = None
|
|
else:
|
|
raw_aggs = None
|
|
raw_pivots = None
|
|
|
|
aggs = _validate_aggregations(raw_aggs, group_cols, measure_set, max_aggs)
|
|
if aggs:
|
|
pivots = _validate_pivots(raw_pivots, group_cols, measure_set)
|
|
if not pivots:
|
|
pivots = _fallback_pivots(cand_pivots)
|
|
return {
|
|
"status": "ok",
|
|
"source": "llm",
|
|
"aggregations": aggs,
|
|
"pivots": pivots,
|
|
"note": f"{len(aggs)} agregaciones y {len(pivots)} pivots seleccionados por el LLM",
|
|
}
|
|
|
|
# Fallback determinista.
|
|
note = (
|
|
"LLM no disponible; selección cuantitativa determinista"
|
|
if not text
|
|
else "LLM sin selección válida; selección cuantitativa determinista"
|
|
)
|
|
return {
|
|
"status": "ok",
|
|
"source": "fallback",
|
|
"aggregations": _fallback_aggregations(group_cols_sorted, measures, max_aggs),
|
|
"pivots": _fallback_pivots(cand_pivots),
|
|
"note": note,
|
|
}
|