#include "datascience/metropolis_hastings.h" #include #include namespace fn::ds { MHResult mh_run_1d(const std::function& target_log_pdf, double x0, double proposal_sigma, std::size_t n_samples, double* out_chain, Rng& r) { MHResult res{}; if (n_samples == 0 || out_chain == nullptr) return res; out_chain[0] = x0; double curr = x0; double log_p_curr = target_log_pdf(curr); std::size_t accepted = 0; for (std::size_t i = 1; i < n_samples; ++i) { double prop = curr + proposal_sigma * rng_normal(r); double log_p_prop = target_log_pdf(prop); double log_alpha = log_p_prop - log_p_curr; bool accept = false; if (log_alpha >= 0.0) { accept = true; } else { double u = rng_uniform(r); if (u < std::exp(log_alpha)) accept = true; } if (accept) { curr = prop; log_p_curr = log_p_prop; ++accepted; } out_chain[i] = curr; } res.n_samples = n_samples; res.n_accepted = accepted; res.accept_rate = (n_samples > 1) ? static_cast(accepted) / static_cast(n_samples - 1) : 0.0; return res; } MHResult mh_run_nd(const std::function& target_log_pdf, const double* x0, const double* proposal_sigma, int d, std::size_t n_samples, double* out_chain, Rng& r) { MHResult res{}; if (d <= 0 || n_samples == 0 || out_chain == nullptr || x0 == nullptr || proposal_sigma == nullptr) return res; std::vector curr(d); std::vector prop(d); for (int k = 0; k < d; ++k) curr[k] = x0[k]; for (int k = 0; k < d; ++k) out_chain[k] = curr[k]; double log_p_curr = target_log_pdf(curr.data()); std::size_t accepted = 0; for (std::size_t i = 1; i < n_samples; ++i) { for (int k = 0; k < d; ++k) { prop[k] = curr[k] + proposal_sigma[k] * rng_normal(r); } double log_p_prop = target_log_pdf(prop.data()); double log_alpha = log_p_prop - log_p_curr; bool accept = false; if (log_alpha >= 0.0) { accept = true; } else { double u = rng_uniform(r); if (u < std::exp(log_alpha)) accept = true; } if (accept) { for (int k = 0; k < d; ++k) curr[k] = prop[k]; log_p_curr = log_p_prop; ++accepted; } for (int k = 0; k < d; ++k) { out_chain[i * static_cast(d) + k] = curr[k]; } } res.n_samples = n_samples; res.n_accepted = accepted; res.accept_rate = (n_samples > 1) ? static_cast(accepted) / static_cast(n_samples - 1) : 0.0; return res; } } // namespace fn::ds