StochTree 0.1.1
Loading...
Searching...
No Matches
variance_model.h
1
5#ifndef STOCHTREE_VARIANCE_MODEL_H_
6#define STOCHTREE_VARIANCE_MODEL_H_
7
8#include <Eigen/Dense>
9#include <stochtree/data.h>
10#include <stochtree/ensemble.h>
11#include <stochtree/gamma_sampler.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/meta.h>
14
15#include <random>
16
17namespace StochTree {
18
21 public:
24 double PosteriorShape(Eigen::VectorXd& residuals, double a, double b) {
25 data_size_t n = residuals.rows();
26 return a + (0.5 * n);
27 }
28 double PosteriorScale(Eigen::VectorXd& residuals, double a, double b) {
29 data_size_t n = residuals.rows();
30 double sum_sq_resid = 0.;
31 for (data_size_t i = 0; i < n; i++) {
32 sum_sq_resid += (residuals(i) * residuals(i));
33 }
34 return b + (0.5 * sum_sq_resid);
35 }
36 double PosteriorShape(Eigen::VectorXd& residuals, Eigen::VectorXd& weights, double a, double b) {
37 data_size_t n = residuals.rows();
38 return a + (0.5 * n);
39 }
40 double PosteriorScale(Eigen::VectorXd& residuals, Eigen::VectorXd& weights, double a, double b) {
41 data_size_t n = residuals.rows();
42 double sum_sq_resid = 0.;
43 for (data_size_t i = 0; i < n; i++) {
44 sum_sq_resid += (residuals(i) * residuals(i)) * weights(i);
45 }
46 return b + (0.5 * sum_sq_resid);
47 }
48 double SampleVarianceParameter(Eigen::VectorXd& residuals, double a, double b, std::mt19937& gen) {
49 double ig_shape = PosteriorShape(residuals, a, b);
50 double ig_scale = PosteriorScale(residuals, a, b);
51 return ig_sampler_.Sample(ig_shape, ig_scale, gen);
52 }
53 double SampleVarianceParameter(Eigen::VectorXd& residuals, Eigen::VectorXd& weights, double a, double b, std::mt19937& gen) {
54 double ig_shape = PosteriorShape(residuals, weights, a, b);
55 double ig_scale = PosteriorScale(residuals, weights, a, b);
56 return ig_sampler_.Sample(ig_shape, ig_scale, gen);
57 }
58 private:
59 InverseGammaSampler ig_sampler_;
60};
61
64 public:
67 double PosteriorShape(TreeEnsemble* ensemble, double a, double b) {
68 data_size_t num_leaves = ensemble->NumLeaves();
69 return (a/2.0) + (num_leaves/2.0);
70 }
71 double PosteriorScale(TreeEnsemble* ensemble, double a, double b) {
72 double mu_sq = ensemble->SumLeafSquared();
73 return (b/2.0) + (mu_sq/2.0);
74 }
75 double SampleVarianceParameter(TreeEnsemble* ensemble, double a, double b, std::mt19937& gen) {
76 double ig_shape = PosteriorShape(ensemble, a, b);
77 double ig_scale = PosteriorScale(ensemble, a, b);
78 return ig_sampler_.Sample(ig_shape, ig_scale, gen);
79 }
80 private:
81 InverseGammaSampler ig_sampler_;
82};
83
84} // namespace StochTree
85
86#endif // STOCHTREE_VARIANCE_MODEL_H_
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition variance_model.h:20
Definition ig_sampler.h:9
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition variance_model.h:63
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:31
Definition category_tracker.h:36