#include "datascience/rhat_ess.h" #include "datascience/autocorr.h" #include #include namespace fn::ds { static double rhat_core(const double* chains, std::size_t m, std::size_t n) { if (chains == nullptr || m < 2 || n < 2) return 1.0; // Mean por cadena. std::vector means(m, 0.0); for (std::size_t j = 0; j < m; ++j) { const double* c = chains + j * n; double s = 0.0; for (std::size_t i = 0; i < n; ++i) s += c[i]; means[j] = s / static_cast(n); } // Grand mean. double grand = 0.0; for (std::size_t j = 0; j < m; ++j) grand += means[j]; grand /= static_cast(m); // Between-chain variance B. double B = 0.0; for (std::size_t j = 0; j < m; ++j) { double d = means[j] - grand; B += d * d; } B *= static_cast(n) / static_cast(m - 1); // Within-chain variance W (promedio de las varianzas muestrales). double W = 0.0; for (std::size_t j = 0; j < m; ++j) { const double* c = chains + j * n; double s = 0.0; for (std::size_t i = 0; i < n; ++i) { double d = c[i] - means[j]; s += d * d; } W += s / static_cast(n - 1); } W /= static_cast(m); if (W <= 0.0) return 1.0; double n_d = static_cast(n); double var_hat = ((n_d - 1.0) * W + B) / n_d; return std::sqrt(var_hat / W); } double rhat(const double* chains, std::size_t m, std::size_t n) { return rhat_core(chains, m, n); } double rhat_split(const double* chains, std::size_t m, std::size_t n) { if (chains == nullptr || m < 1 || n < 4) return 1.0; std::size_t half = n / 2; std::size_t m2 = m * 2; // Reorganizar a (2m, half) row-major. Copia explicita: la segunda mitad // de cada cadena no es contigua a la primera. std::vector split(m2 * half); for (std::size_t j = 0; j < m; ++j) { const double* c = chains + j * n; double* a = split.data() + (2 * j) * half; double* b = split.data() + (2 * j + 1) * half; for (std::size_t i = 0; i < half; ++i) a[i] = c[i]; for (std::size_t i = 0; i < half; ++i) b[i] = c[half + i]; } return rhat_core(split.data(), m2, half); } double ess_basic(const double* chains, std::size_t m, std::size_t n, std::size_t max_lag, double cutoff) { if (chains == nullptr || m == 0 || n < 2) return 0.0; double total = 0.0; for (std::size_t j = 0; j < m; ++j) { const double* c = chains + j * n; double tau = autocorr_tau(c, n, max_lag, cutoff); if (tau < 1.0) tau = 1.0; total += static_cast(n) / tau; } return total; } } // namespace fn::ds