"""Pure EDA helper: pick GROUP BY keys and measures from a TableProfile. Given a ``TableProfile`` of the ``eda`` group (the dict produced by, e.g., ``summarize_table_duckdb``), this function deterministically selects the most interesting categorical columns to group by (GROUP BY), the numeric measure columns to aggregate, and a couple of categorical x categorical pivot pairs. It is the quantitative backbone for the aggregation / OLAP chapter of an AutomaticEDA: a pure, deterministic ranking over the profile, with no I/O, no mutation of the input and no external dependencies (stdlib only). It never raises — a missing or malformed profile yields an empty, well-formed result. """ def select_groupby_keys( profile: dict, max_keys: int = 3, max_card: int = 20, max_measures: int = 4, ) -> dict: """Select GROUP BY keys, measures and pivot pairs from a TableProfile. Reads everything defensively (``.get(...)``, ``or []``, ``isinstance``) and never raises. With an empty/None profile it returns every list empty. Selection rules (deterministic): - **group_keys** (categorical columns to group by): candidates have ``inferred_type`` in ``("categorical", "boolean")``. Discarded if they are in ``profile['key_candidates']``, carry a ``possible_id`` / ``high_cardinality`` / ``constant`` flag, have ``distinct_count`` outside ``[2, max_card]``, or are all-null (``null_pct >= 0.999``). Each survivor gets ``score = card_score * balance_score`` where ``card_score`` keeps a plateau for moderate cardinality (2..12) and decays towards ``max_card``, and ``balance_score = 1 / imbalance`` (``categorical.imbalance`` if present, else approximated from ``mode_pct``, else a neutral default). The top ``max_keys`` by score (desc, ties by column order) are returned. - **measures** (numeric columns to aggregate): candidates have ``inferred_type`` in ``("numeric", "integer", "float")``. Discarded if id-like (``possible_id`` flag *and* ``unique_pct >= 0.99``) or constant (``numeric.std`` is ``0`` or ``None``). Ranked by informative dispersion: ``abs(cv)`` when available, else ``abs(std)``. The top ``max_measures`` **names** are returned. - **pivots**: up to 2 ``(group_keys[i].col, group_keys[j].col)`` pairs with ``i < j``, using the first measure as the aggregated value. Empty when fewer than 2 group keys were selected. Args: profile: TableProfile dict of the ``eda`` group. Relevant keys: ``columns`` (list[ColumnProfile]), ``key_candidates`` (list of column names or ``{name}`` dicts), ``n_rows``. Each ColumnProfile uses: ``name``, ``inferred_type``, ``distinct_count``, ``unique_pct`` (0..1), ``null_pct`` (0..1), ``flags`` (list[str]), ``numeric`` ({std, cv, ...}|None), ``categorical`` ({imbalance, mode_pct, ...}|None). max_keys: Maximum number of group-by keys to return. Default 3. max_card: Maximum cardinality (``distinct_count``) a categorical column may have to still qualify as a group key. Default 20. max_measures: Maximum number of measure names to return. Default 4. Returns: dict with: group_keys (list[{col, cardinality, score}], ordered by score desc), measures (list[str], numeric column names ordered by dispersion), pivots (list[{index, columns, value}], up to 2 pairs), note (str, short summary of what was chosen). """ if not isinstance(profile, dict): profile = {} try: max_keys = int(max_keys) except (TypeError, ValueError): max_keys = 3 try: max_card = int(max_card) except (TypeError, ValueError): max_card = 20 try: max_measures = int(max_measures) except (TypeError, ValueError): max_measures = 4 max_keys = max(max_keys, 0) max_card = max(max_card, 2) max_measures = max(max_measures, 0) columns = profile.get("columns") or [] if not isinstance(columns, (list, tuple)): columns = [] key_names = _key_candidate_names(profile.get("key_candidates")) group_keys = _select_group_keys(columns, key_names, max_keys, max_card) measures = _select_measures(columns, max_measures) pivots = _select_pivots(group_keys, measures) return { "group_keys": group_keys, "measures": measures, "pivots": pivots, "note": _build_note(group_keys, measures, pivots), } # --------------------------------------------------------------------------- # group_keys # --------------------------------------------------------------------------- _GROUP_TYPES = ("categorical", "boolean") _DISQUALIFYING_FLAGS = frozenset({"possible_id", "high_cardinality", "constant"}) _CARD_PLATEAU_HI = 12 # cardinalities 2..12 are all "moderate" (best). def _select_group_keys(columns, key_names, max_keys, max_card) -> list: """Rank categorical/boolean columns suitable for GROUP BY.""" scored = [] for idx, col in enumerate(columns): if not isinstance(col, dict): continue if (col.get("inferred_type") or "") not in _GROUP_TYPES: continue name = col.get("name") if name is None: continue if name in key_names: continue flags = _as_set(col.get("flags")) if flags & _DISQUALIFYING_FLAGS: continue if _num(col.get("null_pct"), 0.0) >= 0.999: continue card = _num(col.get("distinct_count"), 0.0) if card < 2 or card > max_card: continue card_i = int(card) score = _card_score(card_i, max_card) * _balance_score(col.get("categorical")) scored.append((round(score, 6), idx, name, card_i)) # Deterministic: higher score first, ties broken by original column order. scored.sort(key=lambda t: (-t[0], t[1])) out = [] for score, _idx, name, card_i in scored[:max_keys]: out.append({"col": name, "cardinality": card_i, "score": score}) return out def _card_score(card: int, max_card: int) -> float: """Prefer moderate cardinality; plateau at 2..12, decay towards max_card.""" if card <= 1: return 0.0 if card <= _CARD_PLATEAU_HI: return 1.0 denom = max(max_card - _CARD_PLATEAU_HI, 1) over = card - _CARD_PLATEAU_HI return max(0.1, 1.0 - over / denom) def _balance_score(categorical) -> float: """1.0 for a perfectly balanced category, decaying as imbalance grows. Uses ``categorical.imbalance`` (max_count/min_count, >= 1) when available; otherwise approximates from ``mode_pct`` (top-class dominance); otherwise a neutral default so the column is still selectable. """ if isinstance(categorical, dict): imbalance = categorical.get("imbalance") if isinstance(imbalance, (int, float)) and imbalance >= 1.0: return 1.0 / float(imbalance) mode_pct = categorical.get("mode_pct") if isinstance(mode_pct, (int, float)): return _clamp(1.0 - float(mode_pct), 0.0, 1.0) return 0.5 # --------------------------------------------------------------------------- # measures # --------------------------------------------------------------------------- _NUMERIC_TYPES = ("numeric", "integer", "float") def _select_measures(columns, max_measures) -> list: """Rank numeric columns by informative dispersion (cv, else std).""" scored = [] for idx, col in enumerate(columns): if not isinstance(col, dict): continue if (col.get("inferred_type") or "") not in _NUMERIC_TYPES: continue name = col.get("name") if name is None: continue flags = _as_set(col.get("flags")) unique_pct = _num(col.get("unique_pct"), 0.0) if "possible_id" in flags and unique_pct >= 0.99: continue # sequential id, not a measure. numeric = col.get("numeric") std = numeric.get("std") if isinstance(numeric, dict) else None if not isinstance(std, (int, float)) or std == 0: continue # constant or unknown spread -> not informative. cv = numeric.get("cv") if isinstance(numeric, dict) else None if isinstance(cv, (int, float)): dispersion = abs(float(cv)) else: dispersion = abs(float(std)) scored.append((dispersion, idx, name)) # Higher dispersion first, ties broken by original column order. scored.sort(key=lambda t: (-t[0], t[1])) return [name for _disp, _idx, name in scored[:max_measures]] # --------------------------------------------------------------------------- # pivots # --------------------------------------------------------------------------- def _select_pivots(group_keys, measures) -> list: """Up to 2 (cat_a, cat_b) pairs from the chosen group keys.""" if not isinstance(group_keys, list) or len(group_keys) < 2: return [] value = measures[0] if measures else None pairs = [] n = len(group_keys) for i in range(n): for j in range(i + 1, n): pairs.append({ "index": group_keys[i].get("col"), "columns": group_keys[j].get("col"), "value": value, }) if len(pairs) >= 2: return pairs return pairs # --------------------------------------------------------------------------- # helpers # --------------------------------------------------------------------------- def _build_note(group_keys, measures, pivots) -> str: """One-line Spanish summary of the selection.""" parts = [] if group_keys: cols = ", ".join(str(g.get("col")) for g in group_keys) parts.append(f"{len(group_keys)} clave(s) de grupo: {cols}") else: parts.append("sin categóricas agrupables") if measures: parts.append(f"{len(measures)} medida(s): " + ", ".join(str(m) for m in measures)) else: parts.append("sin medidas numéricas") if pivots: parts.append(f"{len(pivots)} pivot(s)") return "; ".join(parts) + "." def _key_candidate_names(key_candidates) -> set: """Normalize ``key_candidates`` (strings or ``{name}`` dicts) to a name set.""" names = set() if not isinstance(key_candidates, (list, tuple)): return names for entry in key_candidates: if isinstance(entry, str): names.add(entry) elif isinstance(entry, dict): nm = entry.get("name") or entry.get("col") if nm is not None: names.add(nm) return names def _as_set(flags) -> set: """Coerce a flags value into a set, tolerating None / non-iterables.""" if isinstance(flags, (list, tuple, set)): return set(flags) return set() def _num(value, default: float) -> float: """Best-effort float conversion with a fallback default.""" if value is None: return default try: return float(value) except (TypeError, ValueError): return default def _clamp(x: float, lo: float, hi: float) -> float: """Recorta x al rango [lo, hi].""" if x < lo: return lo if x > hi: return hi return x