StochTree 0.4.1
Loading...
Searching...
No Matches
distributions.h
1#ifndef STOCHTREE_DISTRIBUTIONS_H
2#define STOCHTREE_DISTRIBUTIONS_H
3#include <numeric>
4#include <random>
13namespace StochTree {
14
19inline double standard_uniform_draw_53bit(std::mt19937& gen) {
20 int32_t a = gen() >> 5;
21 int32_t b = gen() >> 6;
22 return (a * 67108864.0 + b) / 9007199254740992.0;
23}
24
28inline double standard_uniform_draw_32bit(std::mt19937& gen) {
29 constexpr double inv_divisor = 1.0 / static_cast<double>(std::mt19937::max());
30 return (gen() * inv_divisor);
31}
32
39 public:
41 has_cached_value_ = false;
42 cached_value_ = 0.0;
43 }
44
45 inline double operator()(std::mt19937& gen) {
46 if (has_cached_value_) {
47 has_cached_value_ = false;
48 return cached_value_;
49 } else {
50 double u, v, r, s;
51 do {
52 u = standard_uniform_draw_53bit(gen) * 2.0 - 1.0;
53 v = standard_uniform_draw_53bit(gen) * 2.0 - 1.0;
54 s = u * u + v * v;
55 } while (s >= 1.0 || s == 0.0);
56 r = std::sqrt(-2.0 * std::log(s) / s);
57 has_cached_value_ = true;
58 cached_value_ = v * r;
59 return u * r;
60 }
61 }
62
63 private:
64 bool has_cached_value_;
65 double cached_value_;
66};
67
76inline double sample_standard_normal(double mean, double sd, std::mt19937& gen) {
77 double u, v, r, s;
78 do {
79 u = standard_uniform_draw_53bit(gen) * 2.0 - 1.0;
80 v = standard_uniform_draw_53bit(gen) * 2.0 - 1.0;
81 s = u * u + v * v;
82 } while (s >= 1.0 || s == 0.0);
83 r = std::sqrt(-2.0 * std::log(s) / s);
84 return u * r * sd + mean;
85};
86
91 public:
93 has_cached_normal_value_ = false;
94 cached_normal_value_ = 0.0;
95 }
96
97 inline double operator()(std::mt19937& gen, double shape, double scale) {
98 if (shape == 1.0) {
99 return -std::log(standard_uniform_draw_53bit(gen)) * scale;
100 } else if (shape < 1.0) {
101 // Modified Ahrens-Dieter used by numpy:
102 // https://github.com/numpy/numpy/blob/main/numpy/random/src/distributions/distributions.c
103 while (true) {
104 double u = standard_uniform_draw_53bit(gen);
105 double v0 = standard_uniform_draw_53bit(gen);
106 double v = -std::log(v0);
107 if (u <= 1.0 - shape) {
108 double x = std::pow(u, 1.0 / shape);
109 if (x <= v) {
110 return x * scale;
111 }
112 } else {
113 double y = -std::log((1 - u) / shape);
114 double x = std::pow(1.0 - shape + shape * y, 1.0 / shape);
115 if (x <= v + y) {
116 return x * scale;
117 }
118 }
119 }
120 } else if (shape > 1.0) {
121 // Marsaglia-Tsang from numpy
122 double b = shape - 1.0 / 3.0;
123 double c = 1.0 / std::sqrt(9.0 * b);
124 while (true) {
125 double x, v;
126 do {
127 x = normal_draw(gen);
128 v = 1.0 + c * x;
129 } while (v <= 0.0);
130 v = v * v * v;
131 double u = standard_uniform_draw_53bit(gen);
132 if (u < 1.0 - 0.0331 * (x * x) * (x * x)) {
133 return b * v * scale;
134 }
135 if (std::log(u) < 0.5 * x * x + b * (1.0 - v + std::log(v))) {
136 return b * v * scale;
137 }
138 }
139 } else {
140 return 0.0;
141 }
142 }
143
144 inline double normal_draw(std::mt19937& gen) {
145 if (has_cached_normal_value_) {
146 has_cached_normal_value_ = false;
147 return cached_normal_value_;
148 } else {
149 double u, v, r, s;
150 do {
151 u = standard_uniform_draw_53bit(gen) * 2.0 - 1.0;
152 v = standard_uniform_draw_53bit(gen) * 2.0 - 1.0;
153 s = u * u + v * v;
154 } while (s >= 1.0 || s == 0.0);
155 r = std::sqrt(-2.0 * std::log(s) / s);
156 has_cached_normal_value_ = true;
157 cached_normal_value_ = v * r;
158 return u * r;
159 }
160 }
161
162 private:
163 bool has_cached_normal_value_;
164 double cached_normal_value_;
165};
166
175inline double sample_gamma(std::mt19937& gen, double shape, double scale) {
176 if (shape == 1.0) {
177 return -std::log(standard_uniform_draw_53bit(gen)) * scale;
178 } else if (shape < 1.0) {
179 // Modified Ahrens-Dieter used by numpy:
180 // https://github.com/numpy/numpy/blob/main/numpy/random/src/distributions/distributions.c
181 while (true) {
182 double u = standard_uniform_draw_53bit(gen);
183 double v0 = standard_uniform_draw_53bit(gen);
184 double v = -std::log(v0);
185 if (u <= 1.0 - shape) {
186 double x = std::pow(u, 1.0 / shape);
187 if (x <= v) {
188 return x * scale;
189 }
190 } else {
191 double y = -std::log((1 - u) / shape);
192 double x = std::pow(1.0 - shape + shape * y, 1.0 / shape);
193 if (x <= v + y) {
194 return x * scale;
195 }
196 }
197 }
198 } else if (shape > 1.0) {
199 // Marsaglia-Tsang from numpy
200 double b = shape - 1.0 / 3.0;
201 double c = 1.0 / std::sqrt(9.0 * b);
202 while (true) {
203 double x, v;
204 do {
205 // Marsaglia's polar method for standard normal
206 double u1, u2, s;
207 do {
208 u1 = standard_uniform_draw_53bit(gen) * 2.0 - 1.0;
209 u2 = standard_uniform_draw_53bit(gen) * 2.0 - 1.0;
210 s = u1 * u1 + u2 * u2;
211 } while (s >= 1.0 || s == 0.0);
212 x = u1 * std::sqrt(-2.0 * std::log(s) / s);
213 v = 1.0 + c * x;
214 } while (v <= 0.0);
215 v = v * v * v;
216 double u = standard_uniform_draw_53bit(gen);
217 if (u < 1.0 - 0.0331 * (x * x) * (x * x)) {
218 return b * v * scale;
219 }
220 if (std::log(u) < 0.5 * x * x + b * (1.0 - v + std::log(v))) {
221 return b * v * scale;
222 }
223 }
224 } else {
225 return 0.0;
226 }
227}
228
236 public:
237 template<typename Iterator>
238 walker_vose(Iterator first, Iterator last) {
239 n_ = std::distance(first, last);
240 probability_.resize(n_);
241 alias_.resize(n_);
242
243 // Compute probability normalizing factor
244 double sum = 0.0;
245 for (auto it = first; it != last; ++it) {
246 sum += *it;
247 }
248
249 // Build alias table using Walker's algorithm
250 std::vector<double> p(n_);
251 std::vector<int> below_average, above_average;
252
253 for (int i = 0; i < n_; ++i) {
254 p[i] = (*(first + i)) * n_ / sum;
255 if (p[i] < 1.0) {
256 below_average.push_back(i);
257 } else {
258 above_average.push_back(i);
259 }
260 }
261
262 while (!below_average.empty() && !above_average.empty()) {
263 int j = below_average.back(); below_average.pop_back();
264 int i = above_average.back(); above_average.pop_back();
265
266 probability_[j] = p[j];
267 alias_[j] = i;
268 p[i] = (p[i] + p[j]) - 1.0;
269
270 if (p[i] < 1.0) {
271 below_average.push_back(i);
272 } else {
273 above_average.push_back(i);
274 }
275 }
276
277 while (!above_average.empty()) {
278 probability_[above_average.back()] = 1.0;
279 above_average.pop_back();
280 }
281
282 while (!below_average.empty()) {
283 probability_[below_average.back()] = 1.0;
284 below_average.pop_back();
285 }
286 }
287
288 int operator()(std::mt19937& gen) {
289 double u = standard_uniform_draw_53bit(gen);
290 int i = static_cast<int>(u * n_);
291 double y = u * n_ - i;
292 return (y < probability_[i]) ? i : alias_[i];
293 }
294
295 private:
296 std::vector<double> probability_;
297 std::vector<int> alias_;
298 int n_;
299};
300
301inline int sample_discrete_stateless(std::mt19937& gen, std::vector<double>& weights) {
302 double sum_weight = std::accumulate(weights.begin(), weights.end(), 0.0);
303 double u = standard_uniform_draw_53bit(gen) * sum_weight;
304 double running_total_weight = 0.0;
305 for (int i = 0; i < weights.size(); ++i) {
306 running_total_weight += weights[i];
307 if (running_total_weight > u) {
308 return i;
309 }
310 }
311 return weights.size() - 1;
312}
313
314inline int sample_discrete_stateless(std::mt19937& gen, std::vector<double>& weights, double sum_weights) {
315 double u = standard_uniform_draw_53bit(gen) * sum_weights;
316 double running_total_weight = 0.0;
317 for (int i = 0; i < weights.size(); ++i) {
318 running_total_weight += weights[i];
319 if (running_total_weight > u) {
320 return i;
321 }
322 }
323 return weights.size() - 1;
324}
325
326}
327
328#endif // STOCHTREE_DISTRIBUTIONS_H
Definition distributions.h:90
Definition distributions.h:38
Definition distributions.h:235
A collection of random number generation utilities.
Definition category_tracker.h:36
double sample_standard_normal(double mean, double sd, std::mt19937 &gen)
Definition distributions.h:76
double sample_gamma(std::mt19937 &gen, double shape, double scale)
Definition distributions.h:175
double standard_uniform_draw_53bit(std::mt19937 &gen)
Definition distributions.h:19
double standard_uniform_draw_32bit(std::mt19937 &gen)
Definition distributions.h:28