feat(eda): funciones de agregación/OLAP para AutomaticEDA (groupby/pivot push-down + selección LLM)
Cuatro funciones nuevas del grupo eda que nutren el capítulo AGREGACION: - select_groupby_keys (pure): elige categóricas agrupables + numéricas medida desde el TableProfile. - groupby_stats_duckdb (impure): GROUP BY push-down en DuckDB (count/mean/median/std/min/max por grupo). - pivot_table_duckdb (impure): pivot A×B push-down, limitado a top filas/cols para no cortar. - suggest_aggregations_llm (impure): el LLM elige las agregaciones interesantes con fallback determinista. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,310 @@
|
||||
"""Pure EDA helper: pick GROUP BY keys and measures from a TableProfile.
|
||||
|
||||
Given a ``TableProfile`` of the ``eda`` group (the dict produced by, e.g.,
|
||||
``summarize_table_duckdb``), this function deterministically selects the most
|
||||
interesting categorical columns to group by (GROUP BY), the numeric measure
|
||||
columns to aggregate, and a couple of categorical x categorical pivot pairs.
|
||||
|
||||
It is the quantitative backbone for the aggregation / OLAP chapter of an
|
||||
AutomaticEDA: a pure, deterministic ranking over the profile, with no I/O, no
|
||||
mutation of the input and no external dependencies (stdlib only). It never
|
||||
raises — a missing or malformed profile yields an empty, well-formed result.
|
||||
"""
|
||||
|
||||
|
||||
def select_groupby_keys(
|
||||
profile: dict,
|
||||
max_keys: int = 3,
|
||||
max_card: int = 20,
|
||||
max_measures: int = 4,
|
||||
) -> dict:
|
||||
"""Select GROUP BY keys, measures and pivot pairs from a TableProfile.
|
||||
|
||||
Reads everything defensively (``.get(...)``, ``or []``, ``isinstance``) and
|
||||
never raises. With an empty/None profile it returns every list empty.
|
||||
|
||||
Selection rules (deterministic):
|
||||
|
||||
- **group_keys** (categorical columns to group by): candidates have
|
||||
``inferred_type`` in ``("categorical", "boolean")``. Discarded if they are
|
||||
in ``profile['key_candidates']``, carry a ``possible_id`` /
|
||||
``high_cardinality`` / ``constant`` flag, have ``distinct_count`` outside
|
||||
``[2, max_card]``, or are all-null (``null_pct >= 0.999``). Each survivor
|
||||
gets ``score = card_score * balance_score`` where ``card_score`` keeps a
|
||||
plateau for moderate cardinality (2..12) and decays towards ``max_card``,
|
||||
and ``balance_score = 1 / imbalance`` (``categorical.imbalance`` if
|
||||
present, else approximated from ``mode_pct``, else a neutral default).
|
||||
The top ``max_keys`` by score (desc, ties by column order) are returned.
|
||||
|
||||
- **measures** (numeric columns to aggregate): candidates have
|
||||
``inferred_type`` in ``("numeric", "integer", "float")``. Discarded if
|
||||
id-like (``possible_id`` flag *and* ``unique_pct >= 0.99``) or constant
|
||||
(``numeric.std`` is ``0`` or ``None``). Ranked by informative dispersion:
|
||||
``abs(cv)`` when available, else ``abs(std)``. The top ``max_measures``
|
||||
**names** are returned.
|
||||
|
||||
- **pivots**: up to 2 ``(group_keys[i].col, group_keys[j].col)`` pairs with
|
||||
``i < j``, using the first measure as the aggregated value. Empty when
|
||||
fewer than 2 group keys were selected.
|
||||
|
||||
Args:
|
||||
profile: TableProfile dict of the ``eda`` group. Relevant keys:
|
||||
``columns`` (list[ColumnProfile]), ``key_candidates`` (list of
|
||||
column names or ``{name}`` dicts), ``n_rows``. Each ColumnProfile
|
||||
uses: ``name``, ``inferred_type``, ``distinct_count``,
|
||||
``unique_pct`` (0..1), ``null_pct`` (0..1), ``flags`` (list[str]),
|
||||
``numeric`` ({std, cv, ...}|None), ``categorical``
|
||||
({imbalance, mode_pct, ...}|None).
|
||||
max_keys: Maximum number of group-by keys to return. Default 3.
|
||||
max_card: Maximum cardinality (``distinct_count``) a categorical column
|
||||
may have to still qualify as a group key. Default 20.
|
||||
max_measures: Maximum number of measure names to return. Default 4.
|
||||
|
||||
Returns:
|
||||
dict with:
|
||||
group_keys (list[{col, cardinality, score}], ordered by score desc),
|
||||
measures (list[str], numeric column names ordered by dispersion),
|
||||
pivots (list[{index, columns, value}], up to 2 pairs),
|
||||
note (str, short summary of what was chosen).
|
||||
"""
|
||||
if not isinstance(profile, dict):
|
||||
profile = {}
|
||||
|
||||
try:
|
||||
max_keys = int(max_keys)
|
||||
except (TypeError, ValueError):
|
||||
max_keys = 3
|
||||
try:
|
||||
max_card = int(max_card)
|
||||
except (TypeError, ValueError):
|
||||
max_card = 20
|
||||
try:
|
||||
max_measures = int(max_measures)
|
||||
except (TypeError, ValueError):
|
||||
max_measures = 4
|
||||
max_keys = max(max_keys, 0)
|
||||
max_card = max(max_card, 2)
|
||||
max_measures = max(max_measures, 0)
|
||||
|
||||
columns = profile.get("columns") or []
|
||||
if not isinstance(columns, (list, tuple)):
|
||||
columns = []
|
||||
|
||||
key_names = _key_candidate_names(profile.get("key_candidates"))
|
||||
|
||||
group_keys = _select_group_keys(columns, key_names, max_keys, max_card)
|
||||
measures = _select_measures(columns, max_measures)
|
||||
pivots = _select_pivots(group_keys, measures)
|
||||
|
||||
return {
|
||||
"group_keys": group_keys,
|
||||
"measures": measures,
|
||||
"pivots": pivots,
|
||||
"note": _build_note(group_keys, measures, pivots),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# group_keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GROUP_TYPES = ("categorical", "boolean")
|
||||
_DISQUALIFYING_FLAGS = frozenset({"possible_id", "high_cardinality", "constant"})
|
||||
_CARD_PLATEAU_HI = 12 # cardinalities 2..12 are all "moderate" (best).
|
||||
|
||||
|
||||
def _select_group_keys(columns, key_names, max_keys, max_card) -> list:
|
||||
"""Rank categorical/boolean columns suitable for GROUP BY."""
|
||||
scored = []
|
||||
for idx, col in enumerate(columns):
|
||||
if not isinstance(col, dict):
|
||||
continue
|
||||
if (col.get("inferred_type") or "") not in _GROUP_TYPES:
|
||||
continue
|
||||
|
||||
name = col.get("name")
|
||||
if name is None:
|
||||
continue
|
||||
if name in key_names:
|
||||
continue
|
||||
|
||||
flags = _as_set(col.get("flags"))
|
||||
if flags & _DISQUALIFYING_FLAGS:
|
||||
continue
|
||||
|
||||
if _num(col.get("null_pct"), 0.0) >= 0.999:
|
||||
continue
|
||||
|
||||
card = _num(col.get("distinct_count"), 0.0)
|
||||
if card < 2 or card > max_card:
|
||||
continue
|
||||
card_i = int(card)
|
||||
|
||||
score = _card_score(card_i, max_card) * _balance_score(col.get("categorical"))
|
||||
scored.append((round(score, 6), idx, name, card_i))
|
||||
|
||||
# Deterministic: higher score first, ties broken by original column order.
|
||||
scored.sort(key=lambda t: (-t[0], t[1]))
|
||||
|
||||
out = []
|
||||
for score, _idx, name, card_i in scored[:max_keys]:
|
||||
out.append({"col": name, "cardinality": card_i, "score": score})
|
||||
return out
|
||||
|
||||
|
||||
def _card_score(card: int, max_card: int) -> float:
|
||||
"""Prefer moderate cardinality; plateau at 2..12, decay towards max_card."""
|
||||
if card <= 1:
|
||||
return 0.0
|
||||
if card <= _CARD_PLATEAU_HI:
|
||||
return 1.0
|
||||
denom = max(max_card - _CARD_PLATEAU_HI, 1)
|
||||
over = card - _CARD_PLATEAU_HI
|
||||
return max(0.1, 1.0 - over / denom)
|
||||
|
||||
|
||||
def _balance_score(categorical) -> float:
|
||||
"""1.0 for a perfectly balanced category, decaying as imbalance grows.
|
||||
|
||||
Uses ``categorical.imbalance`` (max_count/min_count, >= 1) when available;
|
||||
otherwise approximates from ``mode_pct`` (top-class dominance); otherwise a
|
||||
neutral default so the column is still selectable.
|
||||
"""
|
||||
if isinstance(categorical, dict):
|
||||
imbalance = categorical.get("imbalance")
|
||||
if isinstance(imbalance, (int, float)) and imbalance >= 1.0:
|
||||
return 1.0 / float(imbalance)
|
||||
mode_pct = categorical.get("mode_pct")
|
||||
if isinstance(mode_pct, (int, float)):
|
||||
return _clamp(1.0 - float(mode_pct), 0.0, 1.0)
|
||||
return 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# measures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NUMERIC_TYPES = ("numeric", "integer", "float")
|
||||
|
||||
|
||||
def _select_measures(columns, max_measures) -> list:
|
||||
"""Rank numeric columns by informative dispersion (cv, else std)."""
|
||||
scored = []
|
||||
for idx, col in enumerate(columns):
|
||||
if not isinstance(col, dict):
|
||||
continue
|
||||
if (col.get("inferred_type") or "") not in _NUMERIC_TYPES:
|
||||
continue
|
||||
|
||||
name = col.get("name")
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
flags = _as_set(col.get("flags"))
|
||||
unique_pct = _num(col.get("unique_pct"), 0.0)
|
||||
if "possible_id" in flags and unique_pct >= 0.99:
|
||||
continue # sequential id, not a measure.
|
||||
|
||||
numeric = col.get("numeric")
|
||||
std = numeric.get("std") if isinstance(numeric, dict) else None
|
||||
if not isinstance(std, (int, float)) or std == 0:
|
||||
continue # constant or unknown spread -> not informative.
|
||||
|
||||
cv = numeric.get("cv") if isinstance(numeric, dict) else None
|
||||
if isinstance(cv, (int, float)):
|
||||
dispersion = abs(float(cv))
|
||||
else:
|
||||
dispersion = abs(float(std))
|
||||
|
||||
scored.append((dispersion, idx, name))
|
||||
|
||||
# Higher dispersion first, ties broken by original column order.
|
||||
scored.sort(key=lambda t: (-t[0], t[1]))
|
||||
return [name for _disp, _idx, name in scored[:max_measures]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pivots
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _select_pivots(group_keys, measures) -> list:
|
||||
"""Up to 2 (cat_a, cat_b) pairs from the chosen group keys."""
|
||||
if not isinstance(group_keys, list) or len(group_keys) < 2:
|
||||
return []
|
||||
value = measures[0] if measures else None
|
||||
pairs = []
|
||||
n = len(group_keys)
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
pairs.append({
|
||||
"index": group_keys[i].get("col"),
|
||||
"columns": group_keys[j].get("col"),
|
||||
"value": value,
|
||||
})
|
||||
if len(pairs) >= 2:
|
||||
return pairs
|
||||
return pairs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_note(group_keys, measures, pivots) -> str:
|
||||
"""One-line Spanish summary of the selection."""
|
||||
parts = []
|
||||
if group_keys:
|
||||
cols = ", ".join(str(g.get("col")) for g in group_keys)
|
||||
parts.append(f"{len(group_keys)} clave(s) de grupo: {cols}")
|
||||
else:
|
||||
parts.append("sin categóricas agrupables")
|
||||
if measures:
|
||||
parts.append(f"{len(measures)} medida(s): " + ", ".join(str(m) for m in measures))
|
||||
else:
|
||||
parts.append("sin medidas numéricas")
|
||||
if pivots:
|
||||
parts.append(f"{len(pivots)} pivot(s)")
|
||||
return "; ".join(parts) + "."
|
||||
|
||||
|
||||
def _key_candidate_names(key_candidates) -> set:
|
||||
"""Normalize ``key_candidates`` (strings or ``{name}`` dicts) to a name set."""
|
||||
names = set()
|
||||
if not isinstance(key_candidates, (list, tuple)):
|
||||
return names
|
||||
for entry in key_candidates:
|
||||
if isinstance(entry, str):
|
||||
names.add(entry)
|
||||
elif isinstance(entry, dict):
|
||||
nm = entry.get("name") or entry.get("col")
|
||||
if nm is not None:
|
||||
names.add(nm)
|
||||
return names
|
||||
|
||||
|
||||
def _as_set(flags) -> set:
|
||||
"""Coerce a flags value into a set, tolerating None / non-iterables."""
|
||||
if isinstance(flags, (list, tuple, set)):
|
||||
return set(flags)
|
||||
return set()
|
||||
|
||||
|
||||
def _num(value, default: float) -> float:
|
||||
"""Best-effort float conversion with a fallback default."""
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _clamp(x: float, lo: float, hi: float) -> float:
|
||||
"""Recorta x al rango [lo, hi]."""
|
||||
if x < lo:
|
||||
return lo
|
||||
if x > hi:
|
||||
return hi
|
||||
return x
|
||||
Reference in New Issue
Block a user