Compare commits

..

2 Commits

Author SHA1 Message Date
egutierrez 00cd5274bc feat(eda): capítulo GEOSPATIAL del AutomaticEDA (scatter geográfico + zona/país)
Capítulo nuevo chapters/geospatial.py (CHAPTER_VERSION 1.0.0). Cuando el dataset
tiene un par de coordenadas, dibuja un scatter geográfico en proyección
equirectangular (la escala respeta la latitud para no estirar la longitud) y
analiza la extensión: bounding box, centroide, span, conteo por zona/país,
hemisferios y una interpretación. Cuando NO hay coordenadas, build_geospatial
devuelve None y el capítulo se omite.

Sigue el contrato de capítulos (firma build_<id>(profile, ctx) -> Chapter|None,
lectura defensiva, nunca lanza) y el patrón de modelos/num_distr: delega el
cálculo a las primitivas puras del registry (detect_latlon_columns,
analyze_geo_extent, build_geo_scatter) y solo dibuja la figura matplotlib de
forma perezosa. Las coordenadas crudas llegan por ctx['geo_points'] o
ctx['raw_numeric'] (como modelos lee raw_numeric); sin ellas, degrada con un
bounding box aproximado de numeric.min/max y una nota honesta.

Anti-cortes: usa DataTable/KVTable/Figure/Markdown del modelo, que el paginador
parte sin cortar. Test self-contained con golden + 6 edges + anti-cut (nombres
largos + 2100 puntos en varias regiones renderizan a PDF y PPTX sin truncar).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-30 15:29:33 +02:00
egutierrez cd658cc703 feat(eda): primitivas geoespaciales del grupo eda (detección lat/lon + extensión + scatter)
Tres funciones puras nuevas del dominio datascience (tags eda + geospatial) que
sostienen el capítulo GEOSPATIAL del AutomaticEDA, delegadas a fn-constructor:

- detect_latlon_columns: identifica el par (lat, lon) por nombre de columna +
  rango de valores ([-90,90] / [-180,180]) desde profile['columns']. Devuelve
  {lat_col, lon_col, confidence, reason}. 9 tests.
- analyze_geo_extent: bbox, centroide, span haversine, conteo por zona/país
  (lookup offline con bounding boxes embebidos, KISS sin geopandas) y
  hemisferios. 7 tests.
- build_geo_scatter: prepara los puntos del scatter en orden [lon, lat] con
  downsampling determinista por paso fijo + aspect equirectangular 1/cos(lat)
  clampado. 6 tests.

Registradas en datascience/__init__.py. Todas pure, params_schema completo,
.md autosuficiente (Ejemplo + Cuando usarla + Gotchas).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-30 15:29:33 +02:00
26 changed files with 1891 additions and 2994 deletions
+6 -6
View File
@@ -25,7 +25,6 @@ from .describe_numeric import describe_numeric
from .summarize_categorical import summarize_categorical
from .infer_semantic_type import infer_semantic_type
from .column_quality_score import column_quality_score
from .select_groupby_keys import select_groupby_keys
from .render_eda_markdown import render_eda_markdown
from .detect_distribution_type import detect_distribution_type
from .spearman_corr import spearman_corr
@@ -37,8 +36,6 @@ from .infer_fk_containment_duckdb import infer_fk_containment_duckdb
from .build_join_graph import build_join_graph
from .association_matrix import association_matrix
from .correlation_matrix_duckdb import correlation_matrix_duckdb
from .pivot_table_duckdb import pivot_table_duckdb
from .groupby_stats_duckdb import groupby_stats_duckdb
from .pca_explained import pca_explained
from .kmeans_segments import kmeans_segments
from .isolation_forest_outliers import isolation_forest_outliers
@@ -47,6 +44,9 @@ from .trend_slope import trend_slope
from .run_eda_models import run_eda_models
from .project_clusters_2d import project_clusters_2d
from .describe_clusters_llm import describe_clusters_llm
from .detect_latlon_columns import detect_latlon_columns
from .analyze_geo_extent import analyze_geo_extent
from .build_geo_scatter import build_geo_scatter
from .eda_llm_insights import eda_llm_insights
from .build_eda_notebook import build_eda_notebook
from .decode_qr_image import decode_qr_image
@@ -85,8 +85,6 @@ __all__ = [
"build_join_graph",
"association_matrix",
"correlation_matrix_duckdb",
"pivot_table_duckdb",
"groupby_stats_duckdb",
"pca_explained",
"kmeans_segments",
"isolation_forest_outliers",
@@ -95,13 +93,15 @@ __all__ = [
"run_eda_models",
"project_clusters_2d",
"describe_clusters_llm",
"detect_latlon_columns",
"analyze_geo_extent",
"build_geo_scatter",
"eda_llm_insights",
"build_eda_notebook",
"describe_numeric",
"summarize_categorical",
"infer_semantic_type",
"column_quality_score",
"select_groupby_keys",
"render_eda_markdown",
"detect_distribution_type",
"pull_gsc_search_analytics",
@@ -0,0 +1,61 @@
---
name: analyze_geo_extent
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: pure
signature: "def analyze_geo_extent(lats: list, lons: list) -> dict"
description: "Calcula la extension geografica de una nube de coordenadas (lat/lon) y asigna cada punto a un pais/region mediante un lookup OFFLINE contra una tabla de bounding boxes embebida como constante. Devuelve bounding box, centroide, span de la diagonal (haversine), conteo por region (top-8 + Otros), reparto por hemisferios y una frase resumen en ES. Lectura defensiva: descarta pares None/NaN/fuera de rango y NUNCA lanza. Solo stdlib (math); sin geopandas/shapely. Las cajas de paises son rectangulos aproximados, no reverse-geocoding exacto."
tags: [eda, geospatial, geo, coordinates, bounding-box, haversine, datascience]
params:
- name: lats
desc: "Lista de latitudes en grados, rango valido [-90, 90]. Se empareja por indice con lons (gana la longitud minima comun si difieren). Cada valor puede ser None/NaN/no-numerico/fuera de rango: se lee defensivo y se descarta el par."
- name: lons
desc: "Lista de longitudes en grados, rango valido [-180, 180]. Paralela a lats, emparejada por indice. Valores None/NaN/no-numericos/fuera de rango se descartan junto con su par."
output: "Dict con el resumen geografico: {n_points=pares validos usados, bbox={lat_min,lat_max,lon_min,lon_max} o None, centroid={lat,lon}=media de lat/lon validos o None, span_km=distancia haversine (radio 6371 km) de la diagonal SO->NE del bbox, by_region=[{region,count}] descendente por count limitado a top-8 con el resto agregado en 'Otros', hemisphere={north,south,east,west} (ecuador->norte, meridiano 0->este), note=frase ES resumen}. Si no hay pares validos devuelve la forma cero: n_points 0, bbox None, centroid None, span_km 0.0, by_region [], hemisphere a ceros y note 'sin coordenadas validas'. Puntos que no caen en ninguna caja -> region 'Oceano/Otros'."
uses_functions: []
uses_types: []
returns: []
returns_optional: false
error_type: ""
imports: [math]
tested: true
tests: ["test_nube_en_espana", "test_dos_paises_distintos", "test_listas_vacias", "test_pares_invalidos_filtrados", "test_longitudes_desbalanceadas", "test_span_km_haversine_par_conocido", "test_no_lanza_con_entradas_raras"]
test_file_path: "python/functions/datascience/analyze_geo_extent_test.py"
file_path: "python/functions/datascience/analyze_geo_extent.py"
---
## Ejemplo
```python
import sys, os
sys.path.insert(0, os.path.join("python", "functions"))
from datascience.analyze_geo_extent import analyze_geo_extent
# Nube de puntos alrededor de Madrid + un punto en Paris.
lats = [40.4, 40.0, 41.0, 48.8]
lons = [-3.7, -3.5, -4.0, 2.3]
res = analyze_geo_extent(lats, lons)
print(res["n_points"]) # 4
print(res["by_region"]) # [{'region': 'España', 'count': 3}, {'region': 'Francia', 'count': 1}]
print(round(res["span_km"], 1)) # diagonal SO->NE del bbox en km
print(res["hemisphere"]) # {'north': 4, 'south': 0, 'east': 1, 'west': 3}
print(res["note"]) # los puntos se concentran en España (3 de 4)
```
## Cuando usarla
- Usala en el perfilado EDA (grupo `eda`) cuando una tabla tenga columnas de latitud y longitud y quieras un resumen geografico rapido: donde se concentran los puntos, cuanto territorio cubren y a que paises/regiones caen, sin montar geopandas ni un reverse-geocoder.
- Cuando necesites un capitulo `geospatial` del `AutomaticEDA`: alimenta el bbox + centroide para centrar un mapa, el `span_km` para elegir el zoom, y `by_region` para una tabla de conteos por pais.
- Cuando quieras detectar datos sucios de coordenadas (mezcla de hemisferios inesperada, puntos en `Oceano/Otros`, span enorme) antes de seguir el analisis.
## Gotchas
- Funcion pura, sin I/O ni red y determinista: mismas entradas -> misma salida. Lectura defensiva, NUNCA lanza; pares con None/NaN o fuera de rango ([-90,90] lat, [-180,180] lon) se descartan en silencio.
- El lookup de region es una **aproximacion rectangular**: cada pais/region es un bounding box, NO su frontera real. Un punto en el mar cerca de una costa, o en una esquina del rectangulo, puede asignarse a un pais vecino. No es reverse-geocoding exacto — para precision real hace falta un shapefile (fuera de scope por KISS).
- Cajas solapadas se resuelven por orden: gana la PRIMERA que contiene el punto. Los paises se listan antes que los continentes (fallback), y entre vecinos el mas estrecho/occidental va primero (Portugal antes que España, Chile antes que Argentina, EEUU contiguo antes que Canada). Un punto que no cae en ninguna caja -> `Oceano/Otros`.
- La tabla cubre ~24 paises grandes + 6 regiones continentales; paises pequeños o no listados caen a su continente o a `Oceano/Otros`. No incluye territorios insulares lejanos (Canarias, Hawaii, etc.).
- `span_km` es la diagonal del bounding box (esquina SO a NE), no la dispersion real de la nube ni el area; con un solo punto valido el bbox es degenerado y `span_km` es 0.0.
- El ecuador (`lat == 0`) cuenta como hemisferio norte y el meridiano 0 (`lon == 0`) como este, por convencion `>= 0`.
@@ -0,0 +1,209 @@
"""analyze_geo_extent — geographic extent of a cloud of coordinates (EDA `geospatial`).
Pure function: no I/O, no network, deterministic. Given two parallel lists of
latitudes and longitudes it derives the bounding box, centroid, diagonal span
(haversine), per-region counts and hemisphere split of the points, and assigns
each point to a country/region via an OFFLINE lookup against a table of
rectangular bounding boxes embedded as a constant (`_REGION_BBOXES`).
It never reads files, never hits the network and depends only on `math`. The
country boxes are deliberately coarse rectangles (a KISS approximation, NOT a
reverse-geocoder). Reading is defensive throughout and the function NEVER
raises: invalid pairs (None / NaN / out of range) are silently discarded and an
empty cloud yields a zeroed result the caller can skip.
"""
import math
# Earth mean radius in km used by the haversine formula.
_EARTH_RADIUS_KM = 6371.0
# How many distinct regions to surface in `by_region` before collapsing the
# remainder into a single "Otros" bucket.
_TOP_REGIONS = 8
# Offline region lookup: (name, lat_min, lat_max, lon_min, lon_max).
#
# Specific countries are listed FIRST and continental fallbacks LAST: each point
# is assigned to the FIRST box that contains it, so the more specific country box
# wins over the broad continent box. Boxes are coarse rectangles approximating
# the mainland extent of each region; overlapping neighbours are ordered so the
# narrower/more-western country claims its coastal points (e.g. Portugal before
# Spain, Chile before Argentina, the contiguous US before Canada).
_REGION_BBOXES = (
# --- countries (specific) ---
("Portugal", 36.9, 42.2, -9.6, -6.2),
("España", 36.0, 43.8, -9.4, 3.4),
("Francia", 41.3, 51.1, -5.2, 9.6),
("Reino Unido", 49.9, 58.7, -8.6, 1.8),
("Irlanda", 51.4, 55.4, -10.6, -5.9),
("Países Bajos", 50.7, 53.6, 3.3, 7.2),
("Bélgica", 49.5, 51.5, 2.5, 6.4),
("Suiza", 45.8, 47.8, 5.9, 10.5),
("Alemania", 47.3, 55.1, 5.9, 15.0),
("Italia", 36.6, 47.1, 6.6, 18.5),
("Marruecos", 27.7, 35.9, -13.2, -1.0),
("Egipto", 22.0, 31.7, 25.0, 35.0),
("Sudáfrica", -34.8, -22.1, 16.5, 32.9),
("China", 18.0, 53.6, 73.5, 135.1),
("Japón", 24.0, 45.6, 122.9, 145.9),
("India", 6.7, 35.5, 68.1, 97.4),
("Australia", -43.7, -10.0, 112.9, 153.7),
("México", 14.5, 32.7, -118.4, -86.7),
("Estados Unidos", 24.4, 49.4, -125.0, -66.9),
("Canadá", 41.7, 83.1, -141.0, -52.6),
("Chile", -55.9, -17.5, -75.6, -66.4),
("Argentina", -55.1, -21.8, -73.6, -53.6),
("Brasil", -33.8, 5.3, -74.0, -34.8),
("Rusia", 41.2, 77.0, 19.6, 180.0),
# --- continental fallbacks (broad) ---
("Europa", 34.0, 72.0, -25.0, 45.0),
("África", -35.0, 37.5, -18.0, 52.0),
("Asia", 5.0, 78.0, 26.0, 180.0),
("América del Norte", 7.0, 84.0, -168.0, -52.0),
("América del Sur", -56.0, 13.0, -82.0, -34.0),
("Oceanía", -50.0, 0.0, 110.0, 180.0),
)
def _coord(value, limit):
"""Coerce a coordinate to a valid float in [-limit, limit] or None.
bool is a subclass of int but never a real coordinate, so True/False are
treated as missing. NaN and out-of-range values are rejected.
"""
if value is None or isinstance(value, bool):
return None
try:
f = float(value)
except (TypeError, ValueError):
return None
# NaN is the only value that is not equal to itself.
if f != f or f < -limit or f > limit:
return None
return f
def _haversine_km(lat1, lon1, lat2, lon2):
"""Great-circle distance in km between two (lat, lon) points in degrees."""
rlat1, rlat2 = math.radians(lat1), math.radians(lat2)
dlat = math.radians(lat2 - lat1)
dlon = math.radians(lon2 - lon1)
a = math.sin(dlat / 2.0) ** 2 + math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2.0) ** 2
return 2.0 * _EARTH_RADIUS_KM * math.asin(min(1.0, math.sqrt(a)))
def _region_of(lat, lon):
"""Return the name of the first embedded box containing (lat, lon)."""
for name, lat_min, lat_max, lon_min, lon_max in _REGION_BBOXES:
if lat_min <= lat <= lat_max and lon_min <= lon <= lon_max:
return name
return "Océano/Otros"
def _empty_result():
"""Result shape when there are no valid coordinate pairs."""
return {
"n_points": 0,
"bbox": None,
"centroid": None,
"span_km": 0.0,
"by_region": [],
"hemisphere": {"north": 0, "south": 0, "east": 0, "west": 0},
"note": "sin coordenadas validas",
}
def analyze_geo_extent(lats: list, lons: list) -> dict:
"""Summarise the geographic extent of a cloud of lat/lon coordinates.
Pairs `lats[i]` with `lons[i]` by index (over the common length when the two
lists differ in size), discards any pair where either value is None / NaN or
outside [-90, 90] (lat) / [-180, 180] (lon), and derives the bounding box,
centroid, diagonal span, per-region counts and hemisphere split. Each valid
point is matched to a country/region by an offline lookup against coarse
rectangular bounding boxes (`_REGION_BBOXES`).
Args:
lats: List of latitudes in degrees ([-90, 90]); read defensively.
lons: List of longitudes in degrees ([-180, 180]); read defensively.
Paired with `lats` by index; the shorter length wins when they differ.
Returns:
Dict with the geographic summary:
{n_points, bbox={lat_min,lat_max,lon_min,lon_max}, centroid={lat,lon},
span_km (haversine of the SW->NE bbox diagonal), by_region=[{region,count}]
(descending, top-8 with the rest folded into "Otros"),
hemisphere={north,south,east,west}, note (Spanish summary phrase)}.
With no valid pairs returns the zeroed shape: n_points 0, bbox None,
centroid None, span_km 0.0, empty by_region, zeroed hemisphere and the
note "sin coordenadas validas". Never raises.
"""
if not isinstance(lats, (list, tuple)) or not isinstance(lons, (list, tuple)):
return _empty_result()
valid = []
# zip already stops at the shorter list -> unbalanced lengths are handled.
for raw_lat, raw_lon in zip(lats, lons):
lat = _coord(raw_lat, 90.0)
lon = _coord(raw_lon, 180.0)
if lat is None or lon is None:
continue
valid.append((lat, lon))
if not valid:
return _empty_result()
n = len(valid)
lat_vals = [p[0] for p in valid]
lon_vals = [p[1] for p in valid]
lat_min, lat_max = min(lat_vals), max(lat_vals)
lon_min, lon_max = min(lon_vals), max(lon_vals)
centroid_lat = sum(lat_vals) / n
centroid_lon = sum(lon_vals) / n
# Diagonal span: SW corner (lat_min, lon_min) to NE corner (lat_max, lon_max).
span_km = _haversine_km(lat_min, lon_min, lat_max, lon_max)
# Hemisphere split: the equator/prime-meridian go to north/east respectively.
north = sum(1 for lat in lat_vals if lat >= 0.0)
south = n - north
east = sum(1 for lon in lon_vals if lon >= 0.0)
west = n - east
# Count points per region (offline bbox lookup).
counts = {}
for lat, lon in valid:
region = _region_of(lat, lon)
counts[region] = counts.get(region, 0) + 1
# Descending by count, then by name for a deterministic tie-break.
ranked = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))
by_region = [{"region": name, "count": count} for name, count in ranked[:_TOP_REGIONS]]
rest = sum(count for _, count in ranked[_TOP_REGIONS:])
if rest > 0:
by_region.append({"region": "Otros", "count": rest})
top_region, top_count = ranked[0]
note = (
"los puntos se concentran en {region} ({count} de {n})".format(
region=top_region, count=top_count, n=n
)
)
return {
"n_points": n,
"bbox": {
"lat_min": lat_min,
"lat_max": lat_max,
"lon_min": lon_min,
"lon_max": lon_max,
},
"centroid": {"lat": centroid_lat, "lon": centroid_lon},
"span_km": span_km,
"by_region": by_region,
"hemisphere": {"north": north, "south": south, "east": east, "west": west},
"note": note,
}
@@ -0,0 +1,126 @@
"""Tests para analyze_geo_extent."""
import math
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from analyze_geo_extent import analyze_geo_extent, _haversine_km
# Keys that a non-empty result dict must always contain.
_EXPECTED_KEYS = {
"n_points", "bbox", "centroid", "span_km",
"by_region", "hemisphere", "note",
}
def test_nube_en_espana():
"""Golden: nube de puntos alrededor de Madrid -> region top = España."""
# Cuatro puntos en torno a Madrid (lat ~40, lon ~-3.7), con algo de spread.
lats = [40.4, 40.0, 41.0, 39.5]
lons = [-3.7, -3.5, -4.0, -3.2]
res = analyze_geo_extent(lats, lons)
assert set(res.keys()) == _EXPECTED_KEYS
assert res["n_points"] == 4
# Todos caen en España -> by_region una sola entrada.
assert res["by_region"][0]["region"] == "España"
assert res["by_region"][0]["count"] == 4
# Centroide coherente: media de lat y lon.
assert math.isclose(res["centroid"]["lat"], sum(lats) / 4, rel_tol=1e-9)
assert math.isclose(res["centroid"]["lon"], sum(lons) / 4, rel_tol=1e-9)
# bbox correcto.
assert res["bbox"]["lat_min"] == 39.5
assert res["bbox"]["lat_max"] == 41.0
assert res["bbox"]["lon_min"] == -4.0
assert res["bbox"]["lon_max"] == -3.2
# Hay spread -> diagonal > 0.
assert res["span_km"] > 0.0
# Hemisferio norte (lat>0) y oeste (lon<0).
assert res["hemisphere"]["north"] == 4
assert res["hemisphere"]["south"] == 0
assert res["hemisphere"]["east"] == 0
assert res["hemisphere"]["west"] == 4
assert "España" in res["note"]
def test_dos_paises_distintos():
"""Golden: puntos en España y Francia -> by_region con 2 entradas."""
# Madrid (España) x2 y Paris (Francia) x1.
lats = [40.4, 40.0, 48.8]
lons = [-3.7, -3.5, 2.3]
res = analyze_geo_extent(lats, lons)
assert res["n_points"] == 3
regions = {entry["region"]: entry["count"] for entry in res["by_region"]}
assert regions == {"España": 2, "Francia": 1}
# Orden descendente por count: España (2) antes que Francia (1).
assert res["by_region"][0]["region"] == "España"
assert res["by_region"][0]["count"] == 2
# Madrid y Paris ambos hemisferio norte; Paris lon>0 -> 1 east, 2 west.
assert res["hemisphere"]["north"] == 3
assert res["hemisphere"]["east"] == 1
assert res["hemisphere"]["west"] == 2
def test_listas_vacias():
"""Edge: listas vacias -> n_points 0, bbox None, sin lanzar."""
res = analyze_geo_extent([], [])
assert res["n_points"] == 0
assert res["bbox"] is None
assert res["centroid"] is None
assert res["span_km"] == 0.0
assert res["by_region"] == []
assert res["hemisphere"] == {"north": 0, "south": 0, "east": 0, "west": 0}
assert res["note"] == "sin coordenadas validas"
def test_pares_invalidos_filtrados():
"""Edge: None / NaN / fuera de rango se descartan, no lanza."""
nan = float("nan")
lats = [40.4, None, nan, 91.0, -200.0, 40.0]
lons = [-3.7, -3.5, -3.0, 2.0, 5.0, -3.5]
# Validos: indices 0 y 5 (lat 91 fuera de rango, lon -200 fuera de rango,
# None y NaN descartados).
res = analyze_geo_extent(lats, lons)
assert res["n_points"] == 2
assert res["by_region"][0]["region"] == "España"
assert res["by_region"][0]["count"] == 2
def test_longitudes_desbalanceadas():
"""Edge: len(lats) != len(lons) usa el minimo comun sin lanzar."""
lats = [40.4, 40.0, 41.0, 39.5] # 4 elementos
lons = [-3.7, -3.5] # 2 elementos
res = analyze_geo_extent(lats, lons)
# Solo se emparejan los 2 primeros.
assert res["n_points"] == 2
assert res["bbox"]["lat_min"] == 40.0
assert res["bbox"]["lat_max"] == 40.4
def test_span_km_haversine_par_conocido():
"""Edge: span_km coincide con haversine de la diagonal del bbox."""
# Dos puntos: (0, 0) y (0, 1). bbox diagonal = mismos dos puntos.
res = analyze_geo_extent([0.0, 0.0], [0.0, 1.0])
# 1 grado de longitud en el ecuador ~ 111.19 km.
expected = _haversine_km(0.0, 0.0, 0.0, 1.0)
assert math.isclose(res["span_km"], expected, rel_tol=1e-9)
assert math.isclose(res["span_km"], 111.19, abs_tol=0.5)
def test_no_lanza_con_entradas_raras():
"""Edge: tipos no-lista o None devuelven la forma vacia sin lanzar."""
assert analyze_geo_extent(None, None)["n_points"] == 0
assert analyze_geo_extent("foo", "bar")["n_points"] == 0
# Strings dentro de las listas se descartan como invalidos.
res = analyze_geo_extent(["x", 40.0], [None, -3.5])
assert res["n_points"] == 1
@@ -1,592 +0,0 @@
"""Aggregation chapter (AGREGACION) — group analysis / OLAP of the EDA.
This chapter is the group-by / pivot ("OLAP") section of an AutomaticEDA report
and is meant to be present **whenever the dataset has at least one low-cardinality
categorical column to group by**. For the most interesting categoricals (chosen
by their cardinality/relevance, optionally with an LLM) it renders, as blocks the
core paginator never cuts:
1. **Per-group statistics** (split-apply-combine) — for each interesting
categorical key, the count of rows per group and, for each numeric measure,
its mean/median/std/min/max. One compact summary table (mean of every measure
per group) plus a per-measure detail table.
2. **Bar charts** — a vertical bar chart of a measure's mean per group, bars from
zero (Tufte Lie-Factor = 1).
3. **Pivot tables** — categorical A x categorical B -> aggregate of a measure,
limited to the top rows/cols so it fits a mobile page/slide, with a grouped
bar chart of the same pivot.
The raw data needed to aggregate is **not** in the TableProfile, so — exactly
like ``modelos`` reads its cluster projection from ``ctx`` — this chapter gets
the aggregation results in one of two ways and degrades honestly when neither is
available:
ctx keys this chapter consumes (all optional):
aggregations : dict — pre-computed results, used directly (offline / tests /
forward-compatible with a calculation phase). Shape::
{"groupby": [{"group_by": str, "measures": [str], "why": str,
"result": <groupby_stats_duckdb-shaped dict>}],
"pivots": [{"index": str, "columns": str, "value": str, "agg": str,
"why": str, "result": <pivot_table_duckdb-shaped dict>}]}
db_path, table : str — when ``aggregations`` is absent, the chapter selects
the interesting keys (``select_groupby_keys``), optionally asks an LLM
which to show (``suggest_aggregations_llm`` when ``run_agg_llm`` is True)
and computes the group-by/pivot results live via the push-down registry
functions ``groupby_stats_duckdb`` / ``pivot_table_duckdb``.
run_agg_llm : bool — when True (and ``db_path``/``table`` present), let the
LLM pick the interesting aggregations; otherwise the deterministic
quantitative selection is used.
agg_llm_model : str — model id for the optional LLM selection.
agg_max_keys, agg_max_card, agg_max_measures, agg_top_n : int — limits.
agg_insights : list — optional pre-computed micro-analysis entries
(``[{"title": str, "text": str}]``) rendered as an interpretation section.
Contract: build_<id>(profile, ctx) -> Chapter | None ; CHAPTER_VERSION = "x.y.z".
Reads everything defensively (``.get``) and never raises: anything missing
degrades to a note instead of aborting the chapter; the chapter returns ``None``
only when the dataset has no categorical column to group by.
"""
from __future__ import annotations
from .. import model
# Pure/impure registry functions (group ``eda``) this chapter composes. Imported
# defensively so the chapter still builds (degrading the affected part to a note)
# if a function is somehow unavailable / not indexed yet.
try:
from datascience.select_groupby_keys import select_groupby_keys
except Exception: # noqa: BLE001 — keep the chapter importable no matter what.
select_groupby_keys = None # type: ignore[assignment]
try:
from datascience.groupby_stats_duckdb import groupby_stats_duckdb
except Exception: # noqa: BLE001
groupby_stats_duckdb = None # type: ignore[assignment]
try:
from datascience.pivot_table_duckdb import pivot_table_duckdb
except Exception: # noqa: BLE001
pivot_table_duckdb = None # type: ignore[assignment]
try:
from datascience.suggest_aggregations_llm import suggest_aggregations_llm
except Exception: # noqa: BLE001
suggest_aggregations_llm = None # type: ignore[assignment]
CHAPTER_VERSION = "1.0.0"
CHAPTER_ID = "agregacion"
CHAPTER_TITLE = "Agregación por grupos"
# Tableau-10 palette — stable colours for the pivot's grouped-bar series.
_SERIES_COLORS = [
"#4e79a7", "#f28e2b", "#e15759", "#76b7b2", "#59a14f",
"#edc948", "#b07aa1", "#ff9da7", "#9c755f", "#bab0ac",
]
# Defaults for the live selection/aggregation (overridable via ctx).
_DEF_MAX_KEYS = 3
_DEF_MAX_CARD = 20
_DEF_MAX_MEASURES = 4
_DEF_TOP_N = 12
# --------------------------------------------------------------------------- #
# Formatting helpers (mirror the other chapters' defensive style).
# --------------------------------------------------------------------------- #
def _fmt_num(value, decimals: int = 3) -> str:
if value is None:
return ""
if isinstance(value, bool):
return "" if value else "no"
if isinstance(value, int):
return f"{value:,}".replace(",", ".")
if isinstance(value, float):
if value != value: # NaN
return "NaN"
if value in (float("inf"), float("-inf")):
return str(value)
text = f"{value:.{decimals}f}".rstrip("0").rstrip(".")
return text if text else "0"
return model._safe_str(value)
def _is_dict(v) -> bool:
return isinstance(v, dict)
def _measure_mean(group: dict, measure: str):
"""Pull the mean of one measure out of a groupby-result group entry."""
stats = group.get("stats") if _is_dict(group.get("stats")) else {}
ms = stats.get(measure) if _is_dict(stats.get(measure)) else {}
return ms.get("mean")
# --------------------------------------------------------------------------- #
# Plan + data resolution. Either a pre-computed ctx['aggregations'] is used
# verbatim, or the plan is selected and the results are computed live.
# --------------------------------------------------------------------------- #
def _resolve_candidates(profile: dict, ctx: dict) -> dict:
"""Return {group_keys, measures, pivots, note} of interesting columns."""
pre = ctx.get("agg_candidates")
if _is_dict(pre) and pre.get("group_keys") is not None:
return pre
if select_groupby_keys is not None:
try:
out = select_groupby_keys(
profile,
max_keys=int(ctx.get("agg_max_keys", _DEF_MAX_KEYS)),
max_card=int(ctx.get("agg_max_card", _DEF_MAX_CARD)),
max_measures=int(ctx.get("agg_max_measures", _DEF_MAX_MEASURES)),
)
if _is_dict(out):
return out
except Exception: # noqa: BLE001 — fall through to the inline fallback.
pass
return _inline_candidates(profile, ctx)
def _inline_candidates(profile: dict, ctx: dict) -> dict:
"""Minimal defensive selection when select_groupby_keys is unavailable."""
max_card = int(ctx.get("agg_max_card", _DEF_MAX_CARD))
max_keys = int(ctx.get("agg_max_keys", _DEF_MAX_KEYS))
max_measures = int(ctx.get("agg_max_measures", _DEF_MAX_MEASURES))
keys = profile.get("key_candidates") or []
group_keys, measures = [], []
for col in profile.get("columns") or []:
if not _is_dict(col):
continue
name = col.get("name")
it = col.get("inferred_type")
flags = col.get("flags") or []
dc = col.get("distinct_count")
if it in ("categorical", "boolean") and name not in keys:
if ("possible_id" not in flags and "high_cardinality" not in flags
and "constant" not in flags
and isinstance(dc, int) and 2 <= dc <= max_card):
group_keys.append({"col": name, "cardinality": dc, "score": 0.0})
elif it == "numeric":
num = col.get("numeric") or {}
if num.get("std") not in (None, 0) and not (
"possible_id" in flags and (col.get("unique_pct") or 0) >= 0.99):
measures.append(name)
group_keys = group_keys[:max_keys]
measures = measures[:max_measures]
pivots = []
if len(group_keys) >= 2:
pivots.append({"index": group_keys[0]["col"],
"columns": group_keys[1]["col"],
"value": measures[0] if measures else None})
return {"group_keys": group_keys, "measures": measures, "pivots": pivots,
"note": "selección cuantitativa básica"}
def _resolve_plan(profile: dict, ctx: dict, candidates: dict) -> dict:
"""Return {aggregations:[{group_by,measures,why}], pivots:[...], source}."""
group_keys = candidates.get("group_keys") or []
measures = candidates.get("measures") or []
if ctx.get("run_agg_llm") and suggest_aggregations_llm is not None:
try:
plan = suggest_aggregations_llm(
profile, candidates,
max_aggs=int(ctx.get("agg_max_keys", _DEF_MAX_KEYS)),
model=ctx.get("agg_llm_model", "claude-haiku-4-5-20251001"))
if _is_dict(plan) and plan.get("aggregations"):
return {"aggregations": plan.get("aggregations") or [],
"pivots": plan.get("pivots") or [],
"source": plan.get("source", "llm")}
except Exception: # noqa: BLE001 — fall back to the quantitative plan.
pass
aggregations = [{
"group_by": gk.get("col"),
"measures": measures,
"why": f"categórica de {_fmt_num(gk.get('cardinality'))} niveles",
} for gk in group_keys if _is_dict(gk) and gk.get("col")]
pivots = []
for pv in candidates.get("pivots") or []:
if _is_dict(pv) and pv.get("index") and pv.get("columns"):
pivots.append({"index": pv.get("index"), "columns": pv.get("columns"),
"value": pv.get("value") or (measures[0] if measures else None),
"agg": "mean", "why": "cruce de dos categóricas"})
return {"aggregations": aggregations, "pivots": pivots, "source": "quantitative"}
def _live_groupby(ctx: dict, group_by: str, measures: list, top_n: int):
"""Compute one group-by result live via the push-down registry function."""
db_path = ctx.get("db_path")
table = ctx.get("table")
if not db_path or not table or groupby_stats_duckdb is None:
return None
try:
out = groupby_stats_duckdb(db_path, table, group_by, list(measures or []),
top_n=top_n)
if _is_dict(out) and out.get("status") == "ok":
return out
except Exception: # noqa: BLE001
return None
return None
def _live_pivot(ctx: dict, index: str, columns: str, value, agg: str):
"""Compute one pivot live via the push-down registry function."""
db_path = ctx.get("db_path")
table = ctx.get("table")
if not db_path or not table or pivot_table_duckdb is None or not value:
return None
try:
out = pivot_table_duckdb(db_path, table, index, columns, value,
agg=agg or "mean")
if _is_dict(out) and out.get("status") == "ok":
return out
except Exception: # noqa: BLE001
return None
return None
# --------------------------------------------------------------------------- #
# Figure builders (lazy: matplotlib only imported when the renderer draws them).
# --------------------------------------------------------------------------- #
def _make_group_bars(group_by: str, measure: str, groups: list):
"""Vertical bars: mean of ``measure`` per group, bars from zero."""
labels, values = [], []
for g in groups:
if not _is_dict(g):
continue
mean = _measure_mean(g, measure)
if mean is None:
continue
labels.append(model._safe_str(g.get("key")))
values.append(float(mean))
if not labels:
return None
def _draw():
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(6.6, 3.6))
xs = list(range(len(labels)))
ax.bar(xs, values, color="#4e79a7", alpha=0.9, edgecolor="#2f4d6e",
linewidth=0.4)
ax.set_xticks(xs)
short = [(s[:18] + "") if len(s) > 19 else s for s in labels]
rot = 30 if max((len(s) for s in short), default=0) > 6 else 0
ax.set_xticklabels(short, rotation=rot, ha="right" if rot else "center",
fontsize=7)
ax.set_ylabel(f"media de {measure}", fontsize=8)
ax.set_xlabel(group_by, fontsize=8)
ax.set_title(f"Media de «{measure}» por «{group_by}»", fontsize=10)
ax.grid(axis="y", color="#dddddd", linewidth=0.6)
for spine in ("top", "right"):
ax.spines[spine].set_visible(False)
# Value labels above each bar.
vmax = max(values) if values else 0
for x, v in zip(xs, values):
ax.text(x, v + (abs(vmax) * 0.01 if vmax else 0.01),
_fmt_num(v, 2), ha="center", va="bottom", fontsize=6.5)
fig.tight_layout()
return fig
return _draw
def _make_pivot_bars(pivot: dict):
"""Grouped bars of a pivot: x = row_labels, one series per col_label."""
row_labels = pivot.get("row_labels") or []
col_labels = pivot.get("col_labels") or []
matrix = pivot.get("matrix") or []
if not row_labels or not col_labels or not matrix:
return None
def _draw():
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
n_rows = len(row_labels)
n_cols = len(col_labels)
fig, ax = plt.subplots(figsize=(6.8, 3.8))
total_w = 0.8
bar_w = total_w / max(n_cols, 1)
base = list(range(n_rows))
for j, clabel in enumerate(col_labels):
offs = [b - total_w / 2 + bar_w * (j + 0.5) for b in base]
vals = []
for i in range(n_rows):
cell = matrix[i][j] if (i < len(matrix) and j < len(matrix[i])) else None
vals.append(float(cell) if isinstance(cell, (int, float)) else 0.0)
color = _SERIES_COLORS[j % len(_SERIES_COLORS)]
ax.bar(offs, vals, width=bar_w, color=color, alpha=0.9,
label=model._safe_str(clabel))
ax.set_xticks(base)
short = [(s[:16] + "") if len(s) > 17 else s
for s in (model._safe_str(r) for r in row_labels)]
rot = 30 if max((len(s) for s in short), default=0) > 6 else 0
ax.set_xticklabels(short, rotation=rot, ha="right" if rot else "center",
fontsize=7)
ax.set_xlabel(model._safe_str(pivot.get("index")), fontsize=8)
ax.set_ylabel(f"{pivot.get('agg','mean')} de {pivot.get('value')}",
fontsize=8)
ax.set_title(f"{pivot.get('index')} × {pivot.get('columns')}", fontsize=10)
ax.grid(axis="y", color="#dddddd", linewidth=0.6)
ax.legend(title=model._safe_str(pivot.get("columns")), fontsize=6.5,
title_fontsize=7, frameon=True, framealpha=0.9, loc="best")
for spine in ("top", "right"):
ax.spines[spine].set_visible(False)
fig.tight_layout()
return fig
return _draw
def _group_bars_maker(group_by: str, measure: str, groups: list):
"""Bind per-aggregation args so the lazy closure is loop-safe."""
def _make():
return _make_group_bars(group_by, measure, groups)()
return _make
def _pivot_bars_maker(pivot: dict):
def _make():
return _make_pivot_bars(pivot)()
return _make
# --------------------------------------------------------------------------- #
# Section builders. Each returns a list of blocks (possibly empty).
# --------------------------------------------------------------------------- #
def _groupby_section(group_by: str, measures: list, result: dict, why: str) -> list:
"""Build the blocks for one group-by aggregation, or [] if unusable."""
if not _is_dict(result) or not result.get("groups"):
return []
groups = [g for g in result.get("groups") or [] if _is_dict(g)]
if not groups:
return []
eff_measures = result.get("measures") or measures or []
blocks = [model.Heading(text=f"Agrupado por «{group_by}»", level=2)]
intro = f"**{why}.** " if why else ""
intro += (f"{_fmt_num(result.get('n_groups') or len(groups))} grupos"
f"{' (top por tamaño)' if result.get('truncated') else ''}.")
blocks.append(model.Markdown(text=intro))
# Summary table: one row per group, count + mean of every measure.
header = ["Grupo", "n"] + [f"{m} (media)" for m in eff_measures]
rows = []
for g in groups:
row = [model._safe_str(g.get("key")), _fmt_num(g.get("n"))]
for m in eff_measures:
row.append(_fmt_num(_measure_mean(g, m), 2))
rows.append(row)
blocks.append(model.DataTable(
header=header, rows=rows, title=f"Resumen por «{group_by}»",
note="Conteo de filas y media de cada medida por grupo."))
if not eff_measures:
return blocks
# Primary measure: a bar chart + a detail table (mean/median/std/min/max).
primary = eff_measures[0]
bars = _make_group_bars(group_by, primary, groups)
if bars is not None:
blocks.append(model.Figure(
make=_group_bars_maker(group_by, primary, groups),
caption=f"Media de «{primary}» por «{group_by}» (barras desde cero)."))
det_header = ["Grupo", "n", "media", "mediana", "σ", "mín", "máx"]
det_rows = []
for g in groups:
stats = g.get("stats") if _is_dict(g.get("stats")) else {}
ms = stats.get(primary) if _is_dict(stats.get(primary)) else {}
det_rows.append([
model._safe_str(g.get("key")), _fmt_num(g.get("n")),
_fmt_num(ms.get("mean"), 2), _fmt_num(ms.get("median"), 2),
_fmt_num(ms.get("std"), 2), _fmt_num(ms.get("min"), 2),
_fmt_num(ms.get("max"), 2),
])
blocks.append(model.DataTable(
header=det_header, rows=det_rows,
title=f"Detalle de «{primary}» por «{group_by}»"))
return blocks
def _pivot_section(pivot_spec: dict, result: dict) -> list:
"""Build the blocks for one pivot table, or [] if unusable."""
if not _is_dict(result) or not result.get("row_labels"):
return []
row_labels = result.get("row_labels") or []
col_labels = result.get("col_labels") or []
matrix = result.get("matrix") or []
if not row_labels or not col_labels or not matrix:
return []
index = result.get("index") or pivot_spec.get("index")
columns = result.get("columns") or pivot_spec.get("columns")
value = result.get("value") or pivot_spec.get("value")
agg = result.get("agg") or pivot_spec.get("agg") or "mean"
why = pivot_spec.get("why") or ""
blocks = [model.Heading(text=f"Pivot: «{index}» × «{columns}»", level=2)]
intro = f"**{why}.** " if why else ""
intro += (f"{agg} de «{value}» cruzando «{index}» (filas) y «{columns}» "
f"(columnas).")
if result.get("truncated_rows") or result.get("truncated_cols"):
intro += " Limitado a las filas/columnas más frecuentes."
blocks.append(model.Markdown(text=intro))
header = [model._safe_str(index)] + [model._safe_str(c) for c in col_labels]
rows = []
for i, rlabel in enumerate(row_labels):
row = [model._safe_str(rlabel)]
cells = matrix[i] if i < len(matrix) else []
for j in range(len(col_labels)):
cell = cells[j] if j < len(cells) else None
row.append(_fmt_num(cell, 2))
rows.append(row)
blocks.append(model.DataTable(
header=header, rows=rows,
title=f"{agg} de «{value}»",
note=f"Cada celda es {agg} de «{value}» para esa combinación."))
fig_pivot = {"row_labels": row_labels, "col_labels": col_labels,
"matrix": matrix, "index": index, "columns": columns,
"value": value, "agg": agg}
if _make_pivot_bars(fig_pivot) is not None:
blocks.append(model.Figure(
make=_pivot_bars_maker(fig_pivot),
caption=f"{agg} de «{value}» por «{index}» y «{columns}» "
f"(barras agrupadas)."))
return blocks
def _insights_section(ctx: dict) -> list:
"""Optional pre-computed micro-analysis of the aggregations (SHOULD-11.4)."""
entries = ctx.get("agg_insights")
if not isinstance(entries, list) or not entries:
return []
blocks = [model.Heading(text="Interpretación de los grupos", level=2)]
for e in entries:
if not _is_dict(e):
continue
title = model._safe_str(e.get("title"))
text = model._safe_str(e.get("text"))
line = (f"**{title}.** " if title else "") + text
if line.strip():
blocks.append(model.Markdown(text=line))
return blocks if len(blocks) > 1 else []
# --------------------------------------------------------------------------- #
# Pre-computed path: ctx['aggregations'] already carries the results.
# --------------------------------------------------------------------------- #
def _sections_from_precomputed(agg: dict) -> list:
sections = []
for entry in agg.get("groupby") or []:
if not _is_dict(entry):
continue
sections += _groupby_section(
entry.get("group_by"), entry.get("measures") or [],
entry.get("result") or {}, entry.get("why") or "")
for entry in agg.get("pivots") or []:
if not _is_dict(entry):
continue
sections += _pivot_section(entry, entry.get("result") or {})
return sections
# --------------------------------------------------------------------------- #
# Live path: select keys, pick a plan, compute results via push-down functions.
# --------------------------------------------------------------------------- #
def _sections_live(profile: dict, ctx: dict, candidates: dict) -> list:
top_n = int(ctx.get("agg_top_n", _DEF_TOP_N))
plan = _resolve_plan(profile, ctx, candidates)
sections = []
for agg in plan.get("aggregations") or []:
if not _is_dict(agg) or not agg.get("group_by"):
continue
result = _live_groupby(ctx, agg.get("group_by"),
agg.get("measures") or [], top_n)
if result is not None:
sections += _groupby_section(agg.get("group_by"),
agg.get("measures") or [], result,
agg.get("why") or "")
for pv in plan.get("pivots") or []:
if not _is_dict(pv) or not pv.get("index") or not pv.get("columns"):
continue
result = _live_pivot(ctx, pv.get("index"), pv.get("columns"),
pv.get("value"), pv.get("agg") or "mean")
if result is not None:
sections += _pivot_section(pv, result)
return sections
# --------------------------------------------------------------------------- #
# Entry point.
# --------------------------------------------------------------------------- #
def _intro_blocks() -> list:
text = (
"Este capítulo analiza la tabla **por grupos** (split-apply-combine): "
"elige las columnas categóricas más informativas — por su cardinalidad "
"y relevancia, no todas contra todas, para no inflar comparaciones "
"espurias — y resume las variables numéricas dentro de cada grupo "
"(conteo, media, mediana, desviación). Las **tablas dinámicas** (pivot) "
"cruzan dos categóricas sobre una medida, y los **gráficos de barras** "
"(siempre desde cero) comparan los grupos de un vistazo."
)
return [model.Heading(text=CHAPTER_TITLE, level=1),
model.Markdown(text=text)]
def build_agregacion(profile: dict, ctx: dict):
"""Build the AGREGACION Chapter, or None if the dataset can't be grouped.
Args:
profile: the ``eda`` group TableProfile dict.
ctx: presentation context (see module docstring for the keys consumed).
Returns:
A ``model.Chapter`` with per-group stats, pivots and bar charts; or
``None`` when the dataset has no low-cardinality categorical column to
group by (the chapter does not apply).
"""
profile = profile or {}
ctx = ctx or {}
if not isinstance(profile, dict):
return None
# Pre-computed results take precedence (offline / tests / forward-compat).
pre = ctx.get("aggregations")
if _is_dict(pre) and (pre.get("groupby") or pre.get("pivots")):
sections = _sections_from_precomputed(pre)
if not sections:
return None
blocks = _intro_blocks() + sections + _insights_section(ctx)
return model.Chapter(id=CHAPTER_ID, title=CHAPTER_TITLE,
version=CHAPTER_VERSION, blocks=blocks)
# Live path: needs at least one categorical key to group by.
candidates = _resolve_candidates(profile, ctx)
if not _is_dict(candidates) or not (candidates.get("group_keys")):
return None # chapter does not apply: nothing to group by.
sections = _sections_live(profile, ctx, candidates)
if not sections:
# Applies (there are categorical keys) but no aggregation data is
# reachable: emit an honest note instead of fabricating numbers.
keys = ", ".join(model._safe_str((k or {}).get("col"))
for k in candidates.get("group_keys") or []
if _is_dict(k))
note = model.Note(
"No se pudo calcular la agregación: el capítulo necesita los datos "
"crudos. Pasa ctx['db_path'] + ctx['table'] (para el cálculo "
"push-down en DuckDB) o ctx['aggregations'] ya precalculado. "
f"Columnas categóricas candidatas: {keys or ''}.")
blocks = _intro_blocks() + [note] + _insights_section(ctx)
return model.Chapter(id=CHAPTER_ID, title=CHAPTER_TITLE,
version=CHAPTER_VERSION, blocks=blocks)
blocks = _intro_blocks() + sections + _insights_section(ctx)
return model.Chapter(id=CHAPTER_ID, title=CHAPTER_TITLE,
version=CHAPTER_VERSION, blocks=blocks)
@@ -1,256 +0,0 @@
"""Tests for the AGREGACION chapter — DoD: golden + edges + error/no-cut path.
Self-contained and deterministic: no DuckDB and no LLM. The aggregation results
are passed pre-computed via ``ctx['aggregations']`` (the same shape the push-down
registry functions ``groupby_stats_duckdb`` / ``pivot_table_duckdb`` produce), so
the chapter's rendering logic is exercised without touching disk or the network.
Live push-down + LLM selection are covered separately by the golden script.
Verifies:
- Golden: a profile with categoricals + numerics builds a Chapter with per-group
stats tables, a pivot table and bar-chart figures, and it renders to PDF AND
PPTX showing the group keys, values and pivot — nothing cut.
- Edges: a dataset with no low-cardinality categorical returns None; an empty
profile returns None; a profile that *could* be grouped but has no reachable
data degrades to an honest note instead of raising.
- No-cut: many groups (30) + a long interpretation paragraph survive intact in
the rendered PDF (table split by rows, text wrapped whole).
"""
import os
import re
import tempfile
from pptx import Presentation
from pypdf import PdfReader
from datascience.automatic_eda.chapters.agregacion import build_agregacion
from datascience.automatic_eda.model import Chapter
from datascience.render_automatic_eda_pdf import render_automatic_eda_pdf
from datascience.render_automatic_eda_pptx import render_automatic_eda_pptx
# --------------------------------------------------------------------------- #
# Synthetic fixtures.
# --------------------------------------------------------------------------- #
def _profile() -> dict:
"""A titanic-like profile: 2 categoricals + 2 numeric measures + 1 id."""
return {
"table": "titanic",
"source": "/data/titanic.csv",
"n_rows": 891,
"n_cols": 5,
"key_candidates": ["passenger_id"],
"columns": [
{"name": "passenger_id", "inferred_type": "numeric",
"unique_pct": 1.0, "flags": ["possible_id"],
"numeric": {"mean": 446.0, "std": 257.0}},
{"name": "sex", "inferred_type": "categorical", "distinct_count": 2,
"flags": [], "categorical": {"n_distinct": 2, "imbalance": 0.1,
"top": [{"value": "male", "count": 577}]}},
{"name": "pclass", "inferred_type": "categorical", "distinct_count": 3,
"flags": [], "categorical": {"n_distinct": 3, "imbalance": 0.2}},
{"name": "fare", "inferred_type": "numeric", "flags": [],
"numeric": {"mean": 32.2, "std": 49.7, "cv": 1.54}},
{"name": "age", "inferred_type": "numeric", "flags": [],
"numeric": {"mean": 29.7, "std": 14.5, "cv": 0.49}},
],
}
def _groupby_result(group_by: str, keys_n: list) -> dict:
"""A groupby_stats_duckdb-shaped result for `fare` and `age`."""
groups = []
for i, (key, n) in enumerate(keys_n):
groups.append({
"key": key, "n": n,
"stats": {
"fare": {"mean": 20.0 + i * 15, "median": 10.0 + i * 8,
"std": 40.0 + i, "min": 0.0, "max": 512.3},
"age": {"mean": 28.0 + i, "median": 27.0 + i, "std": 14.0,
"min": 0.42, "max": 80.0},
},
})
return {"status": "ok", "group_by": group_by, "measures": ["fare", "age"],
"aggs": ["count", "mean", "median", "std", "min", "max"],
"n_groups": len(groups), "truncated": False, "groups": groups}
def _pivot_result() -> dict:
return {"status": "ok", "index": "sex", "columns": "pclass", "value": "fare",
"agg": "mean", "row_labels": ["male", "female"],
"col_labels": ["1", "2", "3"],
"matrix": [[62.0, 19.0, 12.0], [110.0, 22.0, 15.0]],
"truncated_rows": False, "truncated_cols": False}
def _ctx_precomputed() -> dict:
return {
"aggregations": {
"groupby": [
{"group_by": "sex", "measures": ["fare", "age"],
"why": "sexo del pasajero",
"result": _groupby_result("sex", [("male", 577), ("female", 314)])},
{"group_by": "pclass", "measures": ["fare", "age"],
"why": "clase del billete",
"result": _groupby_result(
"pclass", [("3", 491), ("1", 216), ("2", 184)])},
],
"pivots": [
{"index": "sex", "columns": "pclass", "value": "fare",
"agg": "mean", "why": "tarifa por sexo y clase",
"result": _pivot_result()},
],
},
"agg_insights": [
{"title": "Tarifa por sexo",
"text": "Las mujeres pagaron de media casi el doble que los hombres."},
],
}
def _pdf_text(path: str) -> str:
txt = "".join((pg.extract_text() or "") for pg in PdfReader(path).pages)
return re.sub(r"\s+", " ", txt)
def _pptx_text(path: str) -> str:
prs = Presentation(path)
parts = []
for sl in prs.slides:
for sh in sl.shapes:
if sh.has_text_frame:
parts.append(sh.text_frame.text)
if sh.has_table:
tb = sh.table
for r in range(len(tb.rows)):
for c in range(len(tb.columns)):
parts.append(tb.cell(r, c).text)
return re.sub(r"\s+", " ", " ".join(parts))
# --------------------------------------------------------------------------- #
# Golden: builds a Chapter and renders to both formats.
# --------------------------------------------------------------------------- #
def test_golden_chapter_blocks_present():
ch = build_agregacion(_profile(), _ctx_precomputed())
assert isinstance(ch, Chapter)
assert ch.id == "agregacion"
kinds = [b.kind for b in ch.blocks]
assert "heading" in kinds
assert kinds.count("data_table") >= 3 # 2 group summaries + pivot (+details)
assert "figure" in kinds # at least one bar chart.
# Headings mention the group keys and the pivot.
htext = " ".join(b.text for b in ch.blocks if b.kind == "heading")
assert "sex" in htext and "pclass" in htext and "Pivot" in htext
def test_golden_render_pdf():
ch = build_agregacion(_profile(), _ctx_precomputed())
with tempfile.TemporaryDirectory() as d:
out = os.path.join(d, "agg.pdf")
res = render_automatic_eda_pdf([ch], out, {"write_manifest": False})
assert res["path"] == out and os.path.exists(out)
assert res["n_pages"] >= 1
txt = _pdf_text(out)
assert "Agregación por grupos" in txt
assert "male" in txt and "female" in txt # group + pivot labels.
assert "Pivot" in txt
assert "mediana" in txt # per-measure detail.
assert "casi el doble" in txt # interpretation kept.
def test_golden_render_pptx():
ch = build_agregacion(_profile(), _ctx_precomputed())
with tempfile.TemporaryDirectory() as d:
out = os.path.join(d, "agg.pptx")
res = render_automatic_eda_pptx([ch], out, {"write_manifest": False})
assert res["path"] == out and os.path.exists(out)
assert res["n_slides"] >= 1
txt = _pptx_text(out)
assert "male" in txt and "pclass" in txt
assert "Pivot" in txt or "sex" in txt
# --------------------------------------------------------------------------- #
# Edges.
# --------------------------------------------------------------------------- #
def test_edge_no_categorical_returns_none():
# Only numerics + an id: nothing to group by -> chapter does not apply.
prof = {
"table": "t", "n_rows": 100, "key_candidates": ["id"],
"columns": [
{"name": "id", "inferred_type": "numeric", "unique_pct": 1.0,
"flags": ["possible_id"], "numeric": {"std": 10.0}},
{"name": "x", "inferred_type": "numeric", "flags": [],
"numeric": {"mean": 1.0, "std": 2.0}},
],
}
assert build_agregacion(prof, {}) is None
def test_edge_empty_profile_returns_none():
assert build_agregacion({}, {}) is None
assert build_agregacion(None, None) is None
def test_edge_high_cardinality_only_returns_none():
# The single categorical is id-like (high cardinality) -> not groupable.
prof = {
"table": "t", "n_rows": 100, "key_candidates": ["uuid"],
"columns": [
{"name": "uuid", "inferred_type": "categorical", "distinct_count": 100,
"flags": ["high_cardinality", "possible_id"]},
{"name": "x", "inferred_type": "numeric", "flags": [],
"numeric": {"mean": 1.0, "std": 2.0}},
],
}
assert build_agregacion(prof, {}) is None
def test_live_without_data_degrades_to_note():
# Has a categorical to group by but no db_path / no precomputed results:
# must NOT raise and must emit an honest note (chapter still applies).
prof = {
"table": "t", "n_rows": 100, "key_candidates": [],
"columns": [
{"name": "grp", "inferred_type": "categorical", "distinct_count": 3,
"flags": [], "categorical": {"n_distinct": 3}},
{"name": "v", "inferred_type": "numeric", "flags": [],
"numeric": {"mean": 1.0, "std": 2.0}},
],
}
ch = build_agregacion(prof, {})
assert isinstance(ch, Chapter)
notes = [b.text for b in ch.blocks if b.kind == "note"]
assert any("datos crudos" in n for n in notes)
# --------------------------------------------------------------------------- #
# No-cut: many groups + long text survive intact in the PDF.
# --------------------------------------------------------------------------- #
def test_anti_corte_muchos_grupos_y_texto_largo():
keys_n = [(f"grupo_{i:02d}", 30 - (i % 5)) for i in range(30)]
long_text = " ".join(f"palabra{i}" for i in range(120))
ctx = {
"aggregations": {
"groupby": [
{"group_by": "cat", "measures": ["fare"], "why": "muchos niveles",
"result": _groupby_result("cat", keys_n)},
],
"pivots": [],
},
"agg_insights": [{"title": "Nota larga", "text": long_text}],
}
ch = build_agregacion(_profile(), ctx)
with tempfile.TemporaryDirectory() as d:
out = os.path.join(d, "big.pdf")
res = render_automatic_eda_pdf([ch], out, {"write_manifest": False})
assert res["path"] == out
assert res["n_pages"] > 1 # 30-row table + figure spill across pages.
txt = _pdf_text(out)
# First and last group labels both survive (table not truncated).
assert "grupo_00" in txt and "grupo_29" in txt
# First, middle and last words of the long paragraph all present.
for i in (0, 60, 119):
assert f"palabra{i}" in txt
@@ -0,0 +1,477 @@
"""Geospatial chapter (GEOSPATIAL) for AutomaticEDA.
When the dataset carries a coordinate pair (latitude/longitude), this chapter
draws the points on a **geographic scatter** in an equirectangular projection
(scaled so degrees of longitude are not stretched at the data's latitude) and
analyses the **zone / country** the points fall in: bounding box, centroid,
geographic span, and a per-region count. When there is **no** coordinate pair the
chapter returns ``None`` — exactly the user requirement.
Detection and the heavy lifting are delegated to pure ``eda``-group registry
functions, never reimplemented here:
- ``detect_latlon_columns`` — finds the (lat, lon) column pair by name + value
range from the ``profile['columns']`` metadata.
- ``analyze_geo_extent`` — bbox, centroid, haversine span, per-region counts and
hemisphere from the raw coordinate arrays.
- ``build_geo_scatter`` — deterministically down-sampled points + bbox + the
aspect ratio for the equirectangular projection. This chapter only draws the
matplotlib figure from that prepared data (same split as ``num_distr`` does
with ``build_boxplot_stats``).
The raw coordinate arrays are **not** in a standard TableProfile (it stores only
per-column aggregates), so — exactly like ``modelos`` reads ``raw_numeric`` from
``ctx`` — this chapter looks for the coordinates in ``ctx`` (or ``profile``) and
degrades honestly when they are absent: it still detects the columns and shows an
approximate bounding box derived from the per-column ``numeric.min/max``, with a
note that the raw points are needed for the map.
ctx keys this chapter consumes (all optional):
geo_points : dict — ``{"lats": [...], "lons": [...]}`` raw coordinate arrays.
Used directly when present (forward-compatible with a calculation phase
that samples them from the table).
raw_numeric : dict — ``{col: [values]}`` raw numeric columns; when present
and ``geo_points`` is not, the detected lat/lon columns are read from it.
run_geo_llm : bool — when True, call ``ask_llm`` for a one-line narrative of
where the points concentrate (otherwise a derived note is used).
geo_llm_model : str — model id for the optional live LLM call.
Contract: build_<id>(profile, ctx) -> Chapter | None ; CHAPTER_VERSION = "x.y.z".
Reads everything defensively (``.get``) and never raises.
"""
from __future__ import annotations
import math
from .. import model
# Pure registry functions (group ``eda``) delegated to. Imported defensively so
# the chapter stays importable (degrading gracefully) if one is unavailable.
try:
from datascience.detect_latlon_columns import detect_latlon_columns
except Exception: # noqa: BLE001 — keep the chapter importable no matter what.
detect_latlon_columns = None # type: ignore[assignment]
try:
from datascience.analyze_geo_extent import analyze_geo_extent
except Exception: # noqa: BLE001
analyze_geo_extent = None # type: ignore[assignment]
try:
from datascience.build_geo_scatter import build_geo_scatter
except Exception: # noqa: BLE001
build_geo_scatter = None # type: ignore[assignment]
CHAPTER_VERSION = "1.0.0"
CHAPTER_ID = "geospatial"
CHAPTER_TITLE = "Análisis geoespacial"
# --------------------------------------------------------------------------- #
# Formatting helpers (mirror the other chapters' defensive style).
# --------------------------------------------------------------------------- #
def _fmt_num(value, decimals: int = 4) -> str:
if value is None:
return ""
if isinstance(value, bool):
return "" if value else "no"
if isinstance(value, int):
return f"{value:,}".replace(",", ".")
if isinstance(value, float):
if value != value: # NaN
return "NaN"
if value in (float("inf"), float("-inf")):
return str(value)
text = f"{value:.{decimals}f}".rstrip("0").rstrip(".")
return text if text else "0"
return model._safe_str(value)
def _fmt_coord(value, decimals: int = 4) -> str:
"""Format a coordinate degree value, defensively."""
try:
return f"{float(value):.{decimals}f}°"
except (TypeError, ValueError):
return model._safe_str(value)
def _fmt_km(value) -> str:
if value is None:
return ""
try:
v = float(value)
except (TypeError, ValueError):
return model._safe_str(value)
if v >= 100:
return f"{v:,.0f} km".replace(",", ".")
return f"{v:.1f} km"
def _is_dict(v) -> bool:
return isinstance(v, dict)
def _clean_floats(seq) -> list:
"""Return a list of floats from an arbitrary sequence (drop None/NaN)."""
out = []
if not isinstance(seq, (list, tuple)):
return out
for v in seq:
try:
f = float(v)
except (TypeError, ValueError):
out.append(None)
continue
out.append(f if f == f else None) # NaN -> None
return out
# --------------------------------------------------------------------------- #
# Resolve the (lat, lon) columns and the raw coordinate arrays.
# --------------------------------------------------------------------------- #
def _detect_columns(profile: dict) -> dict:
"""Detect the lat/lon column pair from the profile metadata, or {}."""
cols = profile.get("columns")
if not isinstance(cols, list) or not cols or detect_latlon_columns is None:
return {}
try:
det = detect_latlon_columns(cols)
except Exception: # noqa: BLE001 — never break the chapter.
return {}
return det if _is_dict(det) else {}
def _resolve_coords(profile: dict, ctx: dict, detected: dict):
"""Return (lats, lons, source_label).
Order: ctx/profile['geo_points'] (explicit arrays) → ctx/profile
['raw_numeric'] keyed by the detected lat/lon column names → (None, None).
"""
gp = ctx.get("geo_points") or profile.get("geo_points")
if _is_dict(gp):
lats = gp.get("lats")
if lats is None:
lats = gp.get("lat")
lons = gp.get("lons")
if lons is None:
lons = gp.get("lon")
if lats and lons:
return list(lats), list(lons), "geo_points"
lat_col = (detected or {}).get("lat_col")
lon_col = (detected or {}).get("lon_col")
if lat_col and lon_col:
raw = ctx.get("raw_numeric") or profile.get("raw_numeric")
if _is_dict(raw):
lats = raw.get(lat_col)
lons = raw.get(lon_col)
if lats and lons:
return list(lats), list(lons), "raw_numeric"
return None, None, "none"
def _column_by_name(profile: dict, name):
if not name:
return None
for col in profile.get("columns") or []:
if isinstance(col, dict) and col.get("name") == name:
return col
return None
def _bbox_from_profile(profile: dict, detected: dict):
"""Approximate bbox from the per-column numeric.min/max (no raw points)."""
lat_c = _column_by_name(profile, (detected or {}).get("lat_col"))
lon_c = _column_by_name(profile, (detected or {}).get("lon_col"))
lat_n = lat_c.get("numeric") if _is_dict(lat_c) else None
lon_n = lon_c.get("numeric") if _is_dict(lon_c) else None
if not _is_dict(lat_n) or not _is_dict(lon_n):
return None
try:
return {
"lat_min": float(lat_n.get("min")),
"lat_max": float(lat_n.get("max")),
"lon_min": float(lon_n.get("min")),
"lon_max": float(lon_n.get("max")),
}
except (TypeError, ValueError):
return None
# --------------------------------------------------------------------------- #
# Figure builder (lazy: matplotlib only imported when the renderer draws it).
# --------------------------------------------------------------------------- #
def _make_geo_scatter(scatter: dict, lat_col: str, lon_col: str):
"""Return a zero-arg callable drawing the geographic scatter, or None."""
points = scatter.get("points") or []
if not points:
return None
bbox = scatter.get("bbox") if _is_dict(scatter.get("bbox")) else {}
aspect = scatter.get("aspect") or 1.0
pad = scatter.get("pad") if _is_dict(scatter.get("pad")) else {}
n_total = scatter.get("n_total")
n_shown = scatter.get("n_shown")
def _draw():
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
xs = [p[0] for p in points if isinstance(p, (list, tuple)) and len(p) >= 2]
ys = [p[1] for p in points if isinstance(p, (list, tuple)) and len(p) >= 2]
fig, ax = plt.subplots(figsize=(6.6, 5.0))
# More points -> smaller markers + lower alpha so dense clouds read as
# density without saturating the page with ink (Tufte).
n = max(len(xs), 1)
size = 18 if n <= 200 else (8 if n <= 1000 else 4)
alpha = 0.75 if n <= 200 else (0.5 if n <= 1000 else 0.35)
ax.scatter(xs, ys, s=size, c="#2a6f97", alpha=alpha, linewidths=0,
zorder=3)
# Bounding box rectangle for orientation.
if bbox:
try:
lo_x, hi_x = float(bbox["lon_min"]), float(bbox["lon_max"])
lo_y, hi_y = float(bbox["lat_min"]), float(bbox["lat_max"])
ax.plot([lo_x, hi_x, hi_x, lo_x, lo_x],
[lo_y, lo_y, hi_y, hi_y, lo_y],
color="#e15759", linewidth=1.0, linestyle="--",
alpha=0.8, zorder=4, label="Bounding box")
px = float(pad.get("lon", 0.0) or 0.0)
py = float(pad.get("lat", 0.0) or 0.0)
ax.set_xlim(lo_x - px, hi_x + px)
ax.set_ylim(lo_y - py, hi_y + py)
except (TypeError, ValueError, KeyError):
pass
# Equirectangular: scale Y/X so longitude is not stretched at this
# latitude (integridad de proyección, Tufte). aspect = 1/cos(lat).
try:
ax.set_aspect(float(aspect))
except (TypeError, ValueError):
pass
ax.set_xlabel(f"Longitud ({lon_col})", fontsize=8)
ax.set_ylabel(f"Latitud ({lat_col})", fontsize=8)
ax.tick_params(labelsize=7)
ax.grid(color="#e6e6e6", linewidth=0.5, zorder=0)
title = "Distribución geográfica de las coordenadas"
if n_shown is not None and n_total is not None and n_shown < n_total:
title += f"\n(mostrando {n_shown:,} de {n_total:,} puntos)".replace(",", ".")
ax.set_title(title, fontsize=10)
ax.legend(loc="best", fontsize=7, frameon=True, framealpha=0.9)
fig.tight_layout()
return fig
return _draw
# --------------------------------------------------------------------------- #
# Section builders.
# --------------------------------------------------------------------------- #
def _intro_block(detected: dict, lat_col: str, lon_col: str) -> list:
conf = (detected or {}).get("confidence")
reason = model._safe_str((detected or {}).get("reason"))
conf_txt = ""
if conf is not None:
try:
conf_txt = f" (confianza {float(conf) * 100:.0f}%)"
except (TypeError, ValueError):
conf_txt = ""
text = (
"Este dataset contiene **coordenadas geográficas**: se identificó el par "
f"**latitud = «{lat_col}»** y **longitud = «{lon_col}»**{conf_txt}. La "
"detección combina el nombre de la columna y el rango de sus valores "
"(latitud en [90, 90], longitud en [180, 180])."
)
if reason:
text += f"\n\n*Criterio de detección:* {reason}."
return [model.Heading(text=CHAPTER_TITLE, level=1),
model.Markdown(text=text)]
def _extent_blocks(extent: dict) -> list:
"""KVTable with bbox/centroid/span + DataTable with the per-region counts."""
if not _is_dict(extent) or not extent.get("n_points"):
return []
blocks = []
bbox = extent.get("bbox") if _is_dict(extent.get("bbox")) else {}
centroid = extent.get("centroid") if _is_dict(extent.get("centroid")) else {}
hemi = extent.get("hemisphere") if _is_dict(extent.get("hemisphere")) else {}
rows = [("Puntos con coordenadas", _fmt_num(extent.get("n_points")))]
if bbox:
rows.append(("Latitud (mín. / máx.)",
f"{_fmt_coord(bbox.get('lat_min'))} a "
f"{_fmt_coord(bbox.get('lat_max'))}"))
rows.append(("Longitud (mín. / máx.)",
f"{_fmt_coord(bbox.get('lon_min'))} a "
f"{_fmt_coord(bbox.get('lon_max'))}"))
if centroid:
rows.append(("Centroide",
f"{_fmt_coord(centroid.get('lat'))}, "
f"{_fmt_coord(centroid.get('lon'))}"))
if extent.get("span_km") is not None:
rows.append(("Extensión (diagonal)", _fmt_km(extent.get("span_km"))))
if hemi:
n, s = hemi.get("north"), hemi.get("south")
e, w = hemi.get("east"), hemi.get("west")
rows.append(("Hemisferios",
f"N {_fmt_num(n)} / S {_fmt_num(s)} · "
f"E {_fmt_num(e)} / O {_fmt_num(w)}"))
blocks.append(model.KVTable(rows=rows, title="Extensión geográfica"))
by_region = extent.get("by_region")
if isinstance(by_region, list) and by_region:
total = sum(r.get("count", 0) for r in by_region if _is_dict(r)) or 0
rrows = []
for r in by_region:
if not _is_dict(r):
continue
cnt = r.get("count", 0)
pct = (cnt / total) if total else None
pct_txt = f"{pct * 100:.1f}%" if pct is not None else ""
rrows.append([model._safe_str(r.get("region")), _fmt_num(cnt),
pct_txt])
if rrows:
blocks.append(model.DataTable(
header=["Zona / país", "Puntos", "% del total"], rows=rrows,
title="Distribución por zona",
note="Asignación aproximada por bounding box de cada región "
"(no es reverse-geocoding exacto de fronteras)."))
return blocks
def _narrative_block(profile: dict, ctx: dict, extent: dict) -> list:
"""A one-line narrative of where the points concentrate.
Uses the derived ``note`` from analyze_geo_extent by default; optionally
calls an LLM (ctx['run_geo_llm']) for a richer one-liner.
"""
note = model._safe_str((extent or {}).get("note"))
if ctx.get("run_geo_llm"):
by_region = (extent or {}).get("by_region") or []
bbox = (extent or {}).get("bbox") or {}
try:
from core.ask_llm import ask_llm
prompt = (
"Eres un analista de datos. En UNA frase en español, describe "
"dónde se concentran geográficamente estos puntos. Sé concreto "
"y no inventes precisión que los datos no tienen.\n"
f"Conteo por zona: {by_region}\nBounding box: {bbox}."
)
out = ask_llm(prompt,
model=ctx.get("geo_llm_model",
"claude-haiku-4-5-20251001"),
echo=False)
if out and isinstance(out, str) and out.strip():
note = out.strip()
except Exception: # noqa: BLE001 — degrade to the derived note.
pass
if not note:
return []
return [model.Markdown(text=f"**Interpretación.** {note}")]
def _no_points_block(profile: dict, detected: dict) -> list:
"""Degrade honestly when the raw coordinate arrays are not available."""
blocks = []
bbox = _bbox_from_profile(profile, detected)
if bbox:
rows = [
("Latitud (mín. / máx.)",
f"{_fmt_coord(bbox.get('lat_min'))} a "
f"{_fmt_coord(bbox.get('lat_max'))}"),
("Longitud (mín. / máx.)",
f"{_fmt_coord(bbox.get('lon_min'))} a "
f"{_fmt_coord(bbox.get('lon_max'))}"),
]
blocks.append(model.KVTable(
rows=rows, title="Extensión geográfica (aproximada)"))
blocks.append(model.Note(
"No se incluyeron las coordenadas crudas en el contexto, por lo que el "
"mapa y el análisis por zona no se han dibujado. El bounding box "
"mostrado se deriva de los mínimos y máximos por columna. Para el "
"scatter geográfico completo, pasa los arrays en "
"ctx['geo_points'] = {'lats': [...], 'lons': [...]} o las columnas en "
"ctx['raw_numeric']."))
return blocks
# --------------------------------------------------------------------------- #
# Entry point.
# --------------------------------------------------------------------------- #
def build_geospatial(profile: dict, ctx: dict):
"""Build the GEOSPATIAL Chapter, or None if the dataset has no coordinates.
Args:
profile: the ``eda`` group TableProfile dict.
ctx: presentation context; may carry ``geo_points``/``raw_numeric`` with
the raw coordinate arrays and the ``run_geo_llm`` flag.
Returns:
A ``model.Chapter`` with the geographic scatter + zone/country analysis,
or ``None`` when no latitude/longitude column pair is detected.
"""
profile = profile or {}
ctx = ctx or {}
if not isinstance(profile, dict):
return None
detected = _detect_columns(profile)
lats, lons, source = _resolve_coords(profile, ctx, detected)
has_detection = bool((detected or {}).get("lat_col") and
(detected or {}).get("lon_col"))
has_points = bool(lats and lons)
if not has_detection and not has_points:
return None # chapter does not apply: no coordinates in this dataset.
# Labels for axes / intro. When only raw arrays were given (no detection),
# fall back to generic names.
lat_col = (detected or {}).get("lat_col") or "lat"
lon_col = (detected or {}).get("lon_col") or "lon"
blocks = _intro_block(detected, lat_col, lon_col)
if has_points:
clean_lats = _clean_floats(lats)
clean_lons = _clean_floats(lons)
# Zone / country analysis.
extent = {}
if analyze_geo_extent is not None:
try:
extent = analyze_geo_extent(clean_lats, clean_lons) or {}
except Exception: # noqa: BLE001
extent = {}
# The geographic scatter figure (its own page/slide).
scatter = {}
if build_geo_scatter is not None:
try:
scatter = build_geo_scatter(clean_lats, clean_lons) or {}
except Exception: # noqa: BLE001
scatter = {}
maker = _make_geo_scatter(scatter, lat_col, lon_col) if scatter else None
if maker is not None:
blocks.append(model.Figure(
make=maker,
caption="Cada punto es una observación situada por sus "
"coordenadas; el recuadro rojo es el bounding box. La "
"escala respeta la latitud (proyección equirectangular)."))
else:
blocks.append(model.Note(
"No se pudo construir el scatter geográfico a partir de las "
"coordenadas proporcionadas."))
blocks += _extent_blocks(extent)
blocks += _narrative_block(profile, ctx, extent)
else:
# Columns detected but no raw points available — degrade honestly.
blocks += _no_points_block(profile, detected)
if not blocks:
return None
return model.Chapter(id=CHAPTER_ID, title=CHAPTER_TITLE,
version=CHAPTER_VERSION, blocks=blocks)
@@ -0,0 +1,245 @@
"""Tests for the GEOSPATIAL chapter — DoD: golden + edges + anti-cut.
Self-contained: builds synthetic TableProfiles (no DuckDB) so the suite is fast
and deterministic. The raw coordinate arrays are passed through ``ctx`` exactly
as the chapter's contract documents (``ctx['geo_points']`` / ``ctx['raw_numeric']``).
Verifies that the chapter detects the lat/lon pair, draws the geographic scatter
figure, analyses the zone/country (bounding box + per-region counts), returns
None when there are no coordinates, degrades honestly when the raw points are
absent, and that a profile with long column names + many points + several
regions renders to PDF and PPTX without cutting any text (long content wraps, it
is never truncated).
"""
import os
import re
import tempfile
from pypdf import PdfReader
from pptx import Presentation
from datascience.automatic_eda.chapters.geospatial import (
build_geospatial,
CHAPTER_VERSION,
)
from datascience.automatic_eda import build_document, render_pdf, render_pptx
# --------------------------------------------------------------------------- #
# Synthetic data helpers
# --------------------------------------------------------------------------- #
def _grid(lat0: float, lon0: float, n: int, spread: float = 1.0):
"""A small deterministic cloud of n points around (lat0, lon0)."""
lats, lons = [], []
for i in range(n):
# deterministic pseudo-spread, no randomness.
f = (i % 11) / 11.0 - 0.5
g = (i % 7) / 7.0 - 0.5
lats.append(lat0 + f * spread)
lons.append(lon0 + g * spread)
return lats, lons
def _profile_with_coords(lat_name="lat", lon_name="lon", lats=None, lons=None):
"""A profile carrying a lat/lon column pair with valid ranges."""
lats = lats if lats is not None else [40.4, 41.0, 39.8, 40.1]
lons = lons if lons is not None else [-3.7, -3.6, -4.0, -3.9]
return {
"table": "lugares",
"columns": [
{"name": lat_name, "inferred_type": "numeric",
"numeric": {"min": min(lats), "max": max(lats),
"mean": sum(lats) / len(lats)}},
{"name": lon_name, "inferred_type": "numeric",
"numeric": {"min": min(lons), "max": max(lons),
"mean": sum(lons) / len(lons)}},
{"name": "valor", "inferred_type": "numeric",
"numeric": {"min": 0, "max": 100, "mean": 50}},
],
}
def _ctx_points(lats, lons):
return {"geo_points": {"lats": lats, "lons": lons}}
def _kinds(chapter):
return [getattr(b, "kind", None) for b in chapter.blocks]
def _tables(chapter):
return [b for b in chapter.blocks if getattr(b, "kind", None) == "data_table"]
def _figures(chapter):
return [b for b in chapter.blocks if getattr(b, "kind", None) == "figure"]
# --------------------------------------------------------------------------- #
# Golden
# --------------------------------------------------------------------------- #
def test_golden_estructura_y_version():
lats, lons = [40.4, 41.0, 39.8, 40.1], [-3.7, -3.6, -4.0, -3.9]
ch = build_geospatial(_profile_with_coords(lats=lats, lons=lons),
_ctx_points(lats, lons))
assert ch is not None
assert ch.id == "geospatial"
assert ch.version == CHAPTER_VERSION
kinds = _kinds(ch)
# intro heading + markdown + scatter figure + extent kv + per-region table.
assert "heading" in kinds
assert "markdown" in kinds
assert "figure" in kinds, "falta el scatter geográfico"
assert "kv_table" in kinds, "falta la tabla de extensión"
def test_golden_detecta_columnas_y_nombra_ejes():
lats, lons = _grid(40.4, -3.7, 30, spread=0.8)
prof = _profile_with_coords("latitude", "longitude", lats, lons)
ch = build_geospatial(prof, _ctx_points(lats, lons))
intro = [b for b in ch.blocks if b.kind == "markdown"][0].text
assert "latitude" in intro and "longitude" in intro
def test_golden_figura_es_perezosa_y_dibujable():
lats, lons = _grid(40.4, -3.7, 50, spread=0.6)
ch = build_geospatial(_profile_with_coords(lats=lats, lons=lons),
_ctx_points(lats, lons))
fig_block = _figures(ch)[0]
assert fig_block.make is not None and fig_block.fig is None # lazy
fig = fig_block.make() # must draw without raising
assert fig is not None
import matplotlib.pyplot as plt
plt.close(fig)
def test_golden_analisis_por_zona_espana():
lats, lons = _grid(40.4, -3.7, 40, spread=0.5) # Madrid area
ch = build_geospatial(_profile_with_coords(lats=lats, lons=lons),
_ctx_points(lats, lons))
tables = _tables(ch)
region_tbl = [t for t in tables if "zona" in (t.title or "").lower()]
assert region_tbl, "falta la tabla por zona/país"
flat = " ".join(" ".join(str(c) for c in r) for r in region_tbl[0].rows)
# Spain-area points must resolve to a Spain/European region, not empty.
assert region_tbl[0].rows
assert any(c for c in (region_tbl[0].rows[0]))
def test_golden_raw_numeric_source():
"""Coordinates can also come from ctx['raw_numeric'] keyed by detected cols."""
lats, lons = _grid(48.85, 2.35, 25, spread=0.4) # Paris area
prof = _profile_with_coords("lat", "lon", lats, lons)
ctx = {"raw_numeric": {"lat": lats, "lon": lons}}
ch = build_geospatial(prof, ctx)
assert ch is not None
assert _figures(ch), "el scatter debe construirse desde raw_numeric"
# --------------------------------------------------------------------------- #
# Edges
# --------------------------------------------------------------------------- #
def test_edge_sin_coordenadas_devuelve_none():
prof = {
"table": "ventas",
"columns": [
{"name": "precio", "inferred_type": "numeric",
"numeric": {"min": 0, "max": 1000}},
{"name": "categoria", "inferred_type": "text"},
],
}
assert build_geospatial(prof, {}) is None
def test_edge_none_y_vacio_no_rompen():
assert build_geospatial(None, None) is None
assert build_geospatial({}, {}) is None
assert build_geospatial({"columns": []}, {}) is None
assert build_geospatial("not a dict", {}) is None
def test_edge_nombre_lat_pero_rango_invalido_no_aplica():
"""A column named 'lat' whose values are out of [-90,90] is NOT a coordinate."""
prof = {
"table": "x",
"columns": [
{"name": "lat", "inferred_type": "numeric",
"numeric": {"min": 1000, "max": 9999}},
{"name": "lon", "inferred_type": "numeric",
"numeric": {"min": 1000, "max": 9999}},
],
}
assert build_geospatial(prof, {}) is None
def test_edge_columnas_detectadas_sin_puntos_degrada():
"""Detected lat/lon but no raw arrays -> honest note + approx bbox, no crash."""
prof = _profile_with_coords(lats=[40.0, 41.0], lons=[-3.0, -4.0])
ch = build_geospatial(prof, {}) # no geo_points / raw_numeric
assert ch is not None
assert not _figures(ch), "sin puntos no debe dibujarse el scatter"
notes = [b for b in ch.blocks if b.kind == "note"]
assert notes and "coordenadas crudas" in notes[0].text
def test_edge_coordenadas_con_nan_se_filtran():
lats = [40.4, float("nan"), 41.0, None, 39.8]
lons = [-3.7, -3.6, float("nan"), -3.9, -4.0]
ch = build_geospatial(_profile_with_coords(lats=[39.8, 41.0],
lons=[-4.0, -3.6]),
_ctx_points(lats, lons))
assert ch is not None # must not raise on NaN/None
# --------------------------------------------------------------------------- #
# Anti-cut: long names + many points + several regions render without truncation
# --------------------------------------------------------------------------- #
def _multiregion_points(per: int = 700):
"""Points spread across Spain, France and the USA to fill the region table."""
lats, lons = [], []
for (la, lo) in ((40.4, -3.7), (48.85, 2.35), (39.0, -98.0)):
gl, gn = _grid(la, lo, per, spread=2.0)
lats += gl
lons += gn
return lats, lons
def test_anticut_pdf_y_pptx_no_truncan():
lat_name = "latitud_geografica_del_punto_de_observacion_registrado"
lon_name = "longitud_geografica_del_punto_de_observacion_registrado"
lats, lons = _multiregion_points(700)
prof = _profile_with_coords(lat_name, lon_name, lats, lons)
ctx = {"geo_points": {"lats": lats, "lons": lons}}
full = build_document(prof, ctx)
assert any(c.id == "geospatial" for c in full)
chapters = [c for c in full if c.id == "geospatial"]
with tempfile.TemporaryDirectory() as d:
pdf = os.path.join(d, "g.pdf")
pptx = os.path.join(d, "g.pptx")
rp = render_pdf(chapters, pdf, {"title": "EDA"})
rx = render_pptx(chapters, pptx, {"title": "EDA"})
assert os.path.exists(pdf) and os.path.exists(pptx)
assert (rp or {}).get("n_pages", 0) >= 1
# PDF: the long lat column name survives whole (wraps, not cut) and there
# is no truncation marker in this chapter.
pdf_txt = "".join((pg.extract_text() or "") for pg in PdfReader(pdf).pages)
assert "" not in pdf_txt and "..." not in pdf_txt
norm = re.sub(r"\s+", "", pdf_txt)
assert lat_name in norm, "el nombre largo de la columna se cortó en el PDF"
# PPTX: long name present in some shape/cell, untruncated.
allt = []
for s in Presentation(pptx).slides:
for sh in s.shapes:
if sh.has_text_frame:
allt.append(sh.text_frame.text)
if sh.has_table:
for row in sh.table.rows:
for c in row.cells:
allt.append(c.text)
joined = re.sub(r"\s+", "", "\n".join(allt))
assert lat_name in joined, "el nombre largo de la columna se cortó en el PPTX"
@@ -0,0 +1,68 @@
---
name: build_geo_scatter
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: pure
signature: "def build_geo_scatter(lats: list, lons: list, max_points: int = 2000) -> dict"
description: "Prepara los datos de un scatter geografico en proyeccion equirectangular para el grupo eda. Empareja lats/lons por indice, descarta pares None/NaN/inf/bool o fuera de rango (lat en [-90,90], lon en [-180,180]) y aplica downsampling DETERMINISTA por paso fijo (pairs[::step]) cuando hay mas pares validos que max_points, para no saturar el PDF/PPTX en moviles. Devuelve los puntos en orden [lon, lat] listos para ax.scatter, el bbox, el aspect 1/cos(centroid_lat) clampado a [0.3,5.0] y un pad sugerido (~5% del rango con suelo minimo). Lectura defensiva; NUNCA lanza ni dibuja: el capitulo se encarga de matplotlib."
tags: [eda, geospatial, datascience, scatter, map, downsample, equirectangular, profiling]
params:
- name: lats
desc: "Lista (o tupla) de latitudes en grados, paralela a lons. Se empareja por indice. Un valor None, NaN, infinito, bool o fuera de [-90,90] descarta ese par. Lectura defensiva."
- name: lons
desc: "Lista (o tupla) de longitudes en grados, paralela a lats. Un valor None, NaN, infinito, bool o fuera de [-180,180] descarta ese par."
- name: max_points
desc: "Tope de puntos a devolver (default 2000). Si los pares validos superan el tope, se hace downsampling determinista por paso fijo step=ceil(n_total/max_points) tomando pairs[::step] (NO aleatorio, reproducible). Un valor no entero o <=0 desactiva el downsampling."
output: "Dict listo para dibujar: {points: [[lon, lat], ...] en orden x=lon/y=lat para ax.scatter; n_total: pares validos antes del downsample (int); n_shown: puntos devueltos tras el downsample (int); downsampled: bool (n_shown<n_total); bbox: {lat_min, lat_max, lon_min, lon_max} o None si no hay puntos; aspect: 1/cos(centroid_lat) clampado a [0.3,5.0] para no estirar la proyeccion equirectangular; pad: {lon, lat} ~5% del rango respectivo con suelo minimo 0.01 grados}. Si no hay pares validos: points=[], n_total=0, n_shown=0, downsampled=False, bbox=None, aspect=1.0, pad={lon:0.0, lat:0.0}."
uses_functions: []
uses_types: []
returns: []
returns_optional: false
error_type: ""
imports: []
tested: true
tests: ["test_geo_scatter_nube_espana", "test_downsampling_determinista_y_reproducible", "test_listas_vacias_no_lanza", "test_un_solo_punto_pad_minimo_y_aspect_finito", "test_filtra_none_nan_y_fuera_de_rango", "test_latitud_alta_aspect_clamped"]
test_file_path: "python/functions/datascience/build_geo_scatter_test.py"
file_path: "python/functions/datascience/build_geo_scatter.py"
---
## Ejemplo
```python
import sys, os
sys.path.insert(0, os.path.join("python", "functions"))
from datascience.build_geo_scatter import build_geo_scatter
# Nube de coordenadas (lat, lon) alrededor de Madrid:
lats = [40.0, 41.0, 39.0, 40.5]
lons = [-3.7, -3.0, -4.0, -3.5]
geo = build_geo_scatter(lats, lons, max_points=2000)
print(geo["points"][0]) # [-3.7, 40.0] -> orden [x=lon, y=lat]
print(geo["bbox"]) # {'lat_min': 39.0, 'lat_max': 41.0, 'lon_min': -4.0, 'lon_max': -3.0}
print(round(geo["aspect"], 3)) # 1.308 -> ensancha el eje x en latitudes medias
print(geo["pad"]) # {'lon': 0.05, 'lat': 0.1} -> margen ~5%
# El capitulo dibuja con matplotlib (esta funcion NO dibuja):
# xs = [p[0] for p in geo["points"]]; ys = [p[1] for p in geo["points"]]
# ax.scatter(xs, ys); ax.set_aspect(geo["aspect"])
# ax.set_xlim(geo["bbox"]["lon_min"] - geo["pad"]["lon"], geo["bbox"]["lon_max"] + geo["pad"]["lon"])
# ax.set_ylim(geo["bbox"]["lat_min"] - geo["pad"]["lat"], geo["bbox"]["lat_max"] + geo["pad"]["lat"])
```
## Cuando usarla
- Usala antes de dibujar un scatter geografico (mapa de puntos en proyeccion equirectangular) en el capitulo geospatial de `AutomaticEDA`: limpia los pares de coordenadas, los reduce a un tamano razonable para el PDF/PPTX y te da bbox, aspect y pad listos para fijar los ejes.
- Cuando tengas dos columnas de lat/lon ya extraidas y quieras un punto de entrada determinista (mismo dataset -> mismo dibujo) que no sature el documento en moviles.
- Cuando necesites el aspect correcto para que un grado de longitud no se vea estirado respecto a uno de latitud (integridad visual, Tufte) sin calcularlo a mano.
## Gotchas
- Funcion pura, sin I/O y determinista. NO dibuja: solo PREPARA los datos; el capitulo se encarga de matplotlib. Lectura defensiva: pares con None/NaN/inf/bool o coordenadas fuera de rango se descartan en silencio y NUNCA lanza.
- El downsampling es DETERMINISTA por paso fijo (`step = ceil(n_total / max_points)`, `pairs[::step]`), NO aleatorio: la misma entrada produce siempre la misma salida (reproducible en tests). El primer punto mostrado es siempre el primer par valido. No es un muestreo uniforme aleatorio — es un barrido regular del orden de entrada.
- `points` va en orden `[lon, lat]` (x, y), no `[lat, lon]`: pasalo directo a `ax.scatter(xs, ys)` sin invertir. Confundir el orden espeja el mapa.
- `aspect = 1/cos(centroid_lat)` se clampa a `[0.3, 5.0]`. En latitudes altas `cos -> 0` y el valor real explota: por encima de ~78 grados el aspect queda fijado en 5.0. Si el centroide cae justo en un polo (`+-90`) se usa el clamp en vez de dividir por cero.
- `pad` es ~5% del rango de cada eje con un suelo minimo de `0.01` grados: con un solo punto o todos iguales (rango 0) el pad cae al suelo para que el punto no quede en una linea. En el caso sin puntos validos el pad es `{lon:0.0, lat:0.0}` y `bbox` es `None`.
- `bbox`, `aspect` y `pad` se calculan sobre los puntos YA mostrados (tras el downsample), de modo que los ejes encajan exactamente con lo que se dibuja.
@@ -0,0 +1,153 @@
"""build_geo_scatter — prepare points for a geographic scatter (EDA `geospatial`).
Pure function: no I/O, deterministic. Takes two parallel lists of latitudes and
longitudes and returns the data a caller needs to draw a geographic scatter in an
equirectangular projection: cleaned points in [lon, lat] order, a bounding box, a
projection aspect ratio and a suggested axis padding.
It NEVER draws anything (no matplotlib) — the chapter that consumes this output is
responsible for the rendering. Reading is defensive throughout and the function
NEVER raises: malformed pairs (None, NaN, infinity or out-of-range coordinates)
are silently dropped and an empty/valid result is always returned.
To keep the rendered PDF/PPTX light on phones, when the number of valid pairs
exceeds `max_points` the points are down-sampled DETERMINISTICALLY by a fixed
step (`pairs[::step]`), never randomly, so the result is reproducible.
"""
import math
# Minimum axis padding (in degrees) so a single point or a zero-range cloud is
# never drawn glued to the axis border (it would collapse to a line).
_MIN_PAD = 0.01
# Aspect ratio clamp. 1/cos(lat) blows up near the poles; clamp keeps the render
# sane (Tufte: do not let the projection stretch the cloud out of proportion).
_ASPECT_MIN = 0.3
_ASPECT_MAX = 5.0
def _coord(value):
"""Coerce to a finite float defensively; return None for invalid coordinates.
bool is a subclass of int, but a real latitude/longitude is never a bool, so
True/False are treated as missing instead of coercing to 1.0/0.0. NaN and
+/-infinity are never valid coordinates either.
"""
if value is None or isinstance(value, bool):
return None
try:
coord = float(value)
except (TypeError, ValueError):
return None
if math.isnan(coord) or math.isinf(coord):
return None
return coord
def build_geo_scatter(lats: list, lons: list, max_points: int = 2000) -> dict:
"""Prepare the data for a geographic scatter in equirectangular projection.
Pairs `lats` and `lons` by index, drops invalid pairs, optionally
down-samples deterministically, and derives the geometry (bbox, aspect, pad)
a caller needs to draw the cloud. No raw rendering is performed.
Args:
lats: List (or tuple) of latitudes in degrees. Paired by index with
`lons`. A value that is None, NaN, infinite, bool or outside
[-90, 90] discards that pair. Read defensively.
lons: List (or tuple) of longitudes in degrees, parallel to `lats`. A
value outside [-180, 180] (or None/NaN/inf/bool) discards that pair.
max_points: Cap on the number of points returned. When the number of
valid pairs exceeds this cap, the points are down-sampled by a fixed
step `ceil(n_total / max_points)` taking `pairs[::step]` — DETERMINISTIC,
not random, so the output is reproducible. A non-positive or non-int
value disables down-sampling.
Returns:
Dict ready for a caller's ax.scatter:
{points: [[lon, lat], ...] (x=lon, y=lat order), n_total: valid pairs
before down-sampling, n_shown: points returned, downsampled: bool,
bbox: {lat_min, lat_max, lon_min, lon_max} or None, aspect: 1/cos(centroid
lat) clamped to [0.3, 5.0], pad: {lon, lat} ~5% of each range with a small
floor}. When there are no valid pairs returns points=[], n_total=0,
n_shown=0, downsampled=False, bbox=None, aspect=1.0, pad={lon:0.0, lat:0.0}.
"""
pairs = [] # each item is (lon, lat) — already in [x, y] order
if isinstance(lats, (list, tuple)) and isinstance(lons, (list, tuple)):
n = min(len(lats), len(lons))
for i in range(n):
lat = _coord(lats[i])
lon = _coord(lons[i])
if lat is None or lon is None:
continue
if lat < -90.0 or lat > 90.0:
continue
if lon < -180.0 or lon > 180.0:
continue
pairs.append((lon, lat))
n_total = len(pairs)
if n_total == 0:
return {
"points": [],
"n_total": 0,
"n_shown": 0,
"downsampled": False,
"bbox": None,
"aspect": 1.0,
"pad": {"lon": 0.0, "lat": 0.0},
}
# Deterministic down-sampling by a fixed step. Reproducible: same input ->
# same output, no randomness.
if (
isinstance(max_points, int)
and not isinstance(max_points, bool)
and max_points > 0
and n_total > max_points
):
step = math.ceil(n_total / max_points)
sampled = pairs[::step]
else:
sampled = pairs
points = [[lon, lat] for (lon, lat) in sampled]
n_shown = len(points)
downsampled = n_shown < n_total
lons_s = [p[0] for p in sampled]
lats_s = [p[1] for p in sampled]
lon_min, lon_max = min(lons_s), max(lons_s)
lat_min, lat_max = min(lats_s), max(lats_s)
bbox = {
"lat_min": lat_min,
"lat_max": lat_max,
"lon_min": lon_min,
"lon_max": lon_max,
}
# Aspect for an equirectangular projection: stretch the x axis by 1/cos(lat)
# at the cloud centroid so a degree of longitude reads at its real width.
centroid_lat = sum(lats_s) / len(lats_s)
cos_lat = math.cos(math.radians(centroid_lat))
if cos_lat < 1e-12: # centroid at (or numerically at) a pole
aspect = _ASPECT_MAX
else:
aspect = 1.0 / cos_lat
aspect = max(_ASPECT_MIN, min(_ASPECT_MAX, aspect))
# Padding ~5% of each range, with a small floor so a zero-range cloud (single
# point / all identical) still gets a non-zero margin.
pad_lon = max(0.05 * (lon_max - lon_min), _MIN_PAD)
pad_lat = max(0.05 * (lat_max - lat_min), _MIN_PAD)
return {
"points": points,
"n_total": n_total,
"n_shown": n_shown,
"downsampled": downsampled,
"bbox": bbox,
"aspect": aspect,
"pad": {"lon": pad_lon, "lat": pad_lat},
}
@@ -0,0 +1,140 @@
"""Tests para build_geo_scatter."""
import math
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from build_geo_scatter import build_geo_scatter
# Keys that a non-empty result dict must always contain.
_EXPECTED_KEYS = {
"points", "n_total", "n_shown", "downsampled", "bbox", "aspect", "pad",
}
def test_geo_scatter_nube_espana():
"""Golden: nube en Espana -> points en orden [lon, lat], bbox, aspect>1, pad 5%."""
# Cuatro puntos alrededor de Madrid (lat ~40, lon negativo).
lats = [40.0, 41.0, 39.0, 40.5]
lons = [-3.7, -3.0, -4.0, -3.5]
r = build_geo_scatter(lats, lons)
assert set(r.keys()) == _EXPECTED_KEYS
# points en orden [x=lon, y=lat]: primer elemento lon (negativo), segundo lat (~40).
assert r["points"] == [[-3.7, 40.0], [-3.0, 41.0], [-4.0, 39.0], [-3.5, 40.5]]
for lon, lat in r["points"]:
assert lon < 0.0 # longitudes de Espana son negativas
assert 36.0 < lat < 44.0 # latitudes peninsulares
# Sin downsampling: 4 < 2000.
assert r["n_total"] == 4
assert r["n_shown"] == 4
assert r["downsampled"] is False
# bbox correcto.
assert r["bbox"] == {
"lat_min": 39.0, "lat_max": 41.0,
"lon_min": -4.0, "lon_max": -3.0,
}
# aspect = 1/cos(centroid_lat); centroid = 40.125 -> ~1.31 > 1.
centroid_lat = (40.0 + 41.0 + 39.0 + 40.5) / 4.0
expected_aspect = 1.0 / math.cos(math.radians(centroid_lat))
assert r["aspect"] > 1.0
assert abs(r["aspect"] - expected_aspect) < 1e-9
assert abs(r["aspect"] - 1.305) < 0.02 # cos(40) ~ 0.77
# pad 5% del rango (lon_range=1.0 -> 0.05 ; lat_range=2.0 -> 0.1).
assert abs(r["pad"]["lon"] - 0.05) < 1e-9
assert abs(r["pad"]["lat"] - 0.10) < 1e-9
def test_downsampling_determinista_y_reproducible():
"""Golden: 5000 puntos, max_points=2000 -> n_shown<=2000, downsampled, reproducible."""
lats = [40.0 + (i % 100) * 0.01 for i in range(5000)]
lons = [-3.0 - (i % 100) * 0.01 for i in range(5000)]
r1 = build_geo_scatter(lats, lons, max_points=2000)
assert r1["n_total"] == 5000
assert r1["n_shown"] <= 2000
assert r1["downsampled"] is True
# step = ceil(5000/2000) = 3 -> len(pairs[::3]) = 1667.
assert r1["n_shown"] == 1667
# Determinista: dos llamadas con la misma entrada dan exactamente lo mismo.
r2 = build_geo_scatter(lats, lons, max_points=2000)
assert r1 == r2
assert r1["points"] == r2["points"]
# El primer punto del downsample es el primer par valido (step parte de 0).
assert r1["points"][0] == [lons[0], lats[0]]
def test_listas_vacias_no_lanza():
"""Edge: listas vacias / None -> points [] sin lanzar."""
r = build_geo_scatter([], [])
assert r["points"] == []
assert r["n_total"] == 0
assert r["n_shown"] == 0
assert r["downsampled"] is False
assert r["bbox"] is None
assert r["aspect"] == 1.0
assert r["pad"] == {"lon": 0.0, "lat": 0.0}
# None como entrada tampoco lanza.
assert build_geo_scatter(None, None)["points"] == []
assert build_geo_scatter([40.0], None)["n_total"] == 0
assert build_geo_scatter(None, [-3.0])["n_total"] == 0
def test_un_solo_punto_pad_minimo_y_aspect_finito():
"""Edge: un solo punto -> pad minimo no cero, bbox degenerado, aspect finito."""
r = build_geo_scatter([40.0], [-3.7])
assert r["n_total"] == 1
assert r["n_shown"] == 1
assert r["points"] == [[-3.7, 40.0]]
assert r["downsampled"] is False
assert r["bbox"] == {
"lat_min": 40.0, "lat_max": 40.0,
"lon_min": -3.7, "lon_max": -3.7,
}
# rango 0 -> pad cae al floor minimo (no cero).
assert r["pad"]["lon"] == 0.01
assert r["pad"]["lat"] == 0.01
# aspect finito y dentro del clamp.
assert math.isfinite(r["aspect"])
assert 0.3 <= r["aspect"] <= 5.0
def test_filtra_none_nan_y_fuera_de_rango():
"""Edge: pares con None/NaN/fuera de rango se descartan por indice."""
nan = float("nan")
inf = float("inf")
# i=0 i=1 i=2 i=3 i=4 i=5 i=6
lats = [40.0, None, nan, 200.0, 41.0, 39.0, inf]
lons = [-3.0, -3.5, -3.6, -3.7, 999.0, -4.0, -2.0]
r = build_geo_scatter(lats, lons)
# Validos solo i=0 (40,-3.0) e i=5 (39,-4.0):
# i=1 lat None, i=2 lat NaN, i=3 lat 200 fuera de rango,
# i=4 lon 999 fuera de rango, i=6 lat inf.
assert r["n_total"] == 2
assert r["points"] == [[-3.0, 40.0], [-4.0, 39.0]]
assert r["bbox"] == {
"lat_min": 39.0, "lat_max": 40.0,
"lon_min": -4.0, "lon_max": -3.0,
}
def test_latitud_alta_aspect_clamped():
"""Edge: latitudes ~85 -> aspect clamped <= 5.0."""
r = build_geo_scatter([85.0, 85.0, 84.0], [10.0, 11.0, 9.0])
# cos(~84.7) ~ 0.093 -> 1/0.093 ~ 10.7 -> clamp a 5.0.
assert r["aspect"] <= 5.0
assert r["aspect"] == 5.0
assert math.isfinite(r["aspect"])
@@ -0,0 +1,67 @@
---
name: detect_latlon_columns
id: detect_latlon_columns_py_datascience
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: pure
signature: "def detect_latlon_columns(columns: list, samples: dict | None = None) -> dict"
description: "Detecta un par (latitud, longitud) entre las columnas de un TableProfile del grupo eda combinando heuristica de nombre (latitude/longitude/lat/lon/lng + x/y debiles) con validacion de rango obligatoria (latitud en [-90,90], longitud en [-180,180]). Lee defensivamente con .get; NUNCA lanza. Usa el sub-bloque numeric.min/max o, si falta, la lista de samples opcional. Devuelve SIEMPRE un dict {lat_col, lon_col, confidence, reason}; si no hay par valido, las columnas van a None y confidence a 0.0."
tags: [eda, geospatial, profiling, latlon, coordinates, detection, datascience]
params:
- name: columns
desc: "Lista de dicts ColumnProfile (el campo `columns` de un TableProfile del grupo eda). Cada dict se lee con .get; solo `name` (str) es obligatorio. Se consultan `inferred_type` (p.ej. 'numeric') y el sub-dict `numeric` con `min`/`max` (floats) para validar el rango. Entradas no-dict o sin name se ignoran sin lanzar."
- name: samples
desc: "Opcional {nombre_columna: [valores...]} para validar el rango cuando una columna no trae numeric.min/max. Los valores nulos se ignoran; si algun valor no nulo no es numerico la columna no se considera coordenada. Si es None u omitido, solo se usa el bloque numeric."
output: "Dict SIEMPRE presente con la forma {lat_col: str|None, lon_col: str|None, confidence: float en [0,1], reason: str en espanol}. En exito, lat_col y lon_col nombran columnas distintas; confidence ~1.0 para par con nombre fuerte (latitude/longitude/lat/lon/lng) + rango valido y ~0.7 para par debil (x/y) + rango. En fallo, ambas columnas None, confidence 0.0 y reason explica por que (sin columnas, nombre sin match, rango fuera de bounds, falta uno de los dos ejes...)."
uses_functions: []
uses_types: []
returns: []
returns_optional: false
error_type: ""
imports: []
tested: true
tests: ["test_par_latitude_longitude_fuerte", "test_par_lat_lon_abreviado", "test_par_x_y_debil_con_rango_valido", "test_nombre_lat_lon_pero_rango_fuera_no_detecta", "test_par_fuerte_prevalece_sobre_debil", "test_entradas_vacias_o_invalidas_no_lanzan", "test_solo_latitud_sin_longitud_no_detecta", "test_deteccion_por_samples_cuando_falta_numeric", "test_samples_fuera_de_rango_descarta"]
test_file_path: "python/functions/datascience/detect_latlon_columns_test.py"
file_path: "python/functions/datascience/detect_latlon_columns.py"
---
## Ejemplo
```python
import sys, os
sys.path.insert(0, os.path.join("python", "functions"))
from datascience.detect_latlon_columns import detect_latlon_columns
# Columnas tal y como vienen en profile['columns'] de un TableProfile del grupo eda:
columns = [
{"name": "id", "inferred_type": "numeric", "numeric": {"min": 1, "max": 9999}},
{"name": "latitude", "inferred_type": "numeric", "numeric": {"min": -45.0, "max": 45.0}},
{"name": "longitude", "inferred_type": "numeric", "numeric": {"min": -120.0, "max": 120.0}},
]
res = detect_latlon_columns(columns)
print(res["lat_col"], res["lon_col"], res["confidence"])
# latitude longitude 1.0
# Sin bloque numeric, validando el rango con samples:
cols2 = [{"name": "lat"}, {"name": "lon"}]
samples = {"lat": [10.5, 20.0, 30.25], "lon": [-40.0, 50.5, 60.0]}
print(detect_latlon_columns(cols2, samples)["lat_col"]) # lat
```
## Cuando usarla
- Usala al perfilar una tabla en `AutomaticEDA` para decidir si tiene geometria de puntos: cuando `detect_latlon_columns` devuelve un par con `confidence` alta, el capitulo geospatial puede dibujar un mapa, calcular un bounding box o proponer un cluster espacial.
- Antes de un analisis geoespacial (alpha shape, convex hull, joins por proximidad) para localizar automaticamente que columnas son la latitud y la longitud sin pedirlo al usuario.
- Cuando recibas un `TableProfile` del grupo `eda` y quieras enrutar columnas a sub-analisis por tipo semantico: este es el detector del par lat/lon, complementario a `infer_semantic_type`.
## Gotchas
- Funcion pura, sin I/O y determinista. Lectura defensiva con `.get`: NUNCA lanza. Cualquier input malformado (None, no-lista, entradas no-dict, claves ausentes) devuelve el dict de fallo con `lat_col`/`lon_col` en None y `confidence` 0.0.
- **El nombre solo no basta**: una columna `latitude` cuyo rango se sale de `[-90, 90]` se descarta (no es coordenada real). Igual para `longitude` fuera de `[-180, 180]`. La validacion de rango es obligatoria.
- El rango de latitud `[-90, 90]` es un subconjunto del de longitud `[-180, 180]`, por eso el nombre es necesario para desambiguar cual eje es cual; una columna numerica en `[-90, 90]` sin nombre que sugiera lat/lon no se detecta.
- Los nombres genericos `x`/`y` (y `x_coord`/`y_coord`) son candidatos **debiles**: solo forman par si el rango encaja y existe la otra mitad (un `x`/`lon` para la `y`, un `y`/`lat` para la `x`). Un `y` suelto sin pareja devuelve None.
- Requiere AMBOS ejes para considerar exito. Si solo encuentra latitud o solo longitud, devuelve el dict de fallo (no media coordenada).
- `samples` solo se consulta cuando falta `numeric.min`/`numeric.max`. Si una columna trae el bloque numeric, ese manda aunque pases samples para ella.
- El matching de nombre es por subcadena normalizada (se quitan `_`, `-` y espacios), asi que nombres como `plate` (contiene "lat") podrian marcarse como candidatos por nombre — pero solo pasarian si su rango cae en `[-90, 90]` y hay una longitud pareja, filtro que en la practica descarta los falsos positivos.
@@ -0,0 +1,198 @@
"""detect_latlon_columns — detect a (latitude, longitude) column pair in an EDA profile.
Pure function: no I/O, deterministic. Takes the `columns` list of a TableProfile
(group `eda`) and decides whether two of its columns form a geographic coordinate
pair (latitude + longitude), combining a name heuristic with a value-range check.
The detection is intentionally conservative: a name hint alone is never enough. A
column is only accepted as latitude/longitude if its numeric range fits inside the
valid coordinate bounds ([-90, 90] for latitude, [-180, 180] for longitude). When
the `numeric` sub-block is absent the optional `samples` argument is used instead.
Reading is fully defensive (.get throughout) and the function NEVER raises: any
malformed input (None, non-list, non-dict entries, missing keys) simply yields a
no-pair result {"lat_col": None, "lon_col": None, "confidence": 0.0, "reason": ...}.
"""
import re
# Collapse the separators a column name may use (snake_case, kebab-case, spaces)
# so that "y_coord", "y-coord" and "y coord" all normalize to the same token.
_SEP_RE = re.compile(r"[\s_\-]+")
# Name-match strengths: a strong, unambiguous coordinate name vs a weak generic
# axis name (x / y) that only counts when the range also fits and a partner exists.
_STRONG = 0.6
_WEAK = 0.3
_RANGE_BONUS = 0.4 # added once the mandatory range validation passes
def _normalize(name):
"""Lowercase a column name and strip separator chars (_, -, whitespace)."""
if not isinstance(name, str):
return ""
return _SEP_RE.sub("", name.strip().lower())
def _num(value):
"""Coerce to float defensively; return None for None/bool/non-numeric."""
# bool is a subclass of int; a coordinate value is never a real bool, so treat
# True/False as missing instead of silently coercing to 1.0/0.0.
if value is None or isinstance(value, bool):
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def _lat_name_strength(nn):
"""Strength of a normalized name as a latitude candidate (0=no match)."""
if not nn:
return 0.0
# "lat", "latitude", "latitud" all contain the "lat" stem.
if "lat" in nn:
return _STRONG
# Weak generic axis name: only useful when paired with an x/lon partner.
if nn in ("y", "ycoord", "ycoordinate", "ycoordinates"):
return _WEAK
return 0.0
def _lon_name_strength(nn):
"""Strength of a normalized name as a longitude candidate (0=no match)."""
if not nn:
return 0.0
# "lon", "long", "longitude", "longitud" share the "lon" stem; "lng" is separate.
if "lon" in nn or "lng" in nn:
return _STRONG
if nn in ("x", "xcoord", "xcoordinate", "xcoordinates"):
return _WEAK
return 0.0
def _col_range(col, sample_values):
"""Return (min, max) floats for a column, or (None, None) if not numeric.
Prefers the `numeric` sub-block min/max (the output of describe_numeric); falls
back to the provided sample list. A column is only treated as numeric when both
extremes are derivable: from the numeric block, or from samples whose every
non-null value coerces to a number.
"""
if isinstance(col, dict):
numeric = col.get("numeric")
if isinstance(numeric, dict):
mn = _num(numeric.get("min"))
mx = _num(numeric.get("max"))
if mn is not None and mx is not None:
return mn, mx
# Fall back to samples when the numeric block is missing or incomplete.
if isinstance(sample_values, (list, tuple)):
non_null = [v for v in sample_values if v is not None]
if non_null:
coerced = [_num(v) for v in non_null]
# Any non-numeric sample means we cannot trust the column as numeric.
if all(c is not None for c in coerced):
return min(coerced), max(coerced)
return None, None
def _no_pair(reason):
"""Canonical empty result: no coordinate pair detected."""
return {"lat_col": None, "lon_col": None, "confidence": 0.0, "reason": reason}
def detect_latlon_columns(columns: list, samples: dict | None = None) -> dict:
"""Detect a (latitude, longitude) column pair from an eda TableProfile.
Combines a name heuristic (latitude/longitude/lat/lon/lng + weak x/y) with a
mandatory range validation: the chosen latitude must sit in [-90, 90] and the
longitude in [-180, 180]. A name hint whose range does not fit is discarded.
Both sides are required for success; if only one is found, no pair is returned.
Args:
columns: List of ColumnProfile dicts (the `columns` of a TableProfile).
Each dict is read defensively with .get; only `name` is required.
`numeric.min` / `numeric.max` (and optionally `inferred_type`) are used
for the range check when present.
samples: Optional {column_name: [values...]} used to validate the range
when a column lacks `numeric.min`/`numeric.max`. If None/omitted, only
the `numeric` sub-block is consulted.
Returns:
Always a dict {"lat_col": str|None, "lon_col": str|None,
"confidence": float, "reason": str}. On success lat_col and lon_col name
the detected pair (distinct columns) and confidence is in [0, 1]: a pair
validated by a strong name on both sides scores ~1.0, a weak x/y pair ~0.7.
On failure both columns are None and confidence is 0.0.
"""
if not isinstance(columns, (list, tuple)) or len(columns) == 0:
return _no_pair("sin columnas que inspeccionar")
sample_map = samples if isinstance(samples, dict) else {}
# (column_name, confidence) for each side. Confidence already includes the
# range bonus because membership in the list implies the range was validated.
lat_candidates = []
lon_candidates = []
for col in columns:
if not isinstance(col, dict):
continue
name = col.get("name")
if not isinstance(name, str) or not name:
continue
nn = _normalize(name)
lat_strength = _lat_name_strength(nn)
lon_strength = _lon_name_strength(nn)
if lat_strength == 0.0 and lon_strength == 0.0:
continue # name gives no coordinate hint; skip.
mn, mx = _col_range(col, sample_map.get(name))
is_numeric = mn is not None and mx is not None
if not is_numeric:
continue # range cannot be validated -> not a coordinate.
if lat_strength > 0.0 and mn >= -90.0 and mx <= 90.0:
lat_candidates.append((name, lat_strength + _RANGE_BONUS))
if lon_strength > 0.0 and mn >= -180.0 and mx <= 180.0:
lon_candidates.append((name, lon_strength + _RANGE_BONUS))
if not lat_candidates and not lon_candidates:
return _no_pair("ninguna columna sugiere latitud ni longitud por nombre+rango")
if not lat_candidates:
return _no_pair("no se encontro columna de latitud valida (nombre+rango en [-90,90])")
if not lon_candidates:
return _no_pair("no se encontro columna de longitud valida (nombre+rango en [-180,180])")
# Pick the distinct pair with the highest combined confidence. First match wins
# on ties to keep the result deterministic by input order.
best = None # (combined, lat_name, lon_name, lat_c, lon_c)
for lat_name, lat_c in lat_candidates:
for lon_name, lon_c in lon_candidates:
if lat_name == lon_name:
continue # a column cannot be both axes of the same pair.
combined = (lat_c + lon_c) / 2.0
if best is None or combined > best[0]:
best = (combined, lat_name, lon_name, lat_c, lon_c)
if best is None:
return _no_pair("solo una columna sirve para ambos ejes; no hay par lat/lon distinto")
combined, lat_name, lon_name, lat_c, lon_c = best
confidence = max(0.0, min(1.0, combined))
lat_label = "fuerte" if lat_c >= 0.9 else "debil"
lon_label = "fuerte" if lon_c >= 0.9 else "debil"
reason = (
f"par lat='{lat_name}' (nombre {lat_label}) / lon='{lon_name}' "
f"(nombre {lon_label}) con rango valido"
)
return {
"lat_col": lat_name,
"lon_col": lon_name,
"confidence": confidence,
"reason": reason,
}
@@ -0,0 +1,141 @@
"""Tests para detect_latlon_columns."""
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from detect_latlon_columns import detect_latlon_columns
# Keys that every result dict (success or failure) must expose.
_EXPECTED_KEYS = {"lat_col", "lon_col", "confidence", "reason"}
def _col(name, mn=None, mx=None, inferred="numeric"):
"""Build a minimal ColumnProfile-like dict for the tests."""
col = {"name": name, "inferred_type": inferred}
if mn is not None or mx is not None:
col["numeric"] = {"min": mn, "max": mx}
return col
def test_par_latitude_longitude_fuerte():
"""Golden: nombres latitude/longitude con rango valido -> par con confianza alta."""
columns = [
_col("id", mn=1, mx=9999, inferred="numeric"),
_col("latitude", mn=-45.0, mx=45.0),
_col("longitude", mn=-120.0, mx=120.0),
]
res = detect_latlon_columns(columns)
assert set(res.keys()) == _EXPECTED_KEYS
assert res["lat_col"] == "latitude"
assert res["lon_col"] == "longitude"
# Nombre fuerte (0.6) + rango (0.4) en ambos lados -> 1.0.
assert abs(res["confidence"] - 1.0) < 1e-9
assert "rango valido" in res["reason"]
def test_par_lat_lon_abreviado():
"""Golden: nombres abreviados lat/lon tambien se detectan como fuertes."""
columns = [
_col("lat", mn=40.0, mx=43.0),
_col("lon", mn=-4.0, mx=-1.0),
_col("precio", mn=0.0, mx=500.0),
]
res = detect_latlon_columns(columns)
assert res["lat_col"] == "lat"
assert res["lon_col"] == "lon"
assert abs(res["confidence"] - 1.0) < 1e-9
def test_par_x_y_debil_con_rango_valido():
"""Edge: x/y genericos solo cuentan como par debil cuando el rango encaja."""
columns = [
_col("y_coord", mn=-10.0, mx=10.0), # debil latitud
_col("x_coord", mn=-150.0, mx=150.0), # debil longitud
]
res = detect_latlon_columns(columns)
assert res["lat_col"] == "y_coord"
assert res["lon_col"] == "x_coord"
# Nombre debil (0.3) + rango (0.4) -> 0.7 en ambos lados.
assert abs(res["confidence"] - 0.7) < 1e-9
def test_nombre_lat_lon_pero_rango_fuera_no_detecta():
"""Edge: nombre lat/lon con rango fuera de bounds -> NO es coordenada."""
columns = [
_col("latitude", mn=-200.0, mx=200.0), # fuera de [-90, 90]
_col("longitude", mn=-120.0, mx=120.0), # valido, pero sin par lat
]
res = detect_latlon_columns(columns)
assert res["lat_col"] is None
assert res["lon_col"] is None
assert res["confidence"] == 0.0
assert isinstance(res["reason"], str) and res["reason"]
def test_par_fuerte_prevalece_sobre_debil():
"""Edge: con candidatos fuertes y debiles, gana el par de mayor confianza."""
columns = [
_col("latitude", mn=-45.0, mx=45.0), # fuerte lat
_col("y", mn=-30.0, mx=30.0), # debil lat
_col("longitude", mn=-120.0, mx=120.0), # fuerte lon
_col("x", mn=-100.0, mx=100.0), # debil lon
]
res = detect_latlon_columns(columns)
assert res["lat_col"] == "latitude"
assert res["lon_col"] == "longitude"
assert abs(res["confidence"] - 1.0) < 1e-9
def test_entradas_vacias_o_invalidas_no_lanzan():
"""Edge: sin columnas / vacio / no-lista / entradas no-dict -> dict None sin lanzar."""
for bad in ([], None, "no soy lista", 42, [1, 2, 3], [{}], [{"foo": "bar"}]):
res = detect_latlon_columns(bad)
assert set(res.keys()) == _EXPECTED_KEYS
assert res["lat_col"] is None
assert res["lon_col"] is None
assert res["confidence"] == 0.0
assert isinstance(res["reason"], str)
def test_solo_latitud_sin_longitud_no_detecta():
"""Edge: solo hay latitud valida, falta la longitud -> sin par."""
columns = [
_col("latitude", mn=-45.0, mx=45.0),
_col("temperatura", mn=-5.0, mx=40.0),
]
res = detect_latlon_columns(columns)
assert res["lat_col"] is None
assert res["lon_col"] is None
assert res["confidence"] == 0.0
def test_deteccion_por_samples_cuando_falta_numeric():
"""Edge: sin bloque numeric, el rango se valida con samples."""
columns = [
{"name": "lat"}, # sin numeric ni inferred_type
{"name": "lon"},
]
samples = {
"lat": [10.5, 20.0, None, 30.25], # todos dentro de [-90, 90]
"lon": [-40.0, 50.5, 60.0], # todos dentro de [-180, 180]
}
res = detect_latlon_columns(columns, samples)
assert res["lat_col"] == "lat"
assert res["lon_col"] == "lon"
assert abs(res["confidence"] - 1.0) < 1e-9
def test_samples_fuera_de_rango_descarta():
"""Edge: samples fuera de bounds invalidan la columna pese al nombre fuerte."""
columns = [{"name": "lat"}, {"name": "lon"}]
samples = {
"lat": [10.0, 95.0], # 95 > 90 -> latitud invalida
"lon": [-40.0, 50.0],
}
res = detect_latlon_columns(columns, samples)
assert res["lat_col"] is None
assert res["lon_col"] is None
assert res["confidence"] == 0.0
@@ -1,87 +0,0 @@
---
name: groupby_stats_duckdb
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: impure
signature: "def groupby_stats_duckdb(db_path: str, table: str, group_by: str, measures: list, aggs: list = None, top_n: int = 15) -> dict"
description: "Agregaciones GROUP BY con push-down SQL en DuckDB: para cada measure numerica calcula mean/median/std/min/max por grupo (split-apply-combine en el motor), trayendo solo una fila por grupo. Nucleo de un capitulo de agregacion/OLAP de un EDA. count = tamanio del grupo, independiente de measures."
tags: [eda, groupby, aggregation, olap, duckdb, datascience, push-down, split-apply-combine]
uses_functions: [duckdb_query_readonly_py_infra]
uses_types: []
returns: []
returns_optional: false
error_type: "error_go_core"
imports: []
params:
- name: db_path
desc: "Ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la base. Path inexistente -> {status:'error'} sin lanzar."
- name: table
desc: "Nombre de la tabla. Se interpola citado con dobles comillas (soporta nombres con espacios; las comillas internas se escapan)."
- name: group_by
desc: "Columna por la que agrupar. Se interpola citada. Sus valores distintos son las claves de los grupos."
- name: measures
desc: "Lista de columnas numericas a agregar. Lista vacia es valida: cada grupo trae solo su tamanio `n` y `stats` vacio."
- name: aggs
desc: "Lista de agregaciones. None (default) = ['count','mean','median','std','min','max']. Validas: count (tamanio del grupo, va a `n`), mean->avg, median, std->stddev_samp, min, max (estas cinco por measure). Agg desconocido -> error."
- name: top_n
desc: "Maximo de grupos a devolver, ordenados por tamanio de grupo descendente (default 15). Internamente se piden top_n+1 para detectar truncado."
output: "dict. En exito {status:'ok', group_by, measures:[...], aggs:[...], n_groups:int, truncated:bool, groups:[{key:<valor grupo>, n:int, stats:{<measure>:{mean,median,std,min,max}}}], note:str}. Las estadisticas son float o None (p.ej. std de un grupo de 1 fila -> NULL -> None). En error {status:'error', error:str} (no lanza)."
tested: true
tests: ["agrega por grupo con valores conocidos", "db inexistente devuelve error sin lanzar", "measures vacias agrega solo count", "columna con espacio agrupa bien"]
test_file_path: "python/functions/datascience/groupby_stats_duckdb_test.py"
file_path: "python/functions/datascience/groupby_stats_duckdb.py"
---
## Ejemplo
```python
import duckdb
from datascience import groupby_stats_duckdb
# Cargar el titanic en una tabla DuckDB de prueba.
db = "/tmp/titanic.duckdb"
con = duckdb.connect(db)
con.execute(
"CREATE TABLE titanic AS "
"SELECT * FROM read_csv_auto('https://raw.githubusercontent.com/"
"datasciencedojo/datasets/master/titanic.csv')"
)
con.close()
# Agrupar por sexo midiendo edad y tarifa.
res = groupby_stats_duckdb(db, "titanic", "Sex", ["Age", "Fare"])
print(res["status"]) # ok
print(res["n_groups"]) # 2 (male, female)
for g in res["groups"]:
print(g["key"], g["n"], round(g["stats"]["Fare"]["mean"], 2))
# female 314 44.48
# male 577 25.52
```
## Cuando usarla
Cuando en un EDA necesitas el clasico split-apply-combine: "para cada categoria de X,
¿cuanto vale en media/mediana/desviacion/min/max la metrica Y?". Es el nucleo de un
capitulo de agregacion/OLAP. Usala antes de pintar barras o boxplots por grupo, para
detectar segmentos con comportamiento distinto, o para resumir una tabla grande sin
traer las filas a RAM: todo el GROUP BY ocurre push-down en el motor de DuckDB y solo
viaja una fila por grupo. `top_n` te deja quedarte con los grupos mas poblados.
## Gotchas
- Funcion impura: lee un archivo DuckDB del disco (read_only, nunca lo modifica). La
tabla debe existir ya en el `.db` (no carga CSV; para eso crea la tabla antes).
- Identificadores (tabla, group_by, measures) se interpolan citados con dobles comillas
y escapando las internas: soporta nombres con espacios y evita inyeccion. No pases
expresiones SQL como group_by/measure — solo nombres de columna.
- `count` es el tamanio del grupo (`COUNT(*)`), independiente de las measures: se
refleja en el campo `n` de cada grupo, NO como clave dentro de `stats`. Las claves de
`stats[measure]` son las measure-aggs efectivas (mean/median/std/min/max menos count).
- `std` usa `stddev_samp` (muestral, n-1): un grupo con una sola fila da `NULL` -> `None`.
Las measures pueden contener NULLs; cada agregada los ignora segun la semantica de DuckDB.
- `truncated:True` indica que habia mas grupos que `top_n` (se devolvieron los `top_n`
mayores por tamanio). Sube `top_n` si necesitas todos los grupos.
- Si `measures` esta vacio, cada grupo trae solo `n` y `stats == {}` (valido, util para
un simple conteo por categoria).
@@ -1,184 +0,0 @@
"""groupby_stats_duckdb — agregaciones GROUP BY con push-down SQL en DuckDB.
Funcion impura: lee de disco a traves de DuckDB (via la primitiva read-only
`duckdb_query_readonly` del grupo `duckdb`). Pertenece al grupo de capacidad `eda`.
Ejecuta un `GROUP BY <group_by>` en el motor de DuckDB (split-apply-combine con
push-down) calculando, para cada columna numerica de `measures`, las agregaciones
pedidas (mean/median/std/min/max). Solo trae al cliente una fila por grupo, nunca
las filas crudas: apto para tablas grandes. Es el nucleo de un capitulo de
agregacion/OLAP de un EDA.
Estilo dict-no-throw del grupo duckdb: nunca lanza; captura cualquier error y
devuelve {status:'error', error:str}.
"""
from infra import duckdb_query_readonly
# Mapeo agg -> funcion agregada SQL de DuckDB. `count` se trata aparte: es
# COUNT(*) (tamanio del grupo), independiente de las measures.
_AGG_SQL = {
"mean": "avg",
"median": "median",
"std": "stddev_samp",
"min": "min",
"max": "max",
}
# Aggs por defecto cuando aggs=None. count primero (tamanio del grupo) + las
# cinco estadisticas por measure.
_DEFAULT_AGGS = ["count", "mean", "median", "std", "min", "max"]
def _quote_ident(ident: str) -> str:
"""Cita un identificador SQL con dobles comillas, escapando las internas.
Soporta nombres con espacios o caracteres especiales y evita inyeccion: dentro
de un identificador entrecomillado el unico caracter peligroso es la propia
comilla doble, que se duplica ("") segun el estandar SQL. DuckDB no admite
parametros posicionales para nombres de tabla/columna, asi que esta es la via
segura de interpolarlos.
"""
return '"' + str(ident).replace('"', '""') + '"'
def groupby_stats_duckdb(
db_path: str,
table: str,
group_by: str,
measures: list,
aggs: list = None,
top_n: int = 15,
) -> dict:
"""GROUP BY con agregaciones por measure, todo push-down en DuckDB.
Args:
db_path: ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la
base. Un path inexistente devuelve {status:'error', ...} sin lanzar.
table: nombre de la tabla. Se interpola citado con dobles comillas (soporta
nombres con espacios).
group_by: columna por la que agrupar. Se interpola citada.
measures: lista de columnas numericas a agregar. Lista vacia es valida:
cada grupo trae solo su tamanio `n` y `stats` vacio.
aggs: lista de agregaciones a calcular. None (default) =
["count", "mean", "median", "std", "min", "max"]. Valores validos:
count (tamanio del grupo, va a `n`), mean, median, std, min, max
(estas cinco se calculan por cada measure). Un agg desconocido devuelve
error.
top_n: numero maximo de grupos a devolver, ordenados por tamanio de grupo
descendente (default 15). Se pide top_n+1 internamente para detectar si
habia mas grupos y marcar `truncated`.
Returns:
dict. En exito:
{status:'ok',
group_by:str,
measures:[...],
aggs:[...], # las efectivas (incluye count si se pidio)
n_groups:int, # nº de grupos devueltos (<= top_n)
truncated:bool, # True si habia mas de top_n grupos
groups:[{key:<valor grupo>, n:int,
stats:{<measure>:{mean,median,std,min,max}}}, ...],
note:str}
Las estadisticas son float o None (p.ej. stddev_samp de un grupo de una
sola fila -> NULL -> None). En error (sin lanzar): {status:'error', error:str}.
"""
try:
# 1. Validar entradas.
if not isinstance(table, str) or table == "":
return {"status": "error", "error": "table must be a non-empty string"}
if not isinstance(group_by, str) or group_by == "":
return {"status": "error", "error": "group_by must be a non-empty string"}
if measures is None:
measures = []
if not isinstance(measures, list):
return {"status": "error", "error": "measures must be a list"}
for m in measures:
if not isinstance(m, str) or m == "":
return {
"status": "error",
"error": f"invalid measure identifier: {m!r}",
}
if aggs is None:
aggs = list(_DEFAULT_AGGS)
if not isinstance(aggs, list) or len(aggs) == 0:
return {
"status": "error",
"error": "aggs must be a non-empty list or None",
}
for a in aggs:
if a != "count" and a not in _AGG_SQL:
return {
"status": "error",
"error": f"unknown agg {a!r}; valid: count, "
+ ", ".join(_AGG_SQL),
}
if not isinstance(top_n, int) or isinstance(top_n, bool) or top_n < 1:
return {"status": "error", "error": "top_n must be a positive int"}
# 2. Aggs por measure = todas menos count (count es el tamanio del grupo,
# se mapea siempre a la columna `n`).
measure_aggs = [a for a in aggs if a != "count"]
# 3. Construir el SELECT. grp y n primero; luego un termino por measure x agg
# con alias posicional (m{idx}_{agg}) para no chocar con nombres de columna
# que lleven espacios o caracteres raros.
select_terms = [f"{_quote_ident(group_by)} AS grp", "COUNT(*) AS n"]
agg_index = [] # (measure_name, agg_name, alias)
for mi, m in enumerate(measures):
for a in measure_aggs:
alias = f"m{mi}_{a}"
fn = _AGG_SQL[a]
select_terms.append(f"{fn}({_quote_ident(m)}) AS {alias}")
agg_index.append((m, a, alias))
# Pedimos top_n+1 grupos para detectar truncado (habia mas que top_n).
sql = (
f"SELECT {', '.join(select_terms)} "
f"FROM {_quote_ident(table)} "
f"GROUP BY {_quote_ident(group_by)} "
f"ORDER BY n DESC "
f"LIMIT {top_n + 1}"
)
# 4. Ejecutar push-down. sandbox=True (default) basta: la tabla ya existe en
# el .db, no necesitamos read_csv/read_blob ni acceso al filesystem.
result = duckdb_query_readonly(db_path, sql, max_rows=top_n + 1)
if result.get("status") != "ok":
return {
"status": "error",
"error": "groupby query failed: "
+ str(result.get("error", "unknown")),
}
rows = result.get("rows", [])
truncated = len(rows) > top_n
if truncated:
rows = rows[:top_n]
# 5. Reconstruir la estructura por grupo.
groups = []
for row in rows:
stats = {m: {} for m in measures}
for (m, a, alias) in agg_index:
stats[m][a] = row.get(alias)
groups.append(
{"key": row.get("grp"), "n": row.get("n"), "stats": stats}
)
return {
"status": "ok",
"group_by": group_by,
"measures": list(measures),
"aggs": list(aggs),
"n_groups": len(groups),
"truncated": truncated,
"groups": groups,
"note": f"GROUP BY {group_by}: top {len(groups)} grupos por tamanio sobre "
f"{len(measures)} measure(s)",
}
except Exception as e: # noqa: BLE001
return {"status": "error", "error": str(e)}
@@ -1,106 +0,0 @@
"""Tests para groupby_stats_duckdb."""
import os
import sys
import duckdb
# Permitir importar funciones del registry (from infra import ..., from datascience import ...).
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions"))
from datascience.groupby_stats_duckdb import groupby_stats_duckdb
def _make_db(tmp_path, rows):
"""Crea una DuckDB con tabla t(g VARCHAR, x DOUBLE) e inserta `rows`."""
db = os.path.join(str(tmp_path), "t.duckdb")
con = duckdb.connect(db)
con.execute("CREATE TABLE t(g VARCHAR, x DOUBLE)")
con.executemany("INSERT INTO t VALUES (?, ?)", rows)
con.close()
return db
def test_agrega_por_grupo_con_valores_conocidos(tmp_path):
# Grupo a: [10, 20, 30] -> n=3, mean=20, min=10, max=30, median=20, std=10.
# Grupo b: [5, 15] -> n=2, mean=10, median=10.
# Grupo c: [100] -> n=1, mean=100, std=None (1 sola fila).
rows = [
("a", 10.0), ("a", 20.0), ("a", 30.0),
("b", 5.0), ("b", 15.0),
("c", 100.0),
]
db = _make_db(tmp_path, rows)
res = groupby_stats_duckdb(db, "t", "g", ["x"])
assert res["status"] == "ok", res
assert res["n_groups"] == 3
assert res["truncated"] is False
assert res["aggs"] == ["count", "mean", "median", "std", "min", "max"]
by_key = {g["key"]: g for g in res["groups"]}
assert set(by_key) == {"a", "b", "c"}
# Grupo a: comprobacion manual de mean/min/max/median/std.
sa = by_key["a"]["stats"]["x"]
assert by_key["a"]["n"] == 3
assert abs(sa["mean"] - 20.0) < 1e-9
assert abs(sa["min"] - 10.0) < 1e-9
assert abs(sa["max"] - 30.0) < 1e-9
assert abs(sa["median"] - 20.0) < 1e-9
assert "std" in sa and sa["std"] is not None
assert abs(sa["std"] - 10.0) < 1e-9 # stddev_samp([10,20,30]) = 10
# Grupo b: mean y median conocidas.
sb = by_key["b"]["stats"]["x"]
assert by_key["b"]["n"] == 2
assert abs(sb["mean"] - 10.0) < 1e-9
assert abs(sb["median"] - 10.0) < 1e-9
assert "median" in sb and "std" in sb
# Grupo c: una sola fila -> std None (stddev_samp NULL), mean/min/max definidos.
sc = by_key["c"]["stats"]["x"]
assert by_key["c"]["n"] == 1
assert abs(sc["mean"] - 100.0) < 1e-9
assert sc["std"] is None
def test_db_inexistente_devuelve_error_sin_lanzar(tmp_path):
db = os.path.join(str(tmp_path), "no_existe.duckdb")
res = groupby_stats_duckdb(db, "t", "g", ["x"])
assert res["status"] == "error", res
assert isinstance(res["error"], str) and res["error"]
def test_measures_vacias_agrega_solo_count(tmp_path):
rows = [("a", 1.0), ("a", 2.0), ("b", 3.0)]
db = _make_db(tmp_path, rows)
res = groupby_stats_duckdb(db, "t", "g", [])
assert res["status"] == "ok", res
by_key = {g["key"]: g for g in res["groups"]}
assert by_key["a"]["n"] == 2
assert by_key["b"]["n"] == 1
# Sin measures, stats por grupo es un dict vacio (valido).
assert by_key["a"]["stats"] == {}
assert by_key["b"]["stats"] == {}
def test_columna_con_espacio_agrupa_bien(tmp_path):
# Tabla con nombres de columna con espacios -> prueba el quoting con dobles
# comillas tanto en group_by como en la measure.
db = os.path.join(str(tmp_path), "space.duckdb")
con = duckdb.connect(db)
con.execute('CREATE TABLE t("my col" VARCHAR, "the val" DOUBLE)')
con.executemany(
'INSERT INTO t VALUES (?, ?)',
[("x", 1.0), ("x", 3.0), ("y", 10.0)],
)
con.close()
res = groupby_stats_duckdb(db, "t", "my col", ["the val"])
assert res["status"] == "ok", res
by_key = {g["key"]: g for g in res["groups"]}
assert by_key["x"]["n"] == 2
assert abs(by_key["x"]["stats"]["the val"]["mean"] - 2.0) < 1e-9
assert by_key["y"]["n"] == 1
assert abs(by_key["y"]["stats"]["the val"]["mean"] - 10.0) < 1e-9
@@ -1,92 +0,0 @@
---
name: pivot_table_duckdb
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: impure
signature: "def pivot_table_duckdb(db_path: str, table: str, index: str, columns: str, value: str, agg: str = 'mean', top_rows: int = 10, top_cols: int = 8) -> dict"
description: "Pivot table (index x columns -> agg(value)) calculada con push-down SQL en DuckDB (GROUP BY en el motor, sin traer filas a RAM) y recortada a las top_rows filas y top_cols columnas con mas observaciones para que quepa entera en un PDF movil / slide PPTX sin cortarse. Version push-down para tablas grandes de la funcion pura `pivot` (que pivota list[dict] en memoria)."
tags: [eda, pivot, duckdb, aggregate, datascience, push-down, report]
uses_functions: [duckdb_query_readonly_py_infra]
uses_types: []
returns: []
returns_optional: false
error_type: "error_go_core"
imports: []
params:
- name: db_path
desc: "Ruta al archivo DuckDB. Debe existir; el modo read_only NO crea la base."
- name: table
desc: "Nombre de la tabla a pivotar. Se interpola citado con dobles comillas (DuckDB no admite parametros para identificadores)."
- name: index
desc: "Columna cuyos valores forman las filas de la pivot (eje vertical)."
- name: columns
desc: "Columna cuyos valores forman las columnas de la pivot (eje horizontal)."
- name: value
desc: "Columna numerica a agregar en cada celda. Ignorada cuando agg='count'."
- name: agg
desc: "Funcion de agregacion: mean, sum, count, min, max, median. mean->avg(), count->COUNT(*). Otro valor devuelve {status:'error'}."
- name: top_rows
desc: "Numero maximo de filas a conservar, elegidas por mayor numero de observaciones (suma de COUNT(*) por valor de index). Default 10."
- name: top_cols
desc: "Numero maximo de columnas a conservar, elegidas por mayor numero de observaciones (suma de COUNT(*) por valor de columns). Default 8."
output: "dict. En exito {status:'ok', index, columns, value, agg, row_labels:[...], col_labels:[...], matrix:[[...]], truncated_rows:bool, truncated_cols:bool, note:str}. matrix tiene len(row_labels) filas y cada fila len(col_labels) celdas (valor agregado o None si la combinacion no existe). truncated_* indica si hubo mas filas/columnas que el top. En error {status:'error', error:str} (no lanza)."
tested: true
tests: ["pivot mean labels y celda conocida", "pivot trunca a top rows y top cols", "pivot count no necesita value real", "pivot db inexistente devuelve error sin lanzar", "pivot agg invalido devuelve error"]
test_file_path: "python/functions/datascience/pivot_table_duckdb_test.py"
file_path: "python/functions/datascience/pivot_table_duckdb.py"
---
## Ejemplo
```python
import duckdb
from datascience import pivot_table_duckdb
# Tabla DuckDB de prueba estilo titanic: sex x pclass -> mean(fare).
db = "/tmp/pivot_demo.duckdb"
con = duckdb.connect(db)
con.execute(
"CREATE TABLE titanic AS SELECT * FROM (VALUES "
"('male',1,211.3),('female',1,151.5),('male',3,7.9),"
"('female',3,16.7),('male',1,52.0),('female',2,41.6)"
") t(sex, pclass, fare)"
)
con.close()
res = pivot_table_duckdb(db, "titanic", index="sex", columns="pclass", value="fare", agg="mean")
print(res["status"]) # ok
print(res["row_labels"]) # ['female', 'male'] (orden por nº de observaciones desc; empate -> etiqueta)
print(res["col_labels"]) # [1, 3, 2] (pclass=1 tiene 3 obs, pclass=3 -> 2, pclass=2 -> 1)
print(res["matrix"]) # [[151.5, 16.7, 41.6], [131.65, 7.9, None]] (male/pclass=2 no existe -> None)
```
## Cuando usarla
Cuando quieres una pivot table (`index` x `columns` -> `agg(value)`) de una tabla
DuckDB con MUCHAS filas y necesitas que el resultado quepa entero en un informe: un
PDF abierto en el movil o un slide PPTX, donde una matriz de 50x30 se cortaria. La
agregacion se hace push-down en el motor (no traes las filas a RAM) y el resultado se
limita a las `top_rows` x `top_cols` combinaciones con mas observaciones. Encaja en el
flujo `eda` para resumir el cruce de dos categoricas (sexo x clase, region x producto)
contra una metrica. Para pivotar un `list[dict]` ya cargado en memoria usa la funcion
pura `pivot_py_datascience`; esta es la version push-down sobre disco.
## Gotchas
- Funcion impura: lee un archivo DuckDB del disco (read_only, nunca lo modifica).
- Recorta a `top_rows` x `top_cols` por numero de observaciones (suma de `COUNT(*)`),
NO por magnitud del valor agregado. Si habia mas filas/columnas, `truncated_rows` /
`truncated_cols` quedan en True y esas combinaciones NO aparecen en la matriz.
- Las celdas sin datos (combinacion `index` x `columns` que no existe en la tabla) se
rellenan con `None`, no con 0: distinguir "cero medido" de "sin observaciones".
- `agg='count'` cuenta filas por celda con `COUNT(*)` e ignora `value` (puedes pasar
cualquier nombre de columna). Para el resto de aggs, `value` debe ser una columna
numerica real o la query fallara con `{status:'error'}`.
- `agg` solo admite mean, sum, count, min, max, median; cualquier otro valor devuelve
`{status:'error'}` sin tocar la base.
- Orden de `row_labels` / `col_labels`: por numero de observaciones descendente, con
desempate estable por etiqueta. No es orden alfabetico ni el de aparicion.
- La query se ejecuta con `sandbox=False` en `duckdb_query_readonly` (uso interno
confiable: el SQL lo construye esta funcion, no un cliente externo).
@@ -1,176 +0,0 @@
"""pivot_table_duckdb — pivot table (index x columns -> agg(value)) con push-down SQL.
Funcion impura: lee de disco a traves de DuckDB reusando la primitiva read-only del
grupo `duckdb` (`duckdb_query_readonly`). Pertenece al grupo de capacidad `eda`
(exploratory data analysis).
A diferencia de la funcion pura `pivot` (que pivota un `list[dict]` ya cargado en
memoria), esta version empuja la agregacion al motor de DuckDB (push-down): el
GROUP BY lo resuelve DuckDB y solo se traen los valores agregados, nunca las filas
crudas. Esto la hace apta para tablas grandes.
Ademas reduce el resultado a las `top_rows` filas y `top_cols` columnas con mas
observaciones, de modo que la pivot quepa entera en un PDF movil / slide PPTX sin
cortarse. Marca `truncated_rows`/`truncated_cols` cuando hubo recorte.
Estilo dict-no-throw del grupo duckdb: nunca lanza; captura cualquier error y
devuelve {status:'error', error:str}.
"""
from collections import defaultdict
from infra import duckdb_query_readonly
# Funciones de agregacion permitidas y su nombre en SQL DuckDB.
# mean -> avg; el resto mapea directo. count se trata aparte (COUNT(*), sin value).
_AGG_SQL = {
"mean": "avg",
"sum": "sum",
"count": "count",
"min": "min",
"max": "max",
"median": "median",
}
def _quote_ident(ident: str) -> str:
"""Cita un identificador SQL con dobles comillas, escapando `"` -> `""`.
DuckDB no admite parametros posicionales para nombres de tabla/columna, asi que
hay que interpolarlos. El quoting con `"` y el doblado de comillas internas evita
que un nombre rompa la sentencia (mismo patron que correlation_matrix_duckdb).
"""
return '"' + str(ident).replace('"', '""') + '"'
def pivot_table_duckdb(
db_path: str,
table: str,
index: str,
columns: str,
value: str,
agg: str = "mean",
top_rows: int = 10,
top_cols: int = 8,
) -> dict:
"""Pivot table push-down en DuckDB, recortada a top_rows x top_cols.
Construye una pivot (filas = valores de `index`, columnas = valores de `columns`,
celda = `agg(value)`) agregando en el motor de DuckDB, y la reduce a las filas y
columnas con mas observaciones para que quepa en un PDF / slide.
Args:
db_path: ruta al archivo DuckDB. Debe existir (read_only NO crea la base).
table: nombre de la tabla a pivotar.
index: columna cuyos valores forman las filas de la pivot.
columns: columna cuyos valores forman las columnas de la pivot.
value: columna numerica a agregar. Ignorada cuando agg="count".
agg: funcion de agregacion. Una de: "mean", "sum", "count", "min", "max",
"median". mean se traduce a avg(); count a COUNT(*).
top_rows: numero maximo de filas a conservar, elegidas por mayor numero de
observaciones (suma de COUNT(*) por valor de index). Default 10.
top_cols: numero maximo de columnas a conservar, elegidas por mayor numero de
observaciones (suma de COUNT(*) por valor de columns). Default 8.
Returns:
dict. En exito:
{status:'ok',
index, columns, value, agg,
row_labels:[...], # valores de index, en orden de freq desc
col_labels:[...], # valores de columns, en orden de freq desc
matrix:[[...], ...], # len == len(row_labels); cada fila
# len == len(col_labels); celda = agg o None
truncated_rows:bool, truncated_cols:bool,
note:str}
En error (sin lanzar): {status:'error', error:str}.
"""
try:
if not isinstance(agg, str) or agg not in _AGG_SQL:
return {
"status": "error",
"error": "invalid agg "
+ repr(agg)
+ "; allowed: "
+ ", ".join(sorted(_AGG_SQL)),
}
# Paso 1 (push-down): agregar (index, columns) -> agg(value) + COUNT(*).
if agg == "count":
agg_expr = "COUNT(*)"
else:
agg_expr = f"{_AGG_SQL[agg]}({_quote_ident(value)})"
sql = (
f"SELECT {_quote_ident(index)} AS r, "
f"{_quote_ident(columns)} AS c, "
f"{agg_expr} AS v, "
f"COUNT(*) AS n "
f"FROM {_quote_ident(table)} "
f"GROUP BY {_quote_ident(index)}, {_quote_ident(columns)}"
)
# max_rows alto: queremos todos los grupos (index x columns) para elegir el
# top con criterio global. sandbox=False igual que correlation_matrix_duckdb,
# porque db_path es una ruta interna de confianza.
result = duckdb_query_readonly(
db_path, sql, max_rows=1_000_000, sandbox=False
)
if result.get("status") != "ok":
return {
"status": "error",
"error": "pivot query failed: "
+ str(result.get("error", "unknown")),
}
# Paso 2 (en python): contar observaciones por fila y por columna, y guardar
# el valor agregado de cada celda (r, c).
row_obs: dict = defaultdict(int)
col_obs: dict = defaultdict(int)
cell: dict = {}
for row in result.get("rows", []):
r = row.get("r")
c = row.get("c")
n = row.get("n") or 0
row_obs[r] += n
col_obs[c] += n
cell[(r, c)] = row.get("v")
def _top(obs: dict, limit: int):
# Orden: mas observaciones primero; desempate estable por etiqueta.
ranked = sorted(obs.items(), key=lambda kv: (-kv[1], str(kv[0])))
selected = [label for label, _ in ranked[:limit]]
return selected, len(ranked) > limit
row_labels, truncated_rows = _top(row_obs, top_rows)
col_labels, truncated_cols = _top(col_obs, top_cols)
# Paso 3: materializar la matriz; None donde la combinacion no existe.
matrix = [
[cell.get((r, c)) for c in col_labels] for r in row_labels
]
note = (
f"pivot {agg}({value}) reducida a {len(row_labels)}x{len(col_labels)} "
"(top por observaciones) para caber en PDF/slide"
)
if agg == "count":
note = (
f"pivot count(*) reducida a {len(row_labels)}x{len(col_labels)} "
"(top por observaciones) para caber en PDF/slide"
)
return {
"status": "ok",
"index": index,
"columns": columns,
"value": value,
"agg": agg,
"row_labels": row_labels,
"col_labels": col_labels,
"matrix": matrix,
"truncated_rows": truncated_rows,
"truncated_cols": truncated_cols,
"note": note,
}
except Exception as e: # noqa: BLE001
return {"status": "error", "error": str(e)}
@@ -1,115 +0,0 @@
"""Tests para pivot_table_duckdb."""
import os
import sys
import duckdb
# Permitir importar funciones del registry (from infra import ..., from datascience import ...).
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "functions"))
from datascience.pivot_table_duckdb import pivot_table_duckdb
def _make_db(tmp_name: str) -> str:
"""Crea una DuckDB con dos categoricas (a, b) y un valor numerico conocido.
Filas:
a='x', b='y', val=10
a='x', b='y', val=20 -> mean(x,y) = 15, count(x,y) = 2
a='x', b='z', val=5 -> mean(x,z) = 5
a='w', b='y', val=100 -> mean(w,y) = 100
Observaciones por a: x=3, w=1. Por b: y=3, z=1.
La combinacion (w, z) no existe -> celda None.
"""
db = os.path.join("/tmp", tmp_name)
if os.path.exists(db):
os.remove(db)
con = duckdb.connect(db)
con.execute("CREATE TABLE t (a VARCHAR, b VARCHAR, val DOUBLE)")
con.execute(
"INSERT INTO t VALUES "
"('x','y',10),('x','y',20),('x','z',5),('w','y',100)"
)
con.close()
return db
def test_pivot_mean_labels_y_celda_conocida():
db = _make_db("pivot_test_mean.duckdb")
res = pivot_table_duckdb(db, "t", index="a", columns="b", value="val", agg="mean")
assert res["status"] == "ok", res
# Filas ordenadas por observaciones desc: x (3) antes que w (1).
assert res["row_labels"] == ["x", "w"], res["row_labels"]
# Columnas ordenadas por observaciones desc: y (3) antes que z (1).
assert res["col_labels"] == ["y", "z"], res["col_labels"]
# matrix[0][0] = mean(a='x', b='y') = (10 + 20) / 2 = 15.
assert abs(res["matrix"][0][0] - 15.0) < 1e-9, res["matrix"]
# matrix[0][1] = mean(a='x', b='z') = 5.
assert abs(res["matrix"][0][1] - 5.0) < 1e-9, res["matrix"]
# matrix[1][0] = mean(a='w', b='y') = 100.
assert abs(res["matrix"][1][0] - 100.0) < 1e-9, res["matrix"]
# (w, z) no existe -> None.
assert res["matrix"][1][1] is None, res["matrix"]
# Sin truncado con los defaults (top_rows=10, top_cols=8).
assert res["truncated_rows"] is False
assert res["truncated_cols"] is False
# La matriz es rectangular consistente con las etiquetas.
assert len(res["matrix"]) == len(res["row_labels"])
for fila in res["matrix"]:
assert len(fila) == len(res["col_labels"])
def test_pivot_trunca_a_top_rows_y_top_cols():
db = _make_db("pivot_test_trunc.duckdb")
res = pivot_table_duckdb(
db, "t", index="a", columns="b", value="val", agg="mean",
top_rows=1, top_cols=1,
)
assert res["status"] == "ok", res
# Solo la fila/columna mas frecuente sobrevive.
assert res["row_labels"] == ["x"], res["row_labels"]
assert res["col_labels"] == ["y"], res["col_labels"]
assert res["matrix"] == [[15.0]], res["matrix"]
# Habia mas de 1 fila y mas de 1 columna -> truncado en ambos ejes.
assert res["truncated_rows"] is True
assert res["truncated_cols"] is True
def test_pivot_count_no_necesita_value_real():
db = _make_db("pivot_test_count.duckdb")
# value apunta a una columna real pero count(*) la ignora; tambien valdria un
# nombre cualquiera. Verificamos que count funciona igualmente.
res = pivot_table_duckdb(
db, "t", index="a", columns="b", value="val", agg="count"
)
assert res["status"] == "ok", res
assert res["row_labels"] == ["x", "w"]
assert res["col_labels"] == ["y", "z"]
# count(a='x', b='y') = 2 observaciones.
assert res["matrix"][0][0] == 2, res["matrix"]
# count(a='x', b='z') = 1.
assert res["matrix"][0][1] == 1, res["matrix"]
# count(a='w', b='y') = 1.
assert res["matrix"][1][0] == 1, res["matrix"]
# (w, z) no existe -> None.
assert res["matrix"][1][1] is None, res["matrix"]
def test_pivot_db_inexistente_devuelve_error_sin_lanzar():
res = pivot_table_duckdb(
"/nonexistent/path/does_not_exist.duckdb",
"t", index="a", columns="b", value="val", agg="mean",
)
assert res["status"] == "error", res
assert isinstance(res["error"], str)
def test_pivot_agg_invalido_devuelve_error():
db = _make_db("pivot_test_badagg.duckdb")
res = pivot_table_duckdb(
db, "t", index="a", columns="b", value="val", agg="stddev"
)
assert res["status"] == "error", res
assert "invalid agg" in res["error"]
@@ -1,158 +0,0 @@
---
id: select_groupby_keys_py_datascience
name: select_groupby_keys
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: pure
signature: "def select_groupby_keys(profile: dict, max_keys: int = 3, max_card: int = 20, max_measures: int = 4) -> dict"
description: "Elige deterministicamente las columnas categoricas mas interesantes para GROUP BY, las numericas medida y pares pivote a partir de un TableProfile del grupo eda. Respaldo cuantitativo para el capitulo de agregacion/OLAP de un EDA. Funcion pura, no muta el input, nunca lanza."
tags: [eda, aggregation, groupby, olap, profiling, datascience]
uses_functions: []
uses_types: []
returns: []
returns_optional: false
error_type: ""
imports: []
example: |
from datascience import select_groupby_keys
profile = {
"n_rows": 891,
"key_candidates": ["passenger_id"],
"columns": [
{"name": "sex", "inferred_type": "categorical", "distinct_count": 2,
"unique_pct": 0.002, "null_pct": 0.0, "flags": [],
"categorical": {"imbalance": 1.8}, "numeric": None},
{"name": "pclass", "inferred_type": "categorical", "distinct_count": 3,
"unique_pct": 0.003, "null_pct": 0.0, "flags": [],
"categorical": {"imbalance": 2.5}, "numeric": None},
{"name": "fare", "inferred_type": "numeric", "distinct_count": 200,
"unique_pct": 0.2, "null_pct": 0.0, "flags": [],
"numeric": {"std": 49.7, "cv": 1.54}, "categorical": None},
],
}
select_groupby_keys(profile)
# {"group_keys": [{"col": "sex", ...}, {"col": "pclass", ...}],
# "measures": ["fare"],
# "pivots": [{"index": "sex", "columns": "pclass", "value": "fare"}],
# "note": "2 clave(s) de grupo: sex, pclass; 1 medida(s): fare; 1 pivot(s)."}
tested: true
tests:
- "test_titanic_picks_good_cats_excludes_id_and_constant"
- "test_titanic_measures_exclude_id_constant_and_keep_numerics"
- "test_titanic_generates_one_pivot"
- "test_empty_profile_returns_all_empty_and_does_not_crash"
- "test_none_profile_does_not_crash"
- "test_only_numerics_yields_empty_group_keys_and_no_pivots"
- "test_high_cardinality_and_max_card_are_excluded"
- "test_max_keys_limits_group_keys"
- "test_three_keys_cap_pivots_to_two"
- "test_does_not_mutate_input"
test_file_path: "python/functions/datascience/select_groupby_keys_test.py"
file_path: "python/functions/datascience/select_groupby_keys.py"
params:
- name: profile
desc: >
TableProfile dict del grupo eda (p.ej. salida de summarize_table_duckdb).
Se lee de forma defensiva (.get / or [] / isinstance). Claves usadas:
columns (list[ColumnProfile]), key_candidates (list de nombres de columna
o dicts {name}), n_rows. Cada ColumnProfile usa: name, inferred_type
("numeric"|"categorical"|"datetime"|"text"|"boolean"), distinct_count,
unique_pct (0..1), null_pct (0..1), flags (list[str], reconoce
"possible_id"/"high_cardinality"/"constant"), numeric ({std, cv, ...}|None)
y categorical ({imbalance, mode_pct, ...}|None).
- name: max_keys
desc: "Numero maximo de claves de grupo (group_keys) a devolver. Default 3."
- name: max_card
desc: >
Cardinalidad maxima (distinct_count) que una columna categorica puede
tener para seguir siendo candidata a clave de grupo. Default 20.
- name: max_measures
desc: "Numero maximo de columnas medida (nombres) a devolver. Default 4."
output: >
dict con group_keys (list de {col, cardinality, score} ordenada por score
desc), measures (list[str] de nombres de columnas numericas ordenadas por
dispersion), pivots (list de {index, columns, value}, hasta 2 pares
categorica x categorica con la primera measure como valor) y note (str,
resumen corto en espanol de lo elegido). Ante profile vacio/None devuelve
todas las listas vacias y una note descriptiva; nunca lanza.
---
## Ejemplo
```python
from datascience import select_groupby_keys
# TableProfile estilo titanic: 2 categoricas buenas, 1 numerica medida,
# 1 id secuencial (descartado) y un key_candidate declarado.
profile = {
"n_rows": 891,
"key_candidates": ["passenger_id"],
"columns": [
{"name": "sex", "inferred_type": "categorical", "distinct_count": 2,
"unique_pct": 0.002, "null_pct": 0.0, "flags": [],
"categorical": {"imbalance": 1.8}, "numeric": None},
{"name": "pclass", "inferred_type": "categorical", "distinct_count": 3,
"unique_pct": 0.003, "null_pct": 0.0, "flags": [],
"categorical": {"imbalance": 2.5}, "numeric": None},
{"name": "fare", "inferred_type": "numeric", "distinct_count": 200,
"unique_pct": 0.2, "null_pct": 0.0, "flags": [],
"numeric": {"std": 49.7, "cv": 1.54}, "categorical": None},
{"name": "passenger_id", "inferred_type": "numeric", "distinct_count": 891,
"unique_pct": 1.0, "null_pct": 0.0, "flags": ["possible_id"],
"numeric": {"std": 257.4, "cv": 0.58}, "categorical": None},
],
}
select_groupby_keys(profile)
# {
# "group_keys": [
# {"col": "sex", "cardinality": 2, "score": 0.5556},
# {"col": "pclass", "cardinality": 3, "score": 0.4},
# ],
# "measures": ["fare"], # passenger_id excluido (id secuencial)
# "pivots": [{"index": "sex", "columns": "pclass", "value": "fare"}],
# "note": "2 clave(s) de grupo: sex, pclass; 1 medida(s): fare; 1 pivot(s).",
# }
```
## Cuando usarla
Cuando hayas perfilado una tabla con el grupo `eda` (p.ej.
`summarize_table_duckdb`) y necesites decidir, sin mirar los datos, por qué
columnas merece la pena agrupar (GROUP BY) y qué métricas numéricas agregar:
el respaldo cuantitativo del capítulo de agregación/OLAP de un AutomaticEDA, o
para proponer pivotes en un dashboard. Es la capa de selección sobre el
TableProfile crudo: lee el perfil, ordena candidatos de forma determinista y
no toca los datos.
## Notas
Función pura, sin I/O ni dependencias externas (solo stdlib), no muta
`profile`. Lectura defensiva total (`.get`, `or []`, `isinstance`): un `{}` o
`None` produce `{"group_keys": [], "measures": [], "pivots": [], "note": ...}`
y nunca lanza.
Criterios de selección (deterministas):
- **group_keys** — candidatas con `inferred_type` en `("categorical","boolean")`.
Se descartan las que estén en `key_candidates`, con flag
`possible_id`/`high_cardinality`/`constant`, con `distinct_count` fuera de
`[2, max_card]`, o all-null (`null_pct >= 0.999`). `score = card_score *
balance_score`: `card_score` mantiene un plateau para cardinalidad moderada
(2..12) y decae hacia `max_card`; `balance_score = 1/imbalance` usando
`categorical.imbalance` si está, aproximando con `mode_pct` si no, o un valor
neutro (0.5) en último caso. Devuelve hasta `max_keys`, ordenadas por score
desc (empates por orden de columna).
- **measures** — candidatas con `inferred_type` en
`("numeric","integer","float")`. Se descartan id-like (flag `possible_id` y
`unique_pct >= 0.99`) y constantes (`numeric.std` == 0 o None). Se rankean por
dispersión informativa: `abs(cv)` si está, si no `abs(std)`. Devuelve hasta
`max_measures` **nombres** (strings).
- **pivots** — hasta 2 pares `(group_keys[i].col, group_keys[j].col)` con i<j y
la primera measure como valor. Vacío si hay menos de 2 group_keys.
Caveat de ranking de measures: mezclar `cv` (adimensional) con `std` (en
unidades de la columna) cuando una columna carece de `cv` puede dar órdenes
poco comparables entre columnas; se prefiere `cv` siempre que esté disponible.
@@ -1,310 +0,0 @@
"""Pure EDA helper: pick GROUP BY keys and measures from a TableProfile.
Given a ``TableProfile`` of the ``eda`` group (the dict produced by, e.g.,
``summarize_table_duckdb``), this function deterministically selects the most
interesting categorical columns to group by (GROUP BY), the numeric measure
columns to aggregate, and a couple of categorical x categorical pivot pairs.
It is the quantitative backbone for the aggregation / OLAP chapter of an
AutomaticEDA: a pure, deterministic ranking over the profile, with no I/O, no
mutation of the input and no external dependencies (stdlib only). It never
raises a missing or malformed profile yields an empty, well-formed result.
"""
def select_groupby_keys(
profile: dict,
max_keys: int = 3,
max_card: int = 20,
max_measures: int = 4,
) -> dict:
"""Select GROUP BY keys, measures and pivot pairs from a TableProfile.
Reads everything defensively (``.get(...)``, ``or []``, ``isinstance``) and
never raises. With an empty/None profile it returns every list empty.
Selection rules (deterministic):
- **group_keys** (categorical columns to group by): candidates have
``inferred_type`` in ``("categorical", "boolean")``. Discarded if they are
in ``profile['key_candidates']``, carry a ``possible_id`` /
``high_cardinality`` / ``constant`` flag, have ``distinct_count`` outside
``[2, max_card]``, or are all-null (``null_pct >= 0.999``). Each survivor
gets ``score = card_score * balance_score`` where ``card_score`` keeps a
plateau for moderate cardinality (2..12) and decays towards ``max_card``,
and ``balance_score = 1 / imbalance`` (``categorical.imbalance`` if
present, else approximated from ``mode_pct``, else a neutral default).
The top ``max_keys`` by score (desc, ties by column order) are returned.
- **measures** (numeric columns to aggregate): candidates have
``inferred_type`` in ``("numeric", "integer", "float")``. Discarded if
id-like (``possible_id`` flag *and* ``unique_pct >= 0.99``) or constant
(``numeric.std`` is ``0`` or ``None``). Ranked by informative dispersion:
``abs(cv)`` when available, else ``abs(std)``. The top ``max_measures``
**names** are returned.
- **pivots**: up to 2 ``(group_keys[i].col, group_keys[j].col)`` pairs with
``i < j``, using the first measure as the aggregated value. Empty when
fewer than 2 group keys were selected.
Args:
profile: TableProfile dict of the ``eda`` group. Relevant keys:
``columns`` (list[ColumnProfile]), ``key_candidates`` (list of
column names or ``{name}`` dicts), ``n_rows``. Each ColumnProfile
uses: ``name``, ``inferred_type``, ``distinct_count``,
``unique_pct`` (0..1), ``null_pct`` (0..1), ``flags`` (list[str]),
``numeric`` ({std, cv, ...}|None), ``categorical``
({imbalance, mode_pct, ...}|None).
max_keys: Maximum number of group-by keys to return. Default 3.
max_card: Maximum cardinality (``distinct_count``) a categorical column
may have to still qualify as a group key. Default 20.
max_measures: Maximum number of measure names to return. Default 4.
Returns:
dict with:
group_keys (list[{col, cardinality, score}], ordered by score desc),
measures (list[str], numeric column names ordered by dispersion),
pivots (list[{index, columns, value}], up to 2 pairs),
note (str, short summary of what was chosen).
"""
if not isinstance(profile, dict):
profile = {}
try:
max_keys = int(max_keys)
except (TypeError, ValueError):
max_keys = 3
try:
max_card = int(max_card)
except (TypeError, ValueError):
max_card = 20
try:
max_measures = int(max_measures)
except (TypeError, ValueError):
max_measures = 4
max_keys = max(max_keys, 0)
max_card = max(max_card, 2)
max_measures = max(max_measures, 0)
columns = profile.get("columns") or []
if not isinstance(columns, (list, tuple)):
columns = []
key_names = _key_candidate_names(profile.get("key_candidates"))
group_keys = _select_group_keys(columns, key_names, max_keys, max_card)
measures = _select_measures(columns, max_measures)
pivots = _select_pivots(group_keys, measures)
return {
"group_keys": group_keys,
"measures": measures,
"pivots": pivots,
"note": _build_note(group_keys, measures, pivots),
}
# ---------------------------------------------------------------------------
# group_keys
# ---------------------------------------------------------------------------
_GROUP_TYPES = ("categorical", "boolean")
_DISQUALIFYING_FLAGS = frozenset({"possible_id", "high_cardinality", "constant"})
_CARD_PLATEAU_HI = 12 # cardinalities 2..12 are all "moderate" (best).
def _select_group_keys(columns, key_names, max_keys, max_card) -> list:
"""Rank categorical/boolean columns suitable for GROUP BY."""
scored = []
for idx, col in enumerate(columns):
if not isinstance(col, dict):
continue
if (col.get("inferred_type") or "") not in _GROUP_TYPES:
continue
name = col.get("name")
if name is None:
continue
if name in key_names:
continue
flags = _as_set(col.get("flags"))
if flags & _DISQUALIFYING_FLAGS:
continue
if _num(col.get("null_pct"), 0.0) >= 0.999:
continue
card = _num(col.get("distinct_count"), 0.0)
if card < 2 or card > max_card:
continue
card_i = int(card)
score = _card_score(card_i, max_card) * _balance_score(col.get("categorical"))
scored.append((round(score, 6), idx, name, card_i))
# Deterministic: higher score first, ties broken by original column order.
scored.sort(key=lambda t: (-t[0], t[1]))
out = []
for score, _idx, name, card_i in scored[:max_keys]:
out.append({"col": name, "cardinality": card_i, "score": score})
return out
def _card_score(card: int, max_card: int) -> float:
"""Prefer moderate cardinality; plateau at 2..12, decay towards max_card."""
if card <= 1:
return 0.0
if card <= _CARD_PLATEAU_HI:
return 1.0
denom = max(max_card - _CARD_PLATEAU_HI, 1)
over = card - _CARD_PLATEAU_HI
return max(0.1, 1.0 - over / denom)
def _balance_score(categorical) -> float:
"""1.0 for a perfectly balanced category, decaying as imbalance grows.
Uses ``categorical.imbalance`` (max_count/min_count, >= 1) when available;
otherwise approximates from ``mode_pct`` (top-class dominance); otherwise a
neutral default so the column is still selectable.
"""
if isinstance(categorical, dict):
imbalance = categorical.get("imbalance")
if isinstance(imbalance, (int, float)) and imbalance >= 1.0:
return 1.0 / float(imbalance)
mode_pct = categorical.get("mode_pct")
if isinstance(mode_pct, (int, float)):
return _clamp(1.0 - float(mode_pct), 0.0, 1.0)
return 0.5
# ---------------------------------------------------------------------------
# measures
# ---------------------------------------------------------------------------
_NUMERIC_TYPES = ("numeric", "integer", "float")
def _select_measures(columns, max_measures) -> list:
"""Rank numeric columns by informative dispersion (cv, else std)."""
scored = []
for idx, col in enumerate(columns):
if not isinstance(col, dict):
continue
if (col.get("inferred_type") or "") not in _NUMERIC_TYPES:
continue
name = col.get("name")
if name is None:
continue
flags = _as_set(col.get("flags"))
unique_pct = _num(col.get("unique_pct"), 0.0)
if "possible_id" in flags and unique_pct >= 0.99:
continue # sequential id, not a measure.
numeric = col.get("numeric")
std = numeric.get("std") if isinstance(numeric, dict) else None
if not isinstance(std, (int, float)) or std == 0:
continue # constant or unknown spread -> not informative.
cv = numeric.get("cv") if isinstance(numeric, dict) else None
if isinstance(cv, (int, float)):
dispersion = abs(float(cv))
else:
dispersion = abs(float(std))
scored.append((dispersion, idx, name))
# Higher dispersion first, ties broken by original column order.
scored.sort(key=lambda t: (-t[0], t[1]))
return [name for _disp, _idx, name in scored[:max_measures]]
# ---------------------------------------------------------------------------
# pivots
# ---------------------------------------------------------------------------
def _select_pivots(group_keys, measures) -> list:
"""Up to 2 (cat_a, cat_b) pairs from the chosen group keys."""
if not isinstance(group_keys, list) or len(group_keys) < 2:
return []
value = measures[0] if measures else None
pairs = []
n = len(group_keys)
for i in range(n):
for j in range(i + 1, n):
pairs.append({
"index": group_keys[i].get("col"),
"columns": group_keys[j].get("col"),
"value": value,
})
if len(pairs) >= 2:
return pairs
return pairs
# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------
def _build_note(group_keys, measures, pivots) -> str:
"""One-line Spanish summary of the selection."""
parts = []
if group_keys:
cols = ", ".join(str(g.get("col")) for g in group_keys)
parts.append(f"{len(group_keys)} clave(s) de grupo: {cols}")
else:
parts.append("sin categóricas agrupables")
if measures:
parts.append(f"{len(measures)} medida(s): " + ", ".join(str(m) for m in measures))
else:
parts.append("sin medidas numéricas")
if pivots:
parts.append(f"{len(pivots)} pivot(s)")
return "; ".join(parts) + "."
def _key_candidate_names(key_candidates) -> set:
"""Normalize ``key_candidates`` (strings or ``{name}`` dicts) to a name set."""
names = set()
if not isinstance(key_candidates, (list, tuple)):
return names
for entry in key_candidates:
if isinstance(entry, str):
names.add(entry)
elif isinstance(entry, dict):
nm = entry.get("name") or entry.get("col")
if nm is not None:
names.add(nm)
return names
def _as_set(flags) -> set:
"""Coerce a flags value into a set, tolerating None / non-iterables."""
if isinstance(flags, (list, tuple, set)):
return set(flags)
return set()
def _num(value, default: float) -> float:
"""Best-effort float conversion with a fallback default."""
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def _clamp(x: float, lo: float, hi: float) -> float:
"""Recorta x al rango [lo, hi]."""
if x < lo:
return lo
if x > hi:
return hi
return x
@@ -1,213 +0,0 @@
"""Tests para select_groupby_keys (grupo eda, dominio datascience)."""
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from select_groupby_keys import select_groupby_keys
def _cat_col(name, card, *, imbalance=2.0, flags=None, null_pct=0.0):
"""ColumnProfile categorico minimo con bloque categorical."""
return {
"name": name,
"inferred_type": "categorical",
"distinct_count": card,
"unique_pct": card / 1000.0,
"null_pct": null_pct,
"flags": flags or [],
"numeric": None,
"categorical": {"imbalance": imbalance, "mode_pct": 0.5, "n_distinct": card},
}
def _num_col(name, *, std, cv, flags=None, unique_pct=0.1):
"""ColumnProfile numerico minimo con bloque numeric."""
return {
"name": name,
"inferred_type": "numeric",
"distinct_count": 200,
"unique_pct": unique_pct,
"null_pct": 0.0,
"flags": flags or [],
"numeric": {"std": std, "cv": cv},
"categorical": None,
}
def _titanic_like_profile() -> dict:
"""Perfil estilo titanic: 2 categoricas buenas, 2 numericas, 1 id, 1 constante."""
return {
"n_rows": 891,
"key_candidates": ["passenger_id"],
"columns": [
_cat_col("sex", 2, imbalance=1.8),
_cat_col("pclass", 3, imbalance=2.5),
_num_col("age", std=14.5, cv=0.49),
_num_col("fare", std=49.7, cv=1.54),
# id secuencial: flag possible_id + unique_pct alto.
{
"name": "passenger_id",
"inferred_type": "numeric",
"distinct_count": 891,
"unique_pct": 1.0,
"null_pct": 0.0,
"flags": ["possible_id"],
"numeric": {"std": 257.4, "cv": 0.58},
"categorical": None,
},
# columna constante: flag constant + std 0.
{
"name": "embarked_const",
"inferred_type": "categorical",
"distinct_count": 1,
"unique_pct": 0.001,
"null_pct": 0.0,
"flags": ["constant"],
"numeric": None,
"categorical": {"imbalance": 1.0},
},
],
}
def test_titanic_picks_good_cats_excludes_id_and_constant():
out = select_groupby_keys(_titanic_like_profile())
# Elige las dos categoricas buenas.
chosen_cols = {g["col"] for g in out["group_keys"]}
assert chosen_cols == {"sex", "pclass"}
# Excluye la constante y el key_candidate.
assert "embarked_const" not in chosen_cols
assert "passenger_id" not in chosen_cols
# Cada group key trae col, cardinality y score.
for g in out["group_keys"]:
assert set(g.keys()) == {"col", "cardinality", "score"}
assert isinstance(g["score"], float)
by_col = {g["col"]: g for g in out["group_keys"]}
assert by_col["sex"]["cardinality"] == 2
assert by_col["pclass"]["cardinality"] == 3
# Ordenadas por score descendente.
scores = [g["score"] for g in out["group_keys"]]
assert scores == sorted(scores, reverse=True)
def test_titanic_measures_exclude_id_constant_and_keep_numerics():
out = select_groupby_keys(_titanic_like_profile())
# Solo nombres (strings) de numericas informativas, sin el id secuencial.
assert all(isinstance(m, str) for m in out["measures"])
assert "passenger_id" not in out["measures"]
assert set(out["measures"]) == {"age", "fare"}
# fare tiene mayor cv (1.54 > 0.49) -> primero.
assert out["measures"][0] == "fare"
def test_titanic_generates_one_pivot():
out = select_groupby_keys(_titanic_like_profile())
# Con 2 group keys -> exactamente 1 pivot.
assert len(out["pivots"]) == 1
pivot = out["pivots"][0]
assert set(pivot.keys()) == {"index", "columns", "value"}
assert {pivot["index"], pivot["columns"]} == {"sex", "pclass"}
# El valor es la primera measure (fare).
assert pivot["value"] == "fare"
def test_empty_profile_returns_all_empty_and_does_not_crash():
out = select_groupby_keys({})
assert out["group_keys"] == []
assert out["measures"] == []
assert out["pivots"] == []
assert isinstance(out["note"], str)
def test_none_profile_does_not_crash():
out = select_groupby_keys(None)
assert out == {
"group_keys": [],
"measures": [],
"pivots": [],
"note": out["note"],
}
assert isinstance(out["note"], str)
def test_only_numerics_yields_empty_group_keys_and_no_pivots():
profile = {
"n_rows": 500,
"key_candidates": [],
"columns": [
_num_col("price", std=12.0, cv=0.6),
_num_col("weight", std=3.0, cv=0.2),
],
}
out = select_groupby_keys(profile)
assert out["group_keys"] == []
assert out["pivots"] == []
# Las numericas si se eligen como measures.
assert set(out["measures"]) == {"price", "weight"}
assert out["measures"][0] == "price" # mayor cv.
def test_high_cardinality_and_max_card_are_excluded():
profile = {
"n_rows": 1000,
"key_candidates": [],
"columns": [
_cat_col("city", 50, flags=["high_cardinality"]), # flag -> fuera.
_cat_col("zone", 35), # card 35 > max_card 20 -> fuera.
_cat_col("region", 5), # valida.
],
}
out = select_groupby_keys(profile, max_card=20)
assert {g["col"] for g in out["group_keys"]} == {"region"}
def test_max_keys_limits_group_keys():
profile = {
"n_rows": 1000,
"key_candidates": [],
"columns": [
_cat_col("a", 4, imbalance=1.0),
_cat_col("b", 5, imbalance=1.2),
_cat_col("c", 6, imbalance=1.5),
_cat_col("d", 7, imbalance=2.0),
],
}
out = select_groupby_keys(profile, max_keys=2)
assert len(out["group_keys"]) == 2
# Hasta 2 pivots con >=2 keys (aqui exactamente 1 par posible entre 2 keys).
assert len(out["pivots"]) == 1
def test_three_keys_cap_pivots_to_two():
profile = {
"n_rows": 1000,
"key_candidates": [],
"columns": [
_cat_col("a", 4, imbalance=1.0),
_cat_col("b", 5, imbalance=1.1),
_cat_col("c", 6, imbalance=1.2),
_num_col("m", std=10.0, cv=0.5),
],
}
out = select_groupby_keys(profile, max_keys=3)
assert len(out["group_keys"]) == 3
# 3 keys -> 3 pares posibles, capado a 2.
assert len(out["pivots"]) == 2
for p in out["pivots"]:
assert p["value"] == "m"
def test_does_not_mutate_input():
profile = _titanic_like_profile()
before = repr(profile)
select_groupby_keys(profile)
assert repr(profile) == before
@@ -1,96 +0,0 @@
---
name: suggest_aggregations_llm
kind: function
lang: py
domain: datascience
version: "1.0.0"
purity: impure
signature: "def suggest_aggregations_llm(profile: dict, candidates: dict, max_aggs: int = 4, model: str = \"claude-haiku-4-5-20251001\") -> dict"
description: "MUST-11.1 del capitulo AGREGACION del AutomaticEDA (grupo eda). Dado el TableProfile de una tabla y los candidatos cuantitativos de select_groupby_keys ({group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}), con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una razon corta cada uno, evitando la explosion combinatoria (no todo contra todo). Privacidad/coste: NO envia filas crudas, solo el resumen AGREGADO de los candidatos (tabla, columnas categoricas con cardinalidad/score, medidas, pivots). Reusa ask_llm del grupo claude-direct (API directa con token OAuth de Claude). Impura, dict-no-throw: NUNCA lanza y SIEMPRE devuelve algo usable; si el LLM falla, el JSON no parsea o no hay seleccion valida, cae a un fallback determinista construido desde los candidatos (source='fallback'). Toda columna que el LLM invente se descarta."
tags: [eda, claude-direct, llm, aggregation, groupby, pivot, datascience, automatic-eda]
params:
- name: profile
desc: "TableProfile del grupo eda. Solo se usa profile['table'] para nombrar la tabla en el prompt; puede ir vacio o sin esa clave (se usa '(tabla sin nombre)')."
- name: candidates
desc: "Salida de select_groupby_keys: {group_keys:[{col, cardinality, score}], measures:[str], pivots:[{index, columns, value}]}. group_keys = columnas categoricas candidatas para GROUP BY; measures = columnas numericas a agregar (sum/avg); pivots = cruces index x columns -> value sugeridos. Cualquier columna que el LLM elija debe existir aqui o se descarta. None o no-dict se trata como vacio."
- name: max_aggs
desc: "Tope de agregaciones a devolver. Default 4. Valores <1 o no-int se normalizan a 4. Limita tanto la seleccion del LLM como el fallback determinista, para evitar la explosion combinatoria."
- name: model
desc: "id del modelo Anthropic a usar en la unica llamada. Default 'claude-haiku-4-5-20251001' (haiku, coste bajo, ~2-3s). Para razones mas finas, pasar p.ej. 'claude-opus-4-8'."
output: "dict dict-no-throw: {status:'ok', source:'llm'|'fallback', aggregations:[{group_by:str, measures:[str], why:str}], pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}. source=='llm' si el LLM produjo al menos una agregacion valida (columnas existentes en candidates); en cualquier otro caso (LLM caido, JSON invalido, seleccion vacia, sin candidatos) source=='fallback' y aggregations/pivots se derivan de candidates con why='selección cuantitativa (sin LLM)'. NUNCA lanza."
uses_functions: [ask_llm_py_core, select_groupby_keys_py_datascience]
uses_types: []
returns: []
returns_optional: false
error_type: "error_go_core"
imports: []
tested: true
tests: ["test_extract_json_object", "test_extract_json_wrapped_in_fences_and_junk", "test_extract_json_non_json_returns_none", "test_validate_aggregations_drops_invalid_columns", "test_llm_path_uses_selection", "test_llm_path_respects_max_aggs", "test_llm_invented_column_is_discarded", "test_fallback_on_empty_llm_response", "test_fallback_on_unparseable_response", "test_fallback_respects_max_aggs", "test_fallback_when_llm_raises", "test_no_candidates_returns_empty_fallback", "test_non_dict_candidates_does_not_raise"]
test_file_path: "python/functions/datascience/suggest_aggregations_llm_test.py"
file_path: "python/functions/datascience/suggest_aggregations_llm.py"
---
## Ejemplo
```python
import sys, os
sys.path.insert(0, os.path.join("python", "functions"))
from datascience.suggest_aggregations_llm import suggest_aggregations_llm
profile = {"table": "ventas"}
# candidates = salida de select_groupby_keys (aqui literal de ejemplo).
candidates = {
"group_keys": [
{"col": "categoria", "cardinality": 8, "score": 0.91},
{"col": "region", "cardinality": 5, "score": 0.74},
{"col": "canal", "cardinality": 3, "score": 0.60},
],
"measures": ["importe", "unidades"],
"pivots": [
{"index": "categoria", "columns": "region", "value": "importe"},
],
}
out = suggest_aggregations_llm(profile, candidates, max_aggs=4) # haiku por defecto
print("fuente:", out["source"]) # "llm" o "fallback" si no hay red
for agg in out["aggregations"]:
print(f"GROUP BY {agg['group_by']} -> {agg['measures']} ({agg['why']})")
for piv in out["pivots"]:
print(f"pivot {piv['index']} x {piv['columns']} = {piv['value']} ({piv['why']})")
```
## Cuando usarla
Justo despues de `select_groupby_keys` en el capitulo AGREGACION del AutomaticEDA:
cuando ya tienes los candidatos cuantitativos (columnas categoricas con cardinalidad,
medidas numericas y pivots posibles) y quieres que un LLM se quede con las K
agregaciones y pivots MAS INFORMATIVOS en vez de generar "todo contra todo". Usala para
priorizar el plan de analisis de grupos antes de materializar las tablas con
`aggregate_by_group` / pivots, manteniendo el coste y el ruido bajos. Si no hay red o
credenciales, sigue funcionando con un fallback determinista, asi que es seguro
ponerla en un pipeline.
## Gotchas
- **Impura: hace 1 llamada de red al LLM.** No es determinista ni gratis. Latencia
tipica ~2-3s con haiku. Una sola llamada cubre toda la seleccion.
- **Requiere token OAuth de Claude** en `~/.claude/.credentials.json` (via `ask_llm` /
grupo `claude-direct`). Sin token / sin red NO lanza: cae al **fallback
determinista** (`source="fallback"`) construido desde `candidates`
(group_keys x measures hasta `max_aggs`, pivots tal cual) con
`why="selección cuantitativa (sin LLM)"`. Comprueba `out["source"]` para saber si la
seleccion vino del LLM o del fallback.
- **NO envia filas crudas al LLM**, solo el resumen AGREGADO de los candidatos. Esto
exige que `candidates` venga ya calculado por `select_groupby_keys` (cardinalidades,
scores, medidas, pivots).
- **Valida columnas inventadas**: si el LLM propone un `group_by`/`measure`/`index`/
`columns` que no esta en `candidates`, esa entrada se descarta (las medidas se
recortan a las validas). Si tras validar no queda ninguna agregacion, cae al
fallback completo.
- **`max_aggs` acota la explosion combinatoria** tanto en el camino LLM como en el
fallback. Subirlo demasiado reintroduce el ruido que esta funcion evita.
- **Modelo `haiku` por defecto** para coste bajo; sube a `claude-opus-4-8` si necesitas
razones (`why`) mas finas (mas caro y lento).
@@ -1,405 +0,0 @@
"""suggest_aggregations_llm — el LLM elige las agregaciones mas informativas (grupo `eda`).
MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Dado el `TableProfile` de una
tabla y los CANDIDATOS cuantitativos que produce `select_groupby_keys`
(`{group_keys:[{col,cardinality,score}], measures:[str], pivots:[{index,columns,value}]}`),
con UNA sola llamada al LLM elige y ordena las K agregaciones (GROUP BY categorica x
medidas numericas) y los pivots MAS INFORMATIVOS para un analisis de grupos, con una
razon corta cada uno. El objetivo es evitar la explosion combinatoria: en vez de
"todo contra todo", el LLM se queda con lo que mas informa.
Privacidad y coste: NO se envian filas crudas al LLM. El prompt solo lleva el resumen
AGREGADO de los candidatos (nombre de la tabla, columnas categoricas con su
cardinalidad/score, medidas y pivots posibles). Una sola llamada barata.
Reusa `ask_llm` del registry (grupo claude-direct, API directa con el token OAuth de
Claude en ~/.claude/.credentials.json, arranque 0). Impura: una llamada de red.
Estilo dict-no-throw con FALLBACK DETERMINISTA: la funcion NUNCA lanza y SIEMPRE
devuelve algo usable. Si `ask_llm` falla (devuelve ""), el JSON no parsea, o el LLM no
produce ninguna seleccion valida, se construye la respuesta directamente desde los
candidatos (group_keys x measures hasta max_aggs, pivots tal cual) con
`source="fallback"`. Ademas, toda columna que el LLM invente (no presente en los
candidatos) se descarta.
"""
import json
from core.ask_llm import ask_llm
_SYSTEM = (
"Eres un analista de datos conciso. Te dan los CANDIDATOS AGREGADOS de una tabla "
"(columnas categoricas para GROUP BY con su cardinalidad, medidas numericas y "
"pivots posibles) y eliges las agregaciones y pivots MAS INFORMATIVOS para "
"entender los grupos, evitando la explosion combinatoria (no todo contra todo). "
"No recibes filas crudas. Responde en espanol. Responde SIEMPRE y SOLO con un "
"unico objeto JSON valido, sin texto alrededor ni fences de markdown, con la forma "
'{"aggregations": [{"group_by": "<col categorica>", "measures": ["<medida>", ...], '
'"why": "<razon corta>"}], "pivots": [{"index": "<col>", "columns": "<col>", '
'"value": "<medida o null>", "why": "<razon corta>"}]}. Usa SOLO nombres de columna '
"que aparezcan en los candidatos; no inventes nombres."
)
def _fmt_num(value) -> str:
"""Formatea un numero de forma compacta para el prompt (None -> '?')."""
if value is None:
return "?"
if isinstance(value, bool):
return str(value)
if isinstance(value, float):
if value == int(value):
return str(int(value))
return f"{value:.4g}"
return str(value)
def _candidate_view(candidates: dict):
"""Extrae las vistas utiles de los candidatos. Funcion interna PURA.
Devuelve la tupla (group_cols, measures, measure_set, pivots, group_keys):
- group_cols: set de nombres de columna categorica validas (de group_keys[].col).
- measures: lista de medidas numericas (str) preservando orden.
- measure_set: set de las medidas para validar pertenencia rapido.
- pivots: lista de pivots candidatos (dicts) tal cual vienen.
- group_keys: lista de dicts {col, cardinality, score} ya filtrada a entradas validas.
Tolera estructuras incompletas o de tipo incorrecto sin lanzar.
"""
candidates = candidates if isinstance(candidates, dict) else {}
gk_raw = candidates.get("group_keys")
group_keys = []
if isinstance(gk_raw, list):
for gk in gk_raw:
if isinstance(gk, dict) and isinstance(gk.get("col"), str):
group_keys.append(gk)
group_cols = {gk["col"] for gk in group_keys}
m_raw = candidates.get("measures")
measures = [m for m in m_raw if isinstance(m, str)] if isinstance(m_raw, list) else []
measure_set = set(measures)
p_raw = candidates.get("pivots")
pivots = p_raw if isinstance(p_raw, list) else []
return group_cols, measures, measure_set, pivots, group_keys
def _sorted_group_cols(group_keys: list) -> list:
"""Nombres de columna categorica ordenados por score descendente. PURA."""
def _score(gk):
s = gk.get("score")
if isinstance(s, (int, float)) and not isinstance(s, bool):
return s
return 0.0
return [gk["col"] for gk in sorted(group_keys, key=_score, reverse=True)]
def _build_prompt(profile: dict, candidates: dict, max_aggs: int) -> str:
"""Construye el prompt compacto SOLO con agregados. Funcion interna PURA.
No toca red ni disco: testeable sin credenciales. Incluye el nombre de la tabla,
las columnas categoricas candidatas (con cardinalidad y score), las medidas
numericas y los pivots candidatos. Nunca filas crudas.
Args:
profile: TableProfile (se usa solo profile['table'] para nombrar la tabla).
candidates: salida de select_groupby_keys.
max_aggs: tope de agregaciones a pedir.
Returns:
El texto del prompt.
"""
profile = profile if isinstance(profile, dict) else {}
candidates = candidates if isinstance(candidates, dict) else {}
table = profile.get("table")
table = str(table) if table is not None else "(tabla sin nombre)"
lines = [
f"Tabla: {table}",
(
"Tarea: elegir las agregaciones (GROUP BY categorica x medidas numericas) y "
"los pivots MAS INFORMATIVOS para un analisis de grupos. Evita la explosion "
"combinatoria: NO combines todo contra todo, prioriza lo que mas informa."
),
f"Devuelve a lo sumo {max_aggs} agregaciones.",
"",
"Columnas categoricas candidatas para GROUP BY (col: cardinalidad, score):",
]
group_keys = candidates.get("group_keys") or []
for gk in group_keys:
if not isinstance(gk, dict) or not isinstance(gk.get("col"), str):
continue
lines.append(
f" - {gk['col']}: cardinalidad={_fmt_num(gk.get('cardinality'))}, "
f"score={_fmt_num(gk.get('score'))}"
)
measures = candidates.get("measures") or []
lines.append("")
lines.append("Medidas numericas disponibles (para sum/avg por grupo):")
lines.append(" " + ", ".join(str(m) for m in measures if isinstance(m, str)))
pivots = candidates.get("pivots") or []
if pivots:
lines.append("")
lines.append("Pivots candidatos (index x columns -> value):")
for p in pivots:
if not isinstance(p, dict):
continue
lines.append(
f" - index={p.get('index')}, columns={p.get('columns')}, "
f"value={p.get('value')}"
)
lines.append("")
lines.append(
"Usa SOLO columnas de las listas anteriores; no inventes nombres. Responde "
"SOLO con el JSON descrito en las instrucciones del sistema."
)
return "\n".join(lines)
def _extract_json(text: str):
"""Extrae el primer bloque JSON (objeto o array) de la respuesta. PURA.
Localiza el bloque que empieza antes (el primer '{' o el primer '[') y, para ese
delimitador, hace json.loads del rango hasta su ultimo cierre. Tolera texto basura
alrededor y fences ```json. NUNCA lanza: ante cualquier fallo devuelve None.
Args:
text: respuesta cruda del LLM.
Returns:
El objeto/lista deserializado, o None si no se pudo parsear.
"""
if not text or not isinstance(text, str):
return None
opens = []
i_obj = text.find("{")
if i_obj != -1:
opens.append((i_obj, "{", "}"))
i_arr = text.find("[")
if i_arr != -1:
opens.append((i_arr, "[", "]"))
opens.sort()
for _, open_c, close_c in opens:
start = text.find(open_c)
end = text.rfind(close_c)
if start != -1 and end != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except (ValueError, TypeError):
continue
return None
def _validate_aggregations(raw_aggs, group_cols: set, measure_set: set, max_aggs: int) -> list:
"""Filtra las agregaciones del LLM a las que usan SOLO columnas candidatas. PURA.
Descarta cualquier agregacion cuyo group_by no este en group_cols o que no tenga
al menos una medida valida. Recorta las medidas a las presentes en measure_set.
Limita el resultado a max_aggs entradas.
"""
out = []
if not isinstance(raw_aggs, list):
return out
for item in raw_aggs:
if not isinstance(item, dict):
continue
gb = item.get("group_by")
if not isinstance(gb, str) or gb not in group_cols:
continue # columna inventada -> se descarta
raw_measures = item.get("measures")
if isinstance(raw_measures, str):
raw_measures = [raw_measures]
if not isinstance(raw_measures, list):
continue
measures = [m for m in raw_measures if isinstance(m, str) and m in measure_set]
if not measures:
continue # sin medidas validas -> agregacion inutil
why = item.get("why")
why = str(why) if why is not None else ""
out.append({"group_by": gb, "measures": measures, "why": why})
if len(out) >= max_aggs:
break
return out
def _validate_pivots(raw_pivots, group_cols: set, measure_set: set) -> list:
"""Filtra los pivots del LLM a los que usan SOLO columnas candidatas. PURA.
Descarta el pivot si index o columns no son columnas categoricas validas. Si el
value no es una medida valida, lo deja en None (un pivot de conteo sigue siendo util).
"""
out = []
if not isinstance(raw_pivots, list):
return out
for item in raw_pivots:
if not isinstance(item, dict):
continue
idx = item.get("index")
cols = item.get("columns")
if not (isinstance(idx, str) and idx in group_cols):
continue
if not (isinstance(cols, str) and cols in group_cols):
continue
val = item.get("value")
if not (isinstance(val, str) and val in measure_set):
val = None
why = item.get("why")
why = str(why) if why is not None else ""
out.append({"index": idx, "columns": cols, "value": val, "why": why})
return out
def _fallback_aggregations(group_cols_sorted: list, measures: list, max_aggs: int) -> list:
"""Agregaciones deterministas: cada columna categorica x todas las medidas. PURA."""
out = []
for col in group_cols_sorted:
out.append(
{
"group_by": col,
"measures": list(measures),
"why": "selección cuantitativa (sin LLM)",
}
)
if len(out) >= max_aggs:
break
return out
def _fallback_pivots(cand_pivots: list) -> list:
"""Normaliza los pivots candidatos a la forma de salida (tal cual + why). PURA."""
out = []
if not isinstance(cand_pivots, list):
return out
for p in cand_pivots:
if not isinstance(p, dict):
continue
idx = p.get("index")
cols = p.get("columns")
if not (isinstance(idx, str) and isinstance(cols, str)):
continue
val = p.get("value")
if not isinstance(val, str):
val = None
out.append(
{
"index": idx,
"columns": cols,
"value": val,
"why": "selección cuantitativa (sin LLM)",
}
)
return out
def suggest_aggregations_llm(
profile: dict,
candidates: dict,
max_aggs: int = 4,
model: str = "claude-haiku-4-5-20251001",
) -> dict:
"""Elige las agregaciones y pivots mas informativos con UNA llamada al LLM.
MUST-11.1 del capitulo AGREGACION del AutomaticEDA. Toma el perfil de la tabla y
los candidatos cuantitativos (salida de select_groupby_keys) y deja que el LLM
seleccione/ordene las K agregaciones (GROUP BY categorica x medidas) y los pivots
mas utiles, con una razon corta cada uno, evitando la explosion combinatoria.
Privacidad/coste: solo viaja al LLM el resumen AGREGADO de los candidatos, nunca
filas crudas. Una sola llamada barata.
dict-no-throw con fallback determinista: NUNCA lanza. Si el LLM falla, el JSON no
parsea, o no produce seleccion valida -> construye la respuesta desde los candidatos
(group_keys x measures hasta max_aggs, pivots tal cual) con source="fallback". Las
columnas que el LLM invente (no presentes en los candidatos) se descartan.
Args:
profile: TableProfile del grupo eda. Solo se usa profile['table'] para nombrar
la tabla en el prompt; puede ir vacio.
candidates: salida de select_groupby_keys, con la forma
{group_keys:[{col,cardinality,score}], measures:[str],
pivots:[{index,columns,value}]}.
max_aggs: tope de agregaciones a devolver. Default 4. Valores <1 o no-int se
normalizan a 4.
model: id del modelo Anthropic. Default 'claude-haiku-4-5-20251001' (haiku,
coste bajo, ~2-3s).
Returns:
dict {status:"ok", source:"llm"|"fallback",
aggregations:[{group_by:str, measures:[str], why:str}],
pivots:[{index:str, columns:str, value:str|None, why:str}], note:str}.
source=="llm" si el LLM produjo al menos una agregacion valida; en cualquier
otro caso "fallback". NUNCA lanza.
"""
if not isinstance(candidates, dict):
candidates = {}
if isinstance(max_aggs, bool) or not isinstance(max_aggs, int) or max_aggs < 1:
max_aggs = 4
group_cols, measures, measure_set, cand_pivots, group_keys = _candidate_view(candidates)
group_cols_sorted = _sorted_group_cols(group_keys)
# Sin material suficiente para agregar: no merece la pena llamar al LLM.
if not group_cols or not measures:
return {
"status": "ok",
"source": "fallback",
"aggregations": [],
"pivots": _fallback_pivots(cand_pivots),
"note": "sin candidatos suficientes para agregar",
}
prompt = _build_prompt(profile, candidates, max_aggs)
try:
text = ask_llm(prompt, model=model, system=_SYSTEM, echo=False)
except Exception: # noqa: BLE001 — degradacion: cualquier fallo de red/LLM.
text = ""
parsed = _extract_json(text)
if parsed is not None:
if isinstance(parsed, dict):
raw_aggs = parsed.get("aggregations")
raw_pivots = parsed.get("pivots")
elif isinstance(parsed, list):
raw_aggs = parsed
raw_pivots = None
else:
raw_aggs = None
raw_pivots = None
aggs = _validate_aggregations(raw_aggs, group_cols, measure_set, max_aggs)
if aggs:
pivots = _validate_pivots(raw_pivots, group_cols, measure_set)
if not pivots:
pivots = _fallback_pivots(cand_pivots)
return {
"status": "ok",
"source": "llm",
"aggregations": aggs,
"pivots": pivots,
"note": f"{len(aggs)} agregaciones y {len(pivots)} pivots seleccionados por el LLM",
}
# Fallback determinista.
note = (
"LLM no disponible; selección cuantitativa determinista"
if not text
else "LLM sin selección válida; selección cuantitativa determinista"
)
return {
"status": "ok",
"source": "fallback",
"aggregations": _fallback_aggregations(group_cols_sorted, measures, max_aggs),
"pivots": _fallback_pivots(cand_pivots),
"note": note,
}
@@ -1,198 +0,0 @@
"""Tests para suggest_aggregations_llm.
NO acceden a red ni a credenciales: las funciones internas (_build_prompt,
_extract_json, _validate_*, _fallback_*) son puras y testeables aisladas; la unica
via que llamaria al LLM (suggest_aggregations_llm) se prueba reemplazando el simbolo
`ask_llm` del modulo bajo prueba con una funcion simulada. Los candidatos van
literales en el test: NO se importa select_groupby_keys.
Cubre golden (LLM ok con columnas validas), edge (max_aggs respetado, sin candidatos)
y error (LLM caido -> fallback, JSON invalido -> fallback, columna inventada -> se
descarta). Todos sin tocar la red.
"""
import json
import datascience.suggest_aggregations_llm as M
from datascience.suggest_aggregations_llm import (
_extract_json,
_validate_aggregations,
suggest_aggregations_llm,
)
# Candidatos de ejemplo con la forma que produce select_groupby_keys (literales).
_CANDIDATES = {
"group_keys": [
{"col": "categoria", "cardinality": 8, "score": 0.91},
{"col": "region", "cardinality": 5, "score": 0.74},
{"col": "canal", "cardinality": 3, "score": 0.60},
],
"measures": ["importe", "unidades"],
"pivots": [
{"index": "categoria", "columns": "region", "value": "importe"},
],
}
_PROFILE = {"table": "ventas"}
def _fake_returner(text):
"""Devuelve un ask_llm simulado que ignora args y retorna `text`."""
def _fake(prompt, model="x", system="", echo=True, **kwargs):
return text
return _fake
# --- _extract_json (parser puro, sin red) ---
def test_extract_json_object():
obj = {"aggregations": [{"group_by": "categoria", "measures": ["importe"], "why": "x"}]}
assert _extract_json(json.dumps(obj)) == obj
def test_extract_json_wrapped_in_fences_and_junk():
obj = {"aggregations": [], "pivots": []}
text = "Claro, aqui tienes:\n```json\n" + json.dumps(obj) + "\n```\nFin."
assert _extract_json(text) == obj
def test_extract_json_non_json_returns_none():
assert _extract_json("no hay json aqui") is None
assert _extract_json("") is None
assert _extract_json(None) is None
# --- _validate_aggregations (puro) ---
def test_validate_aggregations_drops_invalid_columns():
group_cols = {"categoria", "region"}
measure_set = {"importe", "unidades"}
raw = [
{"group_by": "categoria", "measures": ["importe", "inventada"], "why": "ok"},
{"group_by": "no_existe", "measures": ["importe"], "why": "mala"},
{"group_by": "region", "measures": ["solo_inventada"], "why": "sin medidas"},
]
out = _validate_aggregations(raw, group_cols, measure_set, max_aggs=4)
# Solo sobrevive la primera, con las medidas recortadas a las validas.
assert out == [{"group_by": "categoria", "measures": ["importe"], "why": "ok"}]
# --- suggest_aggregations_llm: camino LLM (golden) ---
def test_llm_path_uses_selection(monkeypatch):
llm_obj = {
"aggregations": [
{"group_by": "categoria", "measures": ["importe"], "why": "ventas por familia"},
{"group_by": "region", "measures": ["importe", "unidades"], "why": "reparto geografico"},
],
"pivots": [
{"index": "categoria", "columns": "region", "value": "importe", "why": "cruce clave"},
],
}
monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj)))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES)
assert out["status"] == "ok"
assert out["source"] == "llm"
assert out["aggregations"] == llm_obj["aggregations"]
assert out["pivots"][0]["index"] == "categoria"
assert out["pivots"][0]["why"] == "cruce clave"
def test_llm_path_respects_max_aggs(monkeypatch):
llm_obj = {
"aggregations": [
{"group_by": "categoria", "measures": ["importe"], "why": "a"},
{"group_by": "region", "measures": ["importe"], "why": "b"},
{"group_by": "canal", "measures": ["unidades"], "why": "c"},
],
"pivots": [],
}
monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj)))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=2)
assert out["source"] == "llm"
assert len(out["aggregations"]) == 2
def test_llm_invented_column_is_discarded(monkeypatch):
# El LLM mezcla una agregacion valida con otra de columna inexistente.
llm_obj = {
"aggregations": [
{"group_by": "categoria", "measures": ["importe"], "why": "valida"},
{"group_by": "columna_fantasma", "measures": ["importe"], "why": "inventada"},
],
"pivots": [
{"index": "fantasma", "columns": "region", "value": "importe", "why": "mala"},
],
}
monkeypatch.setattr(M, "ask_llm", _fake_returner(json.dumps(llm_obj)))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES)
assert out["source"] == "llm"
# La agregacion inventada se descarta; queda solo la valida.
assert [a["group_by"] for a in out["aggregations"]] == ["categoria"]
# El pivot con index fantasma se descarta -> cae a los pivots de candidates.
assert all(p["index"] in {"categoria", "region", "canal"} for p in out["pivots"])
# --- suggest_aggregations_llm: fallback determinista (error paths) ---
def test_fallback_on_empty_llm_response(monkeypatch):
monkeypatch.setattr(M, "ask_llm", _fake_returner(""))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=4)
assert out["status"] == "ok"
assert out["source"] == "fallback"
# Las agregaciones se derivan de candidates (una por group_key, con todas las medidas).
assert out["aggregations"][0]["group_by"] in {"categoria", "region", "canal"}
assert out["aggregations"][0]["measures"] == ["importe", "unidades"]
assert out["aggregations"][0]["why"] == "selección cuantitativa (sin LLM)"
# Pivots tal cual de candidates.
assert out["pivots"][0]["index"] == "categoria"
def test_fallback_on_unparseable_response(monkeypatch):
monkeypatch.setattr(M, "ask_llm", _fake_returner("esto no es JSON {roto"))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES)
assert out["source"] == "fallback"
assert len(out["aggregations"]) >= 1
def test_fallback_respects_max_aggs(monkeypatch):
monkeypatch.setattr(M, "ask_llm", _fake_returner(""))
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES, max_aggs=2)
assert out["source"] == "fallback"
assert len(out["aggregations"]) == 2
def test_fallback_when_llm_raises(monkeypatch):
def _boom(*args, **kwargs):
raise RuntimeError("sin red")
monkeypatch.setattr(M, "ask_llm", _boom)
out = suggest_aggregations_llm(_PROFILE, _CANDIDATES)
assert out["source"] == "fallback"
assert out["aggregations"] # no vacio, no lanza
def test_no_candidates_returns_empty_fallback():
# Sin red porque ni siquiera se llama al LLM (no hay material).
out = suggest_aggregations_llm(_PROFILE, {"group_keys": [], "measures": [], "pivots": []})
assert out["status"] == "ok"
assert out["source"] == "fallback"
assert out["aggregations"] == []
def test_non_dict_candidates_does_not_raise():
out = suggest_aggregations_llm(_PROFILE, None)
assert out["status"] == "ok"
assert out["aggregations"] == []