StochTree 0.0.1
Loading...
Searching...
No Matches
variance_model.h
1
5#ifndef STOCHTREE_VARIANCE_MODEL_H_
6#define STOCHTREE_VARIANCE_MODEL_H_
7
8#include <Eigen/Dense>
9#include <stochtree/data.h>
10#include <stochtree/ensemble.h>
11#include <stochtree/gamma_sampler.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/meta.h>
14
15#include <cmath>
16#include <random>
17#include <set>
18#include <string>
19#include <vector>
20
21namespace StochTree {
22
25 public:
28 double PosteriorShape(Eigen::VectorXd& residuals, double a, double b) {
29 data_size_t n = residuals.rows();
30 return a + (0.5 * n);
31 }
32 double PosteriorScale(Eigen::VectorXd& residuals, double a, double b) {
33 data_size_t n = residuals.rows();
34 double sum_sq_resid = 0.;
35 for (data_size_t i = 0; i < n; i++) {
36 sum_sq_resid += (residuals(i) * residuals(i));
37 }
38 return b + (0.5 * sum_sq_resid);
39 }
40 double PosteriorShape(Eigen::VectorXd& residuals, Eigen::VectorXd& weights, double a, double b) {
41 data_size_t n = residuals.rows();
42 return a + (0.5 * n);
43 }
44 double PosteriorScale(Eigen::VectorXd& residuals, Eigen::VectorXd& weights, double a, double b) {
45 data_size_t n = residuals.rows();
46 double sum_sq_resid = 0.;
47 for (data_size_t i = 0; i < n; i++) {
48 sum_sq_resid += (residuals(i) * residuals(i)) * weights(i);
49 }
50 return b + (0.5 * sum_sq_resid);
51 }
52 double SampleVarianceParameter(Eigen::VectorXd& residuals, double a, double b, std::mt19937& gen) {
53 double ig_shape = PosteriorShape(residuals, a, b);
54 double ig_scale = PosteriorScale(residuals, a, b);
55 return ig_sampler_.Sample(ig_shape, ig_scale, gen);
56 }
57 double SampleVarianceParameter(Eigen::VectorXd& residuals, Eigen::VectorXd& weights, double a, double b, std::mt19937& gen) {
58 double ig_shape = PosteriorShape(residuals, weights, a, b);
59 double ig_scale = PosteriorScale(residuals, weights, a, b);
60 return ig_sampler_.Sample(ig_shape, ig_scale, gen);
61 }
62 private:
63 InverseGammaSampler ig_sampler_;
64};
65
68 public:
71 double PosteriorShape(TreeEnsemble* ensemble, double a, double b) {
72 data_size_t num_leaves = ensemble->NumLeaves();
73 return (a/2.0) + (num_leaves/2.0);
74 }
75 double PosteriorScale(TreeEnsemble* ensemble, double a, double b) {
76 double mu_sq = ensemble->SumLeafSquared();
77 return (b/2.0) + (mu_sq/2.0);
78 }
79 double SampleVarianceParameter(TreeEnsemble* ensemble, double a, double b, std::mt19937& gen) {
80 double ig_shape = PosteriorShape(ensemble, a, b);
81 double ig_scale = PosteriorScale(ensemble, a, b);
82 return ig_sampler_.Sample(ig_shape, ig_scale, gen);
83 }
84 private:
85 InverseGammaSampler ig_sampler_;
86};
87
88} // namespace StochTree
89
90#endif // STOCHTREE_VARIANCE_MODEL_H_
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition variance_model.h:24
Definition ig_sampler.h:9
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition variance_model.h:67
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Definition category_tracker.h:40