96da9e3015
Cuatro funciones nuevas del grupo eda que nutren el capítulo AGREGACION: - select_groupby_keys (pure): elige categóricas agrupables + numéricas medida desde el TableProfile. - groupby_stats_duckdb (impure): GROUP BY push-down en DuckDB (count/mean/median/std/min/max por grupo). - pivot_table_duckdb (impure): pivot A×B push-down, limitado a top filas/cols para no cortar. - suggest_aggregations_llm (impure): el LLM elige las agregaciones interesantes con fallback determinista. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
311 lines
11 KiB
Python
311 lines
11 KiB
Python
"""Pure EDA helper: pick GROUP BY keys and measures from a TableProfile.
|
|
|
|
Given a ``TableProfile`` of the ``eda`` group (the dict produced by, e.g.,
|
|
``summarize_table_duckdb``), this function deterministically selects the most
|
|
interesting categorical columns to group by (GROUP BY), the numeric measure
|
|
columns to aggregate, and a couple of categorical x categorical pivot pairs.
|
|
|
|
It is the quantitative backbone for the aggregation / OLAP chapter of an
|
|
AutomaticEDA: a pure, deterministic ranking over the profile, with no I/O, no
|
|
mutation of the input and no external dependencies (stdlib only). It never
|
|
raises — a missing or malformed profile yields an empty, well-formed result.
|
|
"""
|
|
|
|
|
|
def select_groupby_keys(
|
|
profile: dict,
|
|
max_keys: int = 3,
|
|
max_card: int = 20,
|
|
max_measures: int = 4,
|
|
) -> dict:
|
|
"""Select GROUP BY keys, measures and pivot pairs from a TableProfile.
|
|
|
|
Reads everything defensively (``.get(...)``, ``or []``, ``isinstance``) and
|
|
never raises. With an empty/None profile it returns every list empty.
|
|
|
|
Selection rules (deterministic):
|
|
|
|
- **group_keys** (categorical columns to group by): candidates have
|
|
``inferred_type`` in ``("categorical", "boolean")``. Discarded if they are
|
|
in ``profile['key_candidates']``, carry a ``possible_id`` /
|
|
``high_cardinality`` / ``constant`` flag, have ``distinct_count`` outside
|
|
``[2, max_card]``, or are all-null (``null_pct >= 0.999``). Each survivor
|
|
gets ``score = card_score * balance_score`` where ``card_score`` keeps a
|
|
plateau for moderate cardinality (2..12) and decays towards ``max_card``,
|
|
and ``balance_score = 1 / imbalance`` (``categorical.imbalance`` if
|
|
present, else approximated from ``mode_pct``, else a neutral default).
|
|
The top ``max_keys`` by score (desc, ties by column order) are returned.
|
|
|
|
- **measures** (numeric columns to aggregate): candidates have
|
|
``inferred_type`` in ``("numeric", "integer", "float")``. Discarded if
|
|
id-like (``possible_id`` flag *and* ``unique_pct >= 0.99``) or constant
|
|
(``numeric.std`` is ``0`` or ``None``). Ranked by informative dispersion:
|
|
``abs(cv)`` when available, else ``abs(std)``. The top ``max_measures``
|
|
**names** are returned.
|
|
|
|
- **pivots**: up to 2 ``(group_keys[i].col, group_keys[j].col)`` pairs with
|
|
``i < j``, using the first measure as the aggregated value. Empty when
|
|
fewer than 2 group keys were selected.
|
|
|
|
Args:
|
|
profile: TableProfile dict of the ``eda`` group. Relevant keys:
|
|
``columns`` (list[ColumnProfile]), ``key_candidates`` (list of
|
|
column names or ``{name}`` dicts), ``n_rows``. Each ColumnProfile
|
|
uses: ``name``, ``inferred_type``, ``distinct_count``,
|
|
``unique_pct`` (0..1), ``null_pct`` (0..1), ``flags`` (list[str]),
|
|
``numeric`` ({std, cv, ...}|None), ``categorical``
|
|
({imbalance, mode_pct, ...}|None).
|
|
max_keys: Maximum number of group-by keys to return. Default 3.
|
|
max_card: Maximum cardinality (``distinct_count``) a categorical column
|
|
may have to still qualify as a group key. Default 20.
|
|
max_measures: Maximum number of measure names to return. Default 4.
|
|
|
|
Returns:
|
|
dict with:
|
|
group_keys (list[{col, cardinality, score}], ordered by score desc),
|
|
measures (list[str], numeric column names ordered by dispersion),
|
|
pivots (list[{index, columns, value}], up to 2 pairs),
|
|
note (str, short summary of what was chosen).
|
|
"""
|
|
if not isinstance(profile, dict):
|
|
profile = {}
|
|
|
|
try:
|
|
max_keys = int(max_keys)
|
|
except (TypeError, ValueError):
|
|
max_keys = 3
|
|
try:
|
|
max_card = int(max_card)
|
|
except (TypeError, ValueError):
|
|
max_card = 20
|
|
try:
|
|
max_measures = int(max_measures)
|
|
except (TypeError, ValueError):
|
|
max_measures = 4
|
|
max_keys = max(max_keys, 0)
|
|
max_card = max(max_card, 2)
|
|
max_measures = max(max_measures, 0)
|
|
|
|
columns = profile.get("columns") or []
|
|
if not isinstance(columns, (list, tuple)):
|
|
columns = []
|
|
|
|
key_names = _key_candidate_names(profile.get("key_candidates"))
|
|
|
|
group_keys = _select_group_keys(columns, key_names, max_keys, max_card)
|
|
measures = _select_measures(columns, max_measures)
|
|
pivots = _select_pivots(group_keys, measures)
|
|
|
|
return {
|
|
"group_keys": group_keys,
|
|
"measures": measures,
|
|
"pivots": pivots,
|
|
"note": _build_note(group_keys, measures, pivots),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# group_keys
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_GROUP_TYPES = ("categorical", "boolean")
|
|
_DISQUALIFYING_FLAGS = frozenset({"possible_id", "high_cardinality", "constant"})
|
|
_CARD_PLATEAU_HI = 12 # cardinalities 2..12 are all "moderate" (best).
|
|
|
|
|
|
def _select_group_keys(columns, key_names, max_keys, max_card) -> list:
|
|
"""Rank categorical/boolean columns suitable for GROUP BY."""
|
|
scored = []
|
|
for idx, col in enumerate(columns):
|
|
if not isinstance(col, dict):
|
|
continue
|
|
if (col.get("inferred_type") or "") not in _GROUP_TYPES:
|
|
continue
|
|
|
|
name = col.get("name")
|
|
if name is None:
|
|
continue
|
|
if name in key_names:
|
|
continue
|
|
|
|
flags = _as_set(col.get("flags"))
|
|
if flags & _DISQUALIFYING_FLAGS:
|
|
continue
|
|
|
|
if _num(col.get("null_pct"), 0.0) >= 0.999:
|
|
continue
|
|
|
|
card = _num(col.get("distinct_count"), 0.0)
|
|
if card < 2 or card > max_card:
|
|
continue
|
|
card_i = int(card)
|
|
|
|
score = _card_score(card_i, max_card) * _balance_score(col.get("categorical"))
|
|
scored.append((round(score, 6), idx, name, card_i))
|
|
|
|
# Deterministic: higher score first, ties broken by original column order.
|
|
scored.sort(key=lambda t: (-t[0], t[1]))
|
|
|
|
out = []
|
|
for score, _idx, name, card_i in scored[:max_keys]:
|
|
out.append({"col": name, "cardinality": card_i, "score": score})
|
|
return out
|
|
|
|
|
|
def _card_score(card: int, max_card: int) -> float:
|
|
"""Prefer moderate cardinality; plateau at 2..12, decay towards max_card."""
|
|
if card <= 1:
|
|
return 0.0
|
|
if card <= _CARD_PLATEAU_HI:
|
|
return 1.0
|
|
denom = max(max_card - _CARD_PLATEAU_HI, 1)
|
|
over = card - _CARD_PLATEAU_HI
|
|
return max(0.1, 1.0 - over / denom)
|
|
|
|
|
|
def _balance_score(categorical) -> float:
|
|
"""1.0 for a perfectly balanced category, decaying as imbalance grows.
|
|
|
|
Uses ``categorical.imbalance`` (max_count/min_count, >= 1) when available;
|
|
otherwise approximates from ``mode_pct`` (top-class dominance); otherwise a
|
|
neutral default so the column is still selectable.
|
|
"""
|
|
if isinstance(categorical, dict):
|
|
imbalance = categorical.get("imbalance")
|
|
if isinstance(imbalance, (int, float)) and imbalance >= 1.0:
|
|
return 1.0 / float(imbalance)
|
|
mode_pct = categorical.get("mode_pct")
|
|
if isinstance(mode_pct, (int, float)):
|
|
return _clamp(1.0 - float(mode_pct), 0.0, 1.0)
|
|
return 0.5
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# measures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_NUMERIC_TYPES = ("numeric", "integer", "float")
|
|
|
|
|
|
def _select_measures(columns, max_measures) -> list:
|
|
"""Rank numeric columns by informative dispersion (cv, else std)."""
|
|
scored = []
|
|
for idx, col in enumerate(columns):
|
|
if not isinstance(col, dict):
|
|
continue
|
|
if (col.get("inferred_type") or "") not in _NUMERIC_TYPES:
|
|
continue
|
|
|
|
name = col.get("name")
|
|
if name is None:
|
|
continue
|
|
|
|
flags = _as_set(col.get("flags"))
|
|
unique_pct = _num(col.get("unique_pct"), 0.0)
|
|
if "possible_id" in flags and unique_pct >= 0.99:
|
|
continue # sequential id, not a measure.
|
|
|
|
numeric = col.get("numeric")
|
|
std = numeric.get("std") if isinstance(numeric, dict) else None
|
|
if not isinstance(std, (int, float)) or std == 0:
|
|
continue # constant or unknown spread -> not informative.
|
|
|
|
cv = numeric.get("cv") if isinstance(numeric, dict) else None
|
|
if isinstance(cv, (int, float)):
|
|
dispersion = abs(float(cv))
|
|
else:
|
|
dispersion = abs(float(std))
|
|
|
|
scored.append((dispersion, idx, name))
|
|
|
|
# Higher dispersion first, ties broken by original column order.
|
|
scored.sort(key=lambda t: (-t[0], t[1]))
|
|
return [name for _disp, _idx, name in scored[:max_measures]]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# pivots
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _select_pivots(group_keys, measures) -> list:
|
|
"""Up to 2 (cat_a, cat_b) pairs from the chosen group keys."""
|
|
if not isinstance(group_keys, list) or len(group_keys) < 2:
|
|
return []
|
|
value = measures[0] if measures else None
|
|
pairs = []
|
|
n = len(group_keys)
|
|
for i in range(n):
|
|
for j in range(i + 1, n):
|
|
pairs.append({
|
|
"index": group_keys[i].get("col"),
|
|
"columns": group_keys[j].get("col"),
|
|
"value": value,
|
|
})
|
|
if len(pairs) >= 2:
|
|
return pairs
|
|
return pairs
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_note(group_keys, measures, pivots) -> str:
|
|
"""One-line Spanish summary of the selection."""
|
|
parts = []
|
|
if group_keys:
|
|
cols = ", ".join(str(g.get("col")) for g in group_keys)
|
|
parts.append(f"{len(group_keys)} clave(s) de grupo: {cols}")
|
|
else:
|
|
parts.append("sin categóricas agrupables")
|
|
if measures:
|
|
parts.append(f"{len(measures)} medida(s): " + ", ".join(str(m) for m in measures))
|
|
else:
|
|
parts.append("sin medidas numéricas")
|
|
if pivots:
|
|
parts.append(f"{len(pivots)} pivot(s)")
|
|
return "; ".join(parts) + "."
|
|
|
|
|
|
def _key_candidate_names(key_candidates) -> set:
|
|
"""Normalize ``key_candidates`` (strings or ``{name}`` dicts) to a name set."""
|
|
names = set()
|
|
if not isinstance(key_candidates, (list, tuple)):
|
|
return names
|
|
for entry in key_candidates:
|
|
if isinstance(entry, str):
|
|
names.add(entry)
|
|
elif isinstance(entry, dict):
|
|
nm = entry.get("name") or entry.get("col")
|
|
if nm is not None:
|
|
names.add(nm)
|
|
return names
|
|
|
|
|
|
def _as_set(flags) -> set:
|
|
"""Coerce a flags value into a set, tolerating None / non-iterables."""
|
|
if isinstance(flags, (list, tuple, set)):
|
|
return set(flags)
|
|
return set()
|
|
|
|
|
|
def _num(value, default: float) -> float:
|
|
"""Best-effort float conversion with a fallback default."""
|
|
if value is None:
|
|
return default
|
|
try:
|
|
return float(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def _clamp(x: float, lo: float, hi: float) -> float:
|
|
"""Recorta x al rango [lo, hi]."""
|
|
if x < lo:
|
|
return lo
|
|
if x > hi:
|
|
return hi
|
|
return x
|