#pragma once #include "gfx/gpu_ssbo.h" #include namespace fn::ds { // Estado del MH GPU. Cada walker (cadena) tiene su propio thread; el // shader hace n_steps iteraciones por dispatch. Subsiguientes dispatches // continuan donde se quedaron. Output: // chains[m_chains * n_samples_per_run] floats — cadena completa // accept_counts[m_chains] uints — # de aceptaciones (pace) // curr_x[m_chains] floats — x actual por chain struct McMetropolisHastingsGpu { unsigned int program = 0; fn::gfx::Ssbo chains; fn::gfx::Ssbo accept_counts; fn::gfx::Ssbo curr_x; fn::gfx::Ssbo rng_seeds; int m_chains = 0; int n_samples_per_run = 0; unsigned int loc_n_chains = 0; unsigned int loc_n_steps = 0; unsigned int loc_proposal_sigma = 0; }; // Crea el sampler para m_chains cadenas, con n_samples_per_run steps por // cada llamada a run. La log-pdf se inyecta como GLSL: el snippet debe // definir // // float target_log_pdf(float x) { ... } // // y puede usar uniforms float u_user[16] (predeclarados en el preamble) // para parametros que cambien sin recompilar. Se usan en el snippet con // "u_user[0]", "u_user[1]", etc. // // Si la compilacion falla, m_chains=0 y program=0 — comprobar antes del // run. McMetropolisHastingsGpu mc_mh_gpu_create(int m_chains, int n_samples_per_run, const std::string& target_log_pdf_glsl); // Re-siembra los seeds y resetea curr_x a x0 (mismo valor para todas las // cadenas, o array de m_chains valores si initial_xs != nullptr). void mc_mh_gpu_reset(McMetropolisHastingsGpu& s, unsigned long long master_seed, float x0, const float* initial_xs = nullptr); // Ejecuta n_samples_per_run steps por chain. proposal_sigma controla el // random-walk Gaussian. user_params (si != nullptr) se sube a u_user[0..15]. void mc_mh_gpu_run(McMetropolisHastingsGpu& s, float proposal_sigma, const float* user_params = nullptr, int n_user_params = 0); // Lee chains a CPU. out debe tener m_chains * n_samples_per_run floats, // layout out[j * n + i] = sample i de la cadena j (compatible con // rhat_split / ess_basic). void mc_mh_gpu_readback_chains(const McMetropolisHastingsGpu& s, float* out); // Lee accept counts a CPU. out debe tener m_chains uints. accept_rate de // la cadena j: accept_counts[j] / n_samples_per_run. void mc_mh_gpu_readback_accepts(const McMetropolisHastingsGpu& s, unsigned int* out); void mc_mh_gpu_destroy(McMetropolisHastingsGpu& s); } // namespace fn::ds