#include "datascience/rng.h" #include namespace fn::ds { static inline std::uint64_t rotl(std::uint64_t x, int k) { return (x << k) | (x >> (64 - k)); } // SplitMix64 step — usado solo para seedear los 4 lanes de xoshiro256++. static inline std::uint64_t splitmix64(std::uint64_t& state) { state += 0x9E3779B97F4A7C15ULL; std::uint64_t z = state; z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ULL; z = (z ^ (z >> 27)) * 0x94D049BB133111EBULL; return z ^ (z >> 31); } void rng_seed(Rng& r, std::uint64_t seed) { if (seed == 0ULL) seed = 0x9E3779B97F4A7C15ULL; std::uint64_t s = seed; r.s[0] = splitmix64(s); r.s[1] = splitmix64(s); r.s[2] = splitmix64(s); r.s[3] = splitmix64(s); } // xoshiro256++ (Vigna). 1.24 ns/u64 en x86, supera PractRand 32 TB. std::uint64_t rng_u64(Rng& r) { const std::uint64_t result = rotl(r.s[0] + r.s[3], 23) + r.s[0]; const std::uint64_t t = r.s[1] << 17; r.s[2] ^= r.s[0]; r.s[3] ^= r.s[1]; r.s[1] ^= r.s[2]; r.s[0] ^= r.s[3]; r.s[2] ^= t; r.s[3] = rotl(r.s[3], 45); return result; } double rng_uniform(Rng& r) { // 53 bits superiores -> double en [0, 1). return (rng_u64(r) >> 11) * (1.0 / 9007199254740992.0); } double rng_normal(Rng& r) { // Box-Muller. Descarta una de las dos normales (suficientemente rapido // para la mayoria de usos; si hace falta cachear la otra, anadir un // flag al Rng). double u1 = rng_uniform(r); if (u1 < 1e-300) u1 = 1e-300; double u2 = rng_uniform(r); return std::sqrt(-2.0 * std::log(u1)) * std::cos(6.28318530717958647692 * u2); } std::uint64_t rng_below(Rng& r, std::uint64_t n) { if (n == 0ULL) return 0ULL; // Lemire's method (2019): rejection-sampling sin division en el caso // comun. Sesgo nulo. std::uint64_t x = rng_u64(r); __uint128_t m = static_cast<__uint128_t>(x) * static_cast<__uint128_t>(n); std::uint64_t l = static_cast(m); if (l < n) { std::uint64_t t = (~n + 1ULL) % n; // (-n) mod n while (l < t) { x = rng_u64(r); m = static_cast<__uint128_t>(x) * static_cast<__uint128_t>(n); l = static_cast(m); } } return static_cast(m >> 64); } int rng_categorical(Rng& r, const double* weights, int n) { if (n <= 0 || weights == nullptr) return 0; double total = 0.0; for (int i = 0; i < n; ++i) { if (weights[i] > 0.0) total += weights[i]; } if (total <= 0.0) return n - 1; double u = rng_uniform(r) * total; double acc = 0.0; for (int i = 0; i < n; ++i) { if (weights[i] > 0.0) { acc += weights[i]; if (u < acc) return i; } } return n - 1; } } // namespace fn::ds