"""Impure EDA helper: scatter figure of a numeric pair with its fit (`eda` group). Builds a matplotlib scatter of two numeric variables, overlays the fitted curve/line implied by the relationship classification (linear, polynomial of degree 2/3, etc.) and annotates the relationship type with its available metrics. Returns a ready-to-rasterize ``matplotlib.figure.Figure``; it never shows nor saves it. Impure because it touches matplotlib's rendering machinery. It uses the headless Agg backend and the object-oriented ``Figure`` API (no ``pyplot``) so it leaks no global state and is safe to call repeatedly from a report renderer. To keep the rendered PDF/PPTX light on phones, when the number of valid pairs exceeds ``max_points`` the *plotted* points are down-sampled DETERMINISTICALLY by a fixed step (``pairs[::step]``), never randomly, so the output is reproducible. The classification/fit always uses every clean pair; the down-sample only thins the drawn cloud. """ import math import matplotlib matplotlib.use("Agg") import numpy as np # noqa: E402 from matplotlib.figure import Figure # noqa: E402 # Sober blue for the scatter cloud and red for the fitted curve (Tufte: the # data points are the primary ink, the fit is the secondary highlight). _POINT_COLOR = "#4C72B0" _FIT_COLOR = "#C44E52" # Muted gray for the no-data fallback message. _MUTED_TEXT = "#5f6b7a" def _finite(value): """Coerce ``value`` to a finite float, or return None when not usable. bool is a subclass of int, but a real numeric measurement is never a bool, so True/False are treated as missing instead of coercing to 1.0/0.0. NaN and +/-infinity are never valid either. """ if value is None or isinstance(value, bool): return None try: f = float(value) except (TypeError, ValueError): return None if math.isnan(f) or math.isinf(f): return None return f def _clean_pairs(xs, ys): """Pair ``xs[i], ys[i]`` by index, dropping any pair with a non-finite end.""" pairs = [] if isinstance(xs, (list, tuple)) and isinstance(ys, (list, tuple)): n = min(len(xs), len(ys)) for i in range(n): x = _finite(xs[i]) y = _finite(ys[i]) if x is None or y is None: continue pairs.append((x, y)) return pairs def _ordered_trend(xs_clean, ys_clean, n_bins: int = 12): """Return (x_trend, y_trend): the ordered trend of y over x for a monotonic relationship that has no polynomial fit. When x has few distinct values (an ordinal/discrete scale) the trend is the mean of y per distinct x value. Otherwise x is split into ``n_bins`` ordered quantile bins and each point is (mean x, mean y) of the bin. Returns ``(None, None)`` when there is nothing meaningful to draw. """ x_arr = np.asarray(xs_clean, dtype=float) y_arr = np.asarray(ys_clean, dtype=float) if x_arr.size < 2: return None, None uniq = np.unique(x_arr) if uniq.size <= max(2, n_bins): # Discrete x: one trend point per distinct value (mean y). xt = uniq yt = np.array([float(np.mean(y_arr[x_arr == ux])) for ux in uniq]) return xt, yt # Continuous x: ordered quantile bins, (mean x, mean y) per bin. order = np.argsort(x_arr, kind="stable") x_sorted = x_arr[order] y_sorted = y_arr[order] chunks_x = np.array_split(x_sorted, n_bins) chunks_y = np.array_split(y_sorted, n_bins) xt = np.array([float(np.mean(cx)) for cx in chunks_x if cx.size]) yt = np.array([float(np.mean(cy)) for cy in chunks_y if cy.size]) return xt, yt def _no_data_figure(message: str) -> "matplotlib.figure.Figure": """A bare Figure carrying a centered muted message (defensive fallback).""" fig = Figure(figsize=(6.4, 4.0), dpi=150) ax = fig.add_subplot(111) ax.axis("off") ax.text( 0.5, 0.5, message, ha="center", va="center", fontsize=12, color=_MUTED_TEXT, transform=ax.transAxes, ) fig.tight_layout() return fig def _metrics_caption(classification: dict) -> str: """Format the available metrics of a classification dict into one line. Omits the metrics that are None. Keys consumed (any may be absent/None): ``pearson`` (r), ``spearman`` (rho), ``r2_linear`` (R²lin) and the best polynomial R² (``r2_poly3`` if a cubic was the best fit, else ``r2_poly2``). """ parts = [] r = _finite(classification.get("pearson")) if r is not None: parts.append(f"r={r:.2f}") rho = _finite(classification.get("spearman")) if rho is not None: parts.append(f"ρ={rho:.2f}") r2_lin = _finite(classification.get("r2_linear")) if r2_lin is not None: parts.append(f"R²lin={r2_lin:.2f}") # Prefer the R² of the best polynomial degree when it is a poly fit. best_degree = classification.get("best_degree") r2_poly = None if best_degree == 3: r2_poly = _finite(classification.get("r2_poly3")) elif best_degree == 2: r2_poly = _finite(classification.get("r2_poly2")) if r2_poly is None: # Fall back to whichever poly R² is present (cubic first). r2_poly = _finite(classification.get("r2_poly3")) if r2_poly is None: r2_poly = _finite(classification.get("r2_poly2")) if r2_poly is not None: parts.append(f"R²poly={r2_poly:.2f}") return " ".join(parts) def relationship_scatter_figure( xs: list, ys: list, x_label: str = "", y_label: str = "", classification: dict = None, max_points: int = 2000, ) -> "matplotlib.figure.Figure": """Build a scatter figure of a numeric pair with its fit and a type label. Cleans the pairs defensively (drops any pair with a None/bool/NaN/inf end), plots a semi-transparent scatter cloud (down-sampled deterministically when it exceeds ``max_points``), overlays the polynomial fit implied by ``classification`` and annotates the relationship type plus its available metrics in a corner box. The fit and classification always use every clean pair; only the drawn cloud is thinned by the down-sample. When ``classification`` is None it is computed internally by reusing ``classify_relationship_type`` over the clean pairs, so the function is self-contained. The function is fully defensive: empty input, fewer than 2 clean pairs, a missing/None ``coeffs`` or a missing sibling classifier never raise. When there is nothing valid to draw it still returns a ``Figure`` carrying a centered "Sin datos suficientes para el scatter" message. Args: xs: List (or tuple) of x values. Paired by index with ``ys``. Values that are None, bool, NaN or infinite discard that pair. Read defensively. ys: List (or tuple) of y values, parallel to ``xs``. Same defensive rules. x_label: Axis/title label for the x variable. Default "" (falls back to "x" in the title). y_label: Axis/title label for the y variable. Default "" (falls back to "y" in the title). classification: Optional dict from ``classify_relationship_type`` with keys ``tipo, pearson, r2_linear, spearman, r2_poly2, r2_poly3, best_degree, coeffs``. When None, it is computed internally by importing and calling ``classify_relationship_type`` over the clean pairs. When that sibling module is unavailable, the scatter is still drawn (no fit curve, no annotation). max_points: Cap on the number of *plotted* points. When the number of clean pairs exceeds this cap, the drawn cloud is down-sampled by a fixed step ``ceil(n/max_points)`` taking ``pairs[::step]`` — DETERMINISTIC, not random, so the figure is reproducible. A non-positive or non-int value disables down-sampling. Default 2000. Returns: A ``matplotlib.figure.Figure`` (figsize 6.4x4.0, dpi 150) with a single scatter Axes, the fitted curve (when a polynomial fit is available) and a corner annotation with the relationship type and metrics. When there are fewer than 2 clean pairs it returns a Figure with a centered "Sin datos suficientes para el scatter" message. The caller rasterizes/closes it. """ pairs = _clean_pairs(xs, ys) if len(pairs) < 2: return _no_data_figure("Sin datos suficientes para el scatter") # Full clean coordinates feed the classification/fit; the plotted cloud is # what gets thinned. xs_clean = [p[0] for p in pairs] ys_clean = [p[1] for p in pairs] # Resolve the classification. If not provided, reuse the sibling classifier # over ALL clean pairs (self-contained). Missing module => no fit/annotation. cls = classification if cls is None: try: from classify_relationship_type import classify_relationship_type cls = classify_relationship_type(xs_clean, ys_clean) except Exception: cls = None if not isinstance(cls, dict): cls = {} # --- Deterministic down-sampling of the DRAWN points only. n_total = len(pairs) if ( isinstance(max_points, int) and not isinstance(max_points, bool) and max_points > 0 and n_total > max_points ): step = math.ceil(n_total / max_points) sampled = pairs[::step] else: sampled = pairs x_plot = [p[0] for p in sampled] y_plot = [p[1] for p in sampled] fig = Figure(figsize=(6.4, 4.0), dpi=150) ax = fig.add_subplot(111) ax.scatter( x_plot, y_plot, s=12, alpha=0.5, color=_POINT_COLOR, edgecolors="none", rasterized=True, ) # --- Fitted curve/line over the full clean x range. coeffs = cls.get("coeffs") best_degree = cls.get("best_degree") tipo = cls.get("tipo") x_min, x_max = min(xs_clean), max(xs_clean) drew_fit = False if coeffs is not None and best_degree is not None and x_max > x_min: try: coeff_arr = np.asarray(coeffs, dtype=float) if coeff_arr.ndim == 1 and coeff_arr.size > 0 and np.all(np.isfinite(coeff_arr)): x_line = np.linspace(x_min, x_max, 200) y_line = np.polyval(coeff_arr, x_line) if np.all(np.isfinite(y_line)): ax.plot(x_line, y_line, color=_FIT_COLOR, linewidth=2) drew_fit = True except Exception: # Never fail the figure because of a malformed coeffs array. pass # A monotonic non-linear relationship has no fitted polynomial (coeffs is # None by design — a low-degree polynomial would mislead). Draw instead the # ordered trend of y over x so the reader still sees the shape: y averaged # within ordered x-bins (or per distinct x value when x is discrete with few # levels, e.g. an ordinal scale). Defensive: any failure leaves the cloud. if (not drew_fit and isinstance(tipo, str) and "monóton" in tipo.lower() and x_max > x_min): try: xt, yt = _ordered_trend(xs_clean, ys_clean) if xt is not None and len(xt) >= 2: ax.plot(xt, yt, color=_FIT_COLOR, linewidth=2, marker="o", markersize=3) except Exception: pass # --- Labels and title. tx = x_label if x_label else "x" ty = y_label if y_label else "y" ax.set_title(f"{tx} ↔ {ty}", fontsize=12, loc="left", pad=8) ax.set_xlabel(x_label) ax.set_ylabel(y_label) # --- Corner annotation: relationship type + available metrics. caption_lines = [] if tipo: caption_lines.append(str(tipo)) metrics_line = _metrics_caption(cls) if metrics_line: caption_lines.append(metrics_line) if caption_lines: ax.text( 0.03, 0.97, "\n".join(caption_lines), transform=ax.transAxes, ha="left", va="top", fontsize=8, bbox=dict( boxstyle="round,pad=0.35", facecolor="white", edgecolor="#cccccc", alpha=0.85, ), ) fig.tight_layout() return fig