#include "datascience/mc_metropolis_hastings_gpu.h" #include "gfx/gl_loader.h" #include "gfx/gpu_compute_program.h" #include "gfx/gpu_dispatch.h" #include "gfx/gpu_rng_glsl.h" #include #include namespace fn::ds { constexpr int kSeedBinding = 9; constexpr int kUserParams = 16; // Body parametrizado: el preamble inyecta target_log_pdf y la declaracion // de u_user[]. El body hace n_steps iteraciones por chain y persiste el // state (curr_x, accept_count, rng_seed) entre runs. static const char* k_body_template = R"glsl( layout(std430, binding = 0) buffer Chains { float chains[]; }; layout(std430, binding = 1) buffer AcceptCounts { uint accept_counts[]; }; layout(std430, binding = 2) buffer CurrX { float curr_x[]; }; uniform uint u_n_chains; uniform uint u_n_steps; uniform float u_proposal_sigma; void main() { uint cid = gl_GlobalInvocationID.x; if (cid >= u_n_chains) return; uint s = rng_seeds[cid]; float x = curr_x[cid]; float lp = target_log_pdf(x); uint acc = 0u; uint base = cid * u_n_steps; for (uint i = 0u; i < u_n_steps; ++i) { float prop = x + u_proposal_sigma * rng_normal(s); float lp_p = target_log_pdf(prop); float la = lp_p - lp; if (la >= 0.0) { x = prop; lp = lp_p; acc += 1u; } else { float u = rng_uniform(s); if (u < exp(la)) { x = prop; lp = lp_p; acc += 1u; } } chains[base + i] = x; } rng_seeds[cid] = s; curr_x[cid] = x; accept_counts[cid] += acc; } )glsl"; McMetropolisHastingsGpu mc_mh_gpu_create(int m_chains, int n_samples_per_run, const std::string& target_log_pdf_glsl) { McMetropolisHastingsGpu s{}; if (m_chains <= 0 || n_samples_per_run <= 0) return s; // Preamble = rng + array u_user[16] + user log-pdf. auto rng = fn::gfx::glsl_rng_preamble(kSeedBinding); std::string preamble; preamble.reserve(rng.size() + target_log_pdf_glsl.size() + 256); preamble += rng; char buf[64]; std::snprintf(buf, sizeof(buf), "uniform float u_user[%d];\n", kUserParams); preamble += buf; preamble += target_log_pdf_glsl; if (!preamble.empty() && preamble.back() != '\n') preamble += '\n'; auto r = fn::gfx::compile_compute(k_body_template, 64, preamble); if (!r.ok) { std::fprintf(stderr, "[mc_mh_gpu] compile error: %s\n", r.err_msg.c_str()); return s; } s.program = r.program; s.m_chains = m_chains; s.n_samples_per_run = n_samples_per_run; s.loc_n_chains = static_cast(glGetUniformLocation(s.program, "u_n_chains")); s.loc_n_steps = static_cast(glGetUniformLocation(s.program, "u_n_steps")); s.loc_proposal_sigma = static_cast(glGetUniformLocation(s.program, "u_proposal_sigma")); s.chains = fn::gfx::ssbo_create( static_cast(m_chains) * static_cast(n_samples_per_run) * sizeof(float), nullptr, GL_DYNAMIC_COPY); s.accept_counts = fn::gfx::ssbo_create( static_cast(m_chains) * sizeof(unsigned int), nullptr, GL_DYNAMIC_COPY); s.curr_x = fn::gfx::ssbo_create( static_cast(m_chains) * sizeof(float), nullptr, GL_DYNAMIC_COPY); s.rng_seeds = fn::gfx::ssbo_create( static_cast(m_chains) * sizeof(unsigned int), nullptr, GL_DYNAMIC_COPY); return s; } void mc_mh_gpu_reset(McMetropolisHastingsGpu& s, unsigned long long master_seed, float x0, const float* initial_xs) { if (s.m_chains <= 0) return; std::vector seeds(s.m_chains); fn::gfx::seed_walkers_init(master_seed, seeds.data(), s.m_chains); fn::gfx::ssbo_upload(s.rng_seeds, 0, static_cast(s.m_chains) * sizeof(unsigned int), seeds.data()); std::vector xs(s.m_chains); if (initial_xs) { for (int j = 0; j < s.m_chains; ++j) xs[j] = initial_xs[j]; } else { for (int j = 0; j < s.m_chains; ++j) xs[j] = x0; } fn::gfx::ssbo_upload(s.curr_x, 0, static_cast(s.m_chains) * sizeof(float), xs.data()); std::vector zeros(s.m_chains, 0u); fn::gfx::ssbo_upload(s.accept_counts, 0, static_cast(s.m_chains) * sizeof(unsigned int), zeros.data()); } void mc_mh_gpu_run(McMetropolisHastingsGpu& s, float proposal_sigma, const float* user_params, int n_user_params) { if (s.program == 0 || s.m_chains <= 0) return; glUseProgram(s.program); fn::gfx::ssbo_bind(s.chains, 0); fn::gfx::ssbo_bind(s.accept_counts, 1); fn::gfx::ssbo_bind(s.curr_x, 2); fn::gfx::ssbo_bind(s.rng_seeds, kSeedBinding); glUniform1ui(static_cast(s.loc_n_chains), static_cast(s.m_chains)); glUniform1ui(static_cast(s.loc_n_steps), static_cast(s.n_samples_per_run)); glUniform1f (static_cast(s.loc_proposal_sigma), proposal_sigma); if (user_params && n_user_params > 0) { if (n_user_params > kUserParams) n_user_params = kUserParams; GLint loc = glGetUniformLocation(s.program, "u_user"); if (loc >= 0) glUniform1fv(loc, n_user_params, user_params); } fn::gfx::dispatch_1d(s.m_chains, 64); fn::gfx::barrier_buffer_update(); } void mc_mh_gpu_readback_chains(const McMetropolisHastingsGpu& s, float* out) { if (s.chains.id == 0 || out == nullptr) return; fn::gfx::ssbo_readback( s.chains, 0, static_cast(s.m_chains) * static_cast(s.n_samples_per_run) * sizeof(float), out); } void mc_mh_gpu_readback_accepts(const McMetropolisHastingsGpu& s, unsigned int* out) { if (s.accept_counts.id == 0 || out == nullptr) return; fn::gfx::ssbo_readback( s.accept_counts, 0, static_cast(s.m_chains) * sizeof(unsigned int), out); } void mc_mh_gpu_destroy(McMetropolisHastingsGpu& s) { fn::gfx::delete_compute_program(s.program); s.program = 0; fn::gfx::ssbo_destroy(s.chains); fn::gfx::ssbo_destroy(s.accept_counts); fn::gfx::ssbo_destroy(s.curr_x); fn::gfx::ssbo_destroy(s.rng_seeds); s.m_chains = 0; s.n_samples_per_run = 0; } } // namespace fn::ds