StochTree 0.1.1
Loading...
Searching...
No Matches
discrete_sampler.h
1
5#ifndef STOCHTREE_DISCRETE_SAMPLER_H_
6#define STOCHTREE_DISCRETE_SAMPLER_H_
7#include <algorithm>
8#include <numeric>
9#include <random>
10#include <vector>
11
12namespace StochTree {
13
18template <typename container_type, typename prob_type>
19void sample_without_replacement(container_type* output, prob_type* p, container_type* a, int population_size, int sample_size, std::mt19937& gen) {
20 std::vector<prob_type> p_copy(population_size);
21 std::memcpy(p_copy.data(), p, sizeof(prob_type) * population_size);
22 std::vector<int> indices(sample_size);
23 std::uniform_real_distribution<> unif(0.0, 1.0);
24 std::vector<prob_type> unif_samples(sample_size);
25 std::vector<prob_type> cdf(population_size);
26
27 int fulfilled_sample_count = 0;
28 int remaining_sample_count = sample_size - fulfilled_sample_count;
29 while (fulfilled_sample_count < sample_size) {
30 if (fulfilled_sample_count > 0) {
31 for (int i = 0; i < fulfilled_sample_count; i++) p_copy[indices[i]] = 0.0;
32 }
33 std::generate(unif_samples.begin(), unif_samples.begin() + remaining_sample_count, [&gen, &unif](){
34 return unif(gen);
35 });
36 std::partial_sum(p_copy.cbegin(), p_copy.cend(), cdf.begin());
37 for (int i = 0; i < cdf.size(); i++) {
38 cdf[i] = cdf[i] / cdf[cdf.size()-1];
39 }
40 std::vector<int> matches(remaining_sample_count);
41 for (int i = 0; i < remaining_sample_count; i++) {
42 auto match = std::upper_bound(cdf.cbegin(), cdf.cend(), unif_samples[i]);
43 if (match != cdf.cend()) {
44 matches[i] = std::distance(cdf.cbegin(), match);
45 } else {
46 matches[i] = std::distance(cdf.cbegin(), cdf.cend());
47 }
48 }
49 std::sort(matches.begin(), matches.end());
50 auto last_unique = std::unique(matches.begin(), matches.end());
51 matches.erase(last_unique, matches.end());
52 for (int i = 0; i < matches.size(); i++) {
53 indices[fulfilled_sample_count + i] = matches[i];
54 }
55 fulfilled_sample_count += matches.size();
56 remaining_sample_count -= matches.size();
57 }
58 for (int i = 0; i < sample_size; i++) {
59 output[i] = a[indices[i]];
60 }
61}
62
63}
64
65#endif // STOCHTREE_DISCRETE_SAMPLER_H_
Definition category_tracker.h:36
void sample_without_replacement(container_type *output, prob_type *p, container_type *a, int population_size, int sample_size, std::mt19937 &gen)
Sample without replacement according to a set of probability weights. This template function is a C++...
Definition discrete_sampler.h:19