1#ifndef STOCHTREE_DISTRIBUTIONS_H
2#define STOCHTREE_DISTRIBUTIONS_H
20 int32_t a = gen() >> 5;
21 int32_t b = gen() >> 6;
22 return (a * 67108864.0 + b) / 9007199254740992.0;
29 constexpr double inv_divisor = 1.0 /
static_cast<double>(std::mt19937::max());
30 return (gen() * inv_divisor);
41 has_cached_value_ =
false;
45 inline double operator()(std::mt19937& gen) {
46 if (has_cached_value_) {
47 has_cached_value_ =
false;
55 }
while (s >= 1.0 || s == 0.0);
56 r = std::sqrt(-2.0 * std::log(s) / s);
57 has_cached_value_ =
true;
58 cached_value_ = v * r;
64 bool has_cached_value_;
82 }
while (s >= 1.0 || s == 0.0);
83 r = std::sqrt(-2.0 * std::log(s) / s);
84 return u * r * sd + mean;
93 has_cached_normal_value_ =
false;
94 cached_normal_value_ = 0.0;
97 inline double operator()(std::mt19937& gen,
double shape,
double scale) {
100 }
else if (shape < 1.0) {
106 double v = -std::log(v0);
107 if (u <= 1.0 - shape) {
108 double x = std::pow(u, 1.0 / shape);
113 double y = -std::log((1 - u) / shape);
114 double x = std::pow(1.0 - shape + shape * y, 1.0 / shape);
120 }
else if (shape > 1.0) {
122 double b = shape - 1.0 / 3.0;
123 double c = 1.0 / std::sqrt(9.0 * b);
127 x = normal_draw(gen);
132 if (u < 1.0 - 0.0331 * (x * x) * (x * x)) {
133 return b * v * scale;
135 if (std::log(u) < 0.5 * x * x + b * (1.0 - v + std::log(v))) {
136 return b * v * scale;
144 inline double normal_draw(std::mt19937& gen) {
145 if (has_cached_normal_value_) {
146 has_cached_normal_value_ =
false;
147 return cached_normal_value_;
154 }
while (s >= 1.0 || s == 0.0);
155 r = std::sqrt(-2.0 * std::log(s) / s);
156 has_cached_normal_value_ =
true;
157 cached_normal_value_ = v * r;
163 bool has_cached_normal_value_;
164 double cached_normal_value_;
175inline double sample_gamma(std::mt19937& gen,
double shape,
double scale) {
178 }
else if (shape < 1.0) {
184 double v = -std::log(v0);
185 if (u <= 1.0 - shape) {
186 double x = std::pow(u, 1.0 / shape);
191 double y = -std::log((1 - u) / shape);
192 double x = std::pow(1.0 - shape + shape * y, 1.0 / shape);
198 }
else if (shape > 1.0) {
200 double b = shape - 1.0 / 3.0;
201 double c = 1.0 / std::sqrt(9.0 * b);
210 s = u1 * u1 + u2 * u2;
211 }
while (s >= 1.0 || s == 0.0);
212 x = u1 * std::sqrt(-2.0 * std::log(s) / s);
217 if (u < 1.0 - 0.0331 * (x * x) * (x * x)) {
218 return b * v * scale;
220 if (std::log(u) < 0.5 * x * x + b * (1.0 - v + std::log(v))) {
221 return b * v * scale;
237 template<
typename Iterator>
239 n_ = std::distance(first, last);
240 probability_.resize(n_);
245 for (
auto it = first; it != last; ++it) {
250 std::vector<double> p(n_);
251 std::vector<int> below_average, above_average;
253 for (
int i = 0; i < n_; ++i) {
254 p[i] = (*(first + i)) * n_ / sum;
256 below_average.push_back(i);
258 above_average.push_back(i);
262 while (!below_average.empty() && !above_average.empty()) {
263 int j = below_average.back(); below_average.pop_back();
264 int i = above_average.back(); above_average.pop_back();
266 probability_[j] = p[j];
268 p[i] = (p[i] + p[j]) - 1.0;
271 below_average.push_back(i);
273 above_average.push_back(i);
277 while (!above_average.empty()) {
278 probability_[above_average.back()] = 1.0;
279 above_average.pop_back();
282 while (!below_average.empty()) {
283 probability_[below_average.back()] = 1.0;
284 below_average.pop_back();
288 int operator()(std::mt19937& gen) {
290 int i =
static_cast<int>(u * n_);
291 double y = u * n_ - i;
292 return (y < probability_[i]) ? i : alias_[i];
296 std::vector<double> probability_;
297 std::vector<int> alias_;
301inline int sample_discrete_stateless(std::mt19937& gen, std::vector<double>& weights) {
302 double sum_weight = std::accumulate(weights.begin(), weights.end(), 0.0);
304 double running_total_weight = 0.0;
305 for (
int i = 0; i < weights.size(); ++i) {
306 running_total_weight += weights[i];
307 if (running_total_weight > u) {
311 return weights.size() - 1;
314inline int sample_discrete_stateless(std::mt19937& gen, std::vector<double>& weights,
double sum_weights) {
316 double running_total_weight = 0.0;
317 for (
int i = 0; i < weights.size(); ++i) {
318 running_total_weight += weights[i];
319 if (running_total_weight > u) {
323 return weights.size() - 1;
Definition distributions.h:90
Definition distributions.h:38
Definition distributions.h:235
A collection of random number generation utilities.
Definition category_tracker.h:36
double sample_standard_normal(double mean, double sd, std::mt19937 &gen)
Definition distributions.h:76
double sample_gamma(std::mt19937 &gen, double shape, double scale)
Definition distributions.h:175
double standard_uniform_draw_53bit(std::mt19937 &gen)
Definition distributions.h:19
double standard_uniform_draw_32bit(std::mt19937 &gen)
Definition distributions.h:28