StochTree 0.0.1
Loading...
Searching...
No Matches
normal_sampler.h
1
2#ifndef STOCHTREE_NORMAL_SAMPLER_H_
3#define STOCHTREE_NORMAL_SAMPLER_H_
4
5#include <Eigen/Dense>
6#include <stochtree/log.h>
7#include <random>
8#include <vector>
9
10namespace StochTree {
11
13 public:
14 UnivariateNormalSampler() {std_normal_dist_ = std::normal_distribution<double>(0.,1.);}
16 double Sample(double mean, double variance, std::mt19937& gen) {
17 return mean + std::sqrt(variance) * std_normal_dist_(gen);
18 }
19 private:
21 std::normal_distribution<double> std_normal_dist_;
22};
23
25 public:
26 MultivariateNormalSampler() {std_normal_dist_ = std::normal_distribution<double>(0.,1.);}
28 std::vector<double> Sample(Eigen::VectorXd& mean, Eigen::MatrixXd& covariance, std::mt19937& gen) {
29 // Dimension extraction and checks
30 int mean_cols = mean.size();
31 int cov_rows = covariance.rows();
32 int cov_cols = covariance.cols();
33 CHECK_EQ(mean_cols, cov_cols);
34
35 // Variance cholesky decomposition
36 Eigen::LLT<Eigen::MatrixXd> decomposition(covariance);
37 Eigen::MatrixXd covariance_chol = decomposition.matrixL();
38
39 // Sample a vector of standard normal random variables
40 Eigen::VectorXd std_norm_vec(cov_rows);
41 for (int i = 0; i < cov_rows; i++) {
42 std_norm_vec(i) = std_normal_dist_(gen);
43 }
44
45 // Compute and return the sampled value
46 Eigen::VectorXd sampled_values_raw = mean + covariance_chol * std_norm_vec;
47 std::vector<double> result(cov_rows);
48 for (int i = 0; i < cov_rows; i++) {
49 result[i] = sampled_values_raw(i, 0);
50 }
51 return result;
52 }
53 Eigen::VectorXd SampleEigen(Eigen::VectorXd& mean, Eigen::MatrixXd& covariance, std::mt19937& gen) {
54 // Dimension extraction and checks
55 int mean_cols = mean.size();
56 int cov_rows = covariance.rows();
57 int cov_cols = covariance.cols();
58 CHECK_EQ(mean_cols, cov_cols);
59
60 // Variance cholesky decomposition
61 Eigen::LLT<Eigen::MatrixXd> decomposition(covariance);
62 Eigen::MatrixXd covariance_chol = decomposition.matrixL();
63
64 // Sample a vector of standard normal random variables
65 Eigen::VectorXd std_norm_vec(cov_rows);
66 for (int i = 0; i < cov_rows; i++) {
67 std_norm_vec(i) = std_normal_dist_(gen);
68 }
69
70 // Compute and return the sampled value
71 return mean + covariance_chol * std_norm_vec;
72 }
73 private:
75 std::normal_distribution<double> std_normal_dist_;
76};
77
78} // namespace StochTree
79
80#endif // STOCHTREE_NORMAL_SAMPLER_H_
Definition normal_sampler.h:24
Definition normal_sampler.h:12
Definition category_tracker.h:40