StochTree 0.4.1.9000
Loading...
Searching...
No Matches
normal_sampler.h
1
2#ifndef STOCHTREE_NORMAL_SAMPLER_H_
3#define STOCHTREE_NORMAL_SAMPLER_H_
4
5#include <Eigen/Dense>
6#include <stochtree/distributions.h>
7#include <stochtree/log.h>
8#include <random>
9#include <vector>
10
11namespace StochTree {
12
14 public:
15 UnivariateNormalSampler() {std_normal_dist_ = standard_normal();}
17 double Sample(double mean, double variance, std::mt19937& gen) {
18 return mean + std::sqrt(variance) * std_normal_dist_(gen);
19 }
20 private:
22 standard_normal std_normal_dist_;
23};
24
26 public:
27 MultivariateNormalSampler() {std_normal_dist_ = standard_normal();}
29 std::vector<double> Sample(Eigen::VectorXd& mean, Eigen::MatrixXd& covariance, std::mt19937& gen) {
30 // Dimension extraction and checks
31 int mean_cols = mean.size();
32 int cov_rows = covariance.rows();
33 int cov_cols = covariance.cols();
34 CHECK_EQ(mean_cols, cov_cols);
35
36 // Variance cholesky decomposition
37 Eigen::LLT<Eigen::MatrixXd> decomposition(covariance);
38 Eigen::MatrixXd covariance_chol = decomposition.matrixL();
39
40 // Sample a vector of standard normal random variables
41 Eigen::VectorXd std_norm_vec(cov_rows);
42 for (int i = 0; i < cov_rows; i++) {
43 std_norm_vec(i) = std_normal_dist_(gen);
44 }
45
46 // Compute and return the sampled value
47 Eigen::VectorXd sampled_values_raw = mean + covariance_chol * std_norm_vec;
48 std::vector<double> result(cov_rows);
49 for (int i = 0; i < cov_rows; i++) {
50 result[i] = sampled_values_raw(i, 0);
51 }
52 return result;
53 }
54 Eigen::VectorXd SampleEigen(Eigen::VectorXd& mean, Eigen::MatrixXd& covariance, std::mt19937& gen) {
55 // Dimension extraction and checks
56 int mean_cols = mean.size();
57 int cov_rows = covariance.rows();
58 int cov_cols = covariance.cols();
59 CHECK_EQ(mean_cols, cov_cols);
60
61 // Variance cholesky decomposition
62 Eigen::LLT<Eigen::MatrixXd> decomposition(covariance);
63 Eigen::MatrixXd covariance_chol = decomposition.matrixL();
64
65 // Sample a vector of standard normal random variables
66 Eigen::VectorXd std_norm_vec(cov_rows);
67 for (int i = 0; i < cov_rows; i++) {
68 std_norm_vec(i) = std_normal_dist_(gen);
69 }
70
71 // Compute and return the sampled value
72 return mean + covariance_chol * std_norm_vec;
73 }
74 private:
76 standard_normal std_normal_dist_;
77};
78
79} // namespace StochTree
80
81#endif // STOCHTREE_NORMAL_SAMPLER_H_
Definition normal_sampler.h:25
Definition normal_sampler.h:13
Definition distributions.h:38
A collection of random number generation utilities.
Definition category_tracker.h:36