StochTree 0.1.1
Loading...
Searching...
No Matches
leaf_model.h
1
5#ifndef STOCHTREE_LEAF_MODEL_H_
6#define STOCHTREE_LEAF_MODEL_H_
7
8#include <Eigen/Dense>
9#include <stochtree/cutpoint_candidates.h>
10#include <stochtree/data.h>
11#include <stochtree/gamma_sampler.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/log.h>
14#include <stochtree/meta.h>
15#include <stochtree/normal_sampler.h>
16#include <stochtree/partition_tracker.h>
17#include <stochtree/prior.h>
18#include <stochtree/tree.h>
19
20#include <random>
21#include <tuple>
22#include <variant>
23
24namespace StochTree {
25
352 kConstantLeafGaussian,
353 kUnivariateRegressionLeafGaussian,
354 kMultivariateRegressionLeafGaussian,
355 kLogLinearVariance
356};
357
360 public:
361 data_size_t n;
362 double sum_w;
363 double sum_yw;
368 n = 0;
369 sum_w = 0.0;
370 sum_yw = 0.0;
371 }
381 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
382 n += 1;
383 if (dataset.HasVarWeights()) {
384 sum_w += 1/dataset.VarWeightValue(row_idx);
385 sum_yw += outcome(row_idx, 0)/dataset.VarWeightValue(row_idx);
386 } else {
387 sum_w += 1.0;
388 sum_yw += outcome(row_idx, 0);
389 }
390 }
395 n = 0;
396 sum_w = 0.0;
397 sum_yw = 0.0;
398 }
406 n = lhs.n + rhs.n;
407 sum_w = lhs.sum_w + rhs.sum_w;
408 sum_yw = lhs.sum_yw + rhs.sum_yw;
409 }
417 n = lhs.n - rhs.n;
418 sum_w = lhs.sum_w - rhs.sum_w;
419 sum_yw = lhs.sum_yw - rhs.sum_yw;
420 }
426 bool SampleGreaterThan(data_size_t threshold) {
427 return n > threshold;
428 }
434 bool SampleGreaterThanEqual(data_size_t threshold) {
435 return n >= threshold;
436 }
440 data_size_t SampleSize() {
441 return n;
442 }
443};
444
447 public:
453 GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
462 double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance);
469 double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance);
476 double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance);
483 double PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance);
495 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
496 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
502 void SetScale(double tau) {tau_ = tau;}
506 inline bool RequiresBasis() {return false;}
507 private:
508 double tau_;
509 UnivariateNormalSampler normal_sampler_;
510};
511
514 public:
515 data_size_t n;
516 double sum_xxw;
517 double sum_yxw;
522 n = 0;
523 sum_xxw = 0.0;
524 sum_yxw = 0.0;
525 }
535 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
536 n += 1;
537 if (dataset.HasVarWeights()) {
538 sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx);
539 sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx);
540 } else {
541 sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0);
542 sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0);
543 }
544 }
549 n = 0;
550 sum_xxw = 0.0;
551 sum_yxw = 0.0;
552 }
560 n = lhs.n + rhs.n;
561 sum_xxw = lhs.sum_xxw + rhs.sum_xxw;
562 sum_yxw = lhs.sum_yxw + rhs.sum_yxw;
563 }
571 n = lhs.n - rhs.n;
572 sum_xxw = lhs.sum_xxw - rhs.sum_xxw;
573 sum_yxw = lhs.sum_yxw - rhs.sum_yxw;
574 }
580 bool SampleGreaterThan(data_size_t threshold) {
581 return n > threshold;
582 }
588 bool SampleGreaterThanEqual(data_size_t threshold) {
589 return n >= threshold;
590 }
594 data_size_t SampleSize() {
595 return n;
596 }
597};
598
601 public:
602 GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
618 double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
625 double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
632 double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
644 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
645 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
646 void SetScale(double tau) {tau_ = tau;}
647 inline bool RequiresBasis() {return true;}
648 private:
649 double tau_;
650 UnivariateNormalSampler normal_sampler_;
651};
652
655 public:
656 data_size_t n;
657 int p;
658 Eigen::MatrixXd XtWX;
659 Eigen::MatrixXd ytWX;
666 n = 0;
667 XtWX = Eigen::MatrixXd::Zero(basis_dim, basis_dim);
668 ytWX = Eigen::MatrixXd::Zero(1, basis_dim);
669 p = basis_dim;
670 }
680 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
681 n += 1;
682 if (dataset.HasVarWeights()) {
683 XtWX += dataset.GetBasis()(row_idx, Eigen::all).transpose()*dataset.GetBasis()(row_idx, Eigen::all)/dataset.VarWeightValue(row_idx);
684 ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all)))/dataset.VarWeightValue(row_idx);
685 } else {
686 XtWX += dataset.GetBasis()(row_idx, Eigen::all).transpose()*dataset.GetBasis()(row_idx, Eigen::all);
687 ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all)));
688 }
689 }
694 n = 0;
695 XtWX = Eigen::MatrixXd::Zero(p, p);
696 ytWX = Eigen::MatrixXd::Zero(1, p);
697 }
705 n = lhs.n + rhs.n;
706 XtWX = lhs.XtWX + rhs.XtWX;
707 ytWX = lhs.ytWX + rhs.ytWX;
708 }
716 n = lhs.n - rhs.n;
717 XtWX = lhs.XtWX - rhs.XtWX;
718 ytWX = lhs.ytWX - rhs.ytWX;
719 }
725 bool SampleGreaterThan(data_size_t threshold) {
726 return n > threshold;
727 }
733 bool SampleGreaterThanEqual(data_size_t threshold) {
734 return n >= threshold;
735 }
739 data_size_t SampleSize() {
740 return n;
741 }
742};
743
746 public:
752 GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();}
775 Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
782 Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
794 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
795 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
796 void SetScale(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0;}
797 inline bool RequiresBasis() {return true;}
798 private:
799 Eigen::MatrixXd Sigma_0_;
800 MultivariateNormalSampler multivariate_normal_sampler_;
801};
802
805 public:
806 data_size_t n;
807 double weighted_sum_ei;
809 n = 0;
810 weighted_sum_ei = 0.0;
811 }
821 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
822 n += 1;
823 weighted_sum_ei += std::exp(std::log(outcome(row_idx)*outcome(row_idx)) - tracker.GetSamplePrediction(row_idx) + tracker.GetTreeSamplePrediction(row_idx, tree_idx));
824 }
829 n = 0;
830 weighted_sum_ei = 0.0;
831 }
839 n = lhs.n + rhs.n;
840 weighted_sum_ei = lhs.weighted_sum_ei + rhs.weighted_sum_ei;
841 }
849 n = lhs.n - rhs.n;
850 weighted_sum_ei = lhs.weighted_sum_ei - rhs.weighted_sum_ei;
851 }
857 bool SampleGreaterThan(data_size_t threshold) {
858 return n > threshold;
859 }
865 bool SampleGreaterThanEqual(data_size_t threshold) {
866 return n >= threshold;
867 }
871 data_size_t SampleSize() {
872 return n;
873 }
874};
875
878 public:
879 LogLinearVarianceLeafModel(double a, double b) {a_ = a; b_ = b; gamma_sampler_ = GammaSampler();}
888 double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat& left_stat, LogLinearVarianceSuffStat& right_stat, double global_variance);
895 double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance);
896 double SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance);
903 double PosteriorParameterShape(LogLinearVarianceSuffStat& suff_stat, double global_variance);
910 double PosteriorParameterScale(LogLinearVarianceSuffStat& suff_stat, double global_variance);
922 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
923 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
924 void SetPriorShape(double a) {a_ = a;}
925 void SetPriorRate(double b) {b_ = b;}
926 inline bool RequiresBasis() {return false;}
927 private:
928 double a_;
929 double b_;
930 GammaSampler gamma_sampler_;
931};
932
945
958
959template<typename SuffStatType, typename... SuffStatConstructorArgs>
960static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) {
961 return SuffStatType(leaf_suff_stat_args...);
962}
963
964template<typename LeafModelType, typename... LeafModelConstructorArgs>
965static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_model_args) {
966 return LeafModelType(leaf_model_args...);
967}
968
975static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) {
976 if (model_type == kConstantLeafGaussian) {
977 return createSuffStat<GaussianConstantSuffStat>();
978 } else if (model_type == kUnivariateRegressionLeafGaussian) {
979 return createSuffStat<GaussianUnivariateRegressionSuffStat>();
980 } else if (model_type == kMultivariateRegressionLeafGaussian) {
981 return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim);
982 } else {
983 return createSuffStat<LogLinearVarianceSuffStat>();
984 }
985}
986
996static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) {
997 if (model_type == kConstantLeafGaussian) {
998 return createLeafModel<GaussianConstantLeafModel, double>(tau);
999 } else if (model_type == kUnivariateRegressionLeafGaussian) {
1000 return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau);
1001 } else if (model_type == kMultivariateRegressionLeafGaussian) {
1002 return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0);
1003 } else {
1004 return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b);
1005 }
1006}
1007
1008template<typename SuffStatType>
1009static inline void AccumulateSuffStatProposed(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
1010 ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature) {
1011 // Acquire iterators
1012 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_num);
1013 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_num);
1014
1015 // Accumulate sufficient statistics
1016 for (auto i = node_begin_iter; i != node_end_iter; i++) {
1017 auto idx = *i;
1018 double feature_value = dataset.CovariateValue(idx, split_feature);
1019 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1020 if (split.SplitTrue(feature_value)) {
1021 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1022 } else {
1023 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1024 }
1025 }
1026}
1027
1028template<typename SuffStatType>
1029static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
1030 ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) {
1031 // Acquire iterators
1032 auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id);
1033 auto left_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, left_node_id);
1034 auto right_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, right_node_id);
1035 auto right_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, right_node_id);
1036
1037 // Accumulate sufficient statistics for the left and split nodes
1038 for (auto i = left_node_begin_iter; i != left_node_end_iter; i++) {
1039 auto idx = *i;
1040 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1041 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1042 }
1043
1044 // Accumulate sufficient statistics for the right and split nodes
1045 for (auto i = right_node_begin_iter; i != right_node_end_iter; i++) {
1046 auto idx = *i;
1047 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1048 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1049 }
1050}
1051
1052template<typename SuffStatType, bool sorted>
1053static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, int tree_num, int node_id) {
1054 // Acquire iterators
1055 std::vector<data_size_t>::iterator node_begin_iter;
1056 std::vector<data_size_t>::iterator node_end_iter;
1057 if (sorted) {
1058 // Default to the first feature if we're using the presort tracker
1059 node_begin_iter = tracker.SortedNodeBeginIterator(node_id, 0);
1060 node_end_iter = tracker.SortedNodeEndIterator(node_id, 0);
1061 } else {
1062 node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
1063 node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
1064 }
1065
1066 // Accumulate sufficient statistics
1067 for (auto i = node_begin_iter; i != node_end_iter; i++) {
1068 auto idx = *i;
1069 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1070 }
1071}
1072
1073template<typename SuffStatType>
1074static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container,
1075 ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id,
1076 int feature_num, int cutpoint_num) {
1077 // Acquire iterators
1078 auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num);
1079 auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num);
1080
1081 // Determine node start point
1082 data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num);
1083
1084 // Determine cutpoint bin start and end points
1085 data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num, feature_num);
1086 data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_num, feature_num);
1087 data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num + 1, feature_num);
1088
1089 // Cutpoint specific iterators
1090 // TODO: fix the hack of having to subtract off node_begin, probably by cleaning up the CutpointGridContainer interface
1091 auto cutpoint_begin_iter = node_begin_iter + (current_bin_begin - node_begin);
1092 auto cutpoint_end_iter = node_begin_iter + (next_bin_begin - node_begin);
1093
1094 // Accumulate sufficient statistics
1095 for (auto i = cutpoint_begin_iter; i != cutpoint_end_iter; i++) {
1096 auto idx = *i;
1097 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1098 }
1099}
1100
// end of leaf_model_group
1102
1103} // namespace StochTree
1104
1105#endif // STOCHTREE_LEAF_MODEL_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:194
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
double BasisValue(data_size_t row, int col)
Returns a dataset's basis value stored at (row, col)
Definition data.h:372
Eigen::MatrixXd & GetBasis()
Return a reference to the raw Eigen::MatrixXd storing the basis data.
Definition data.h:390
double VarWeightValue(data_size_t row)
Returns a dataset's variance weight stored at element row
Definition data.h:378
bool HasVarWeights()
Whether or not a ForestDataset has (yet) loaded variance weights.
Definition data.h:352
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:50
Definition gamma_sampler.h:9
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:446
double SplitLogMarginalLikelihood(GaussianConstantSuffStat &left_stat, GaussianConstantSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterVariance(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
bool RequiresBasis()
Whether this model requires a basis vector for posterior inference and prediction.
Definition leaf_model.h:506
GaussianConstantLeafModel(double tau)
Construct a new GaussianConstantLeafModel object.
Definition leaf_model.h:453
double PosteriorParameterMean(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
void SetScale(double tau)
Set a new value for the leaf node scale parameter.
Definition leaf_model.h:502
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:359
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:440
GaussianConstantSuffStat()
Construct a new GaussianConstantSuffStat object, setting all sufficient statistics to zero.
Definition leaf_model.h:367
void SubtractSuffStat(GaussianConstantSuffStat &lhs, GaussianConstantSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:416
void AddSuffStat(GaussianConstantSuffStat &lhs, GaussianConstantSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:405
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:434
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:426
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:381
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:394
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:745
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &left_stat, GaussianMultivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd &Sigma_0)
Construct a new GaussianMultivariateRegressionLeafModel object.
Definition leaf_model.h:752
Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:654
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:725
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:739
void AddSuffStat(GaussianMultivariateRegressionSuffStat &lhs, GaussianMultivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:704
void SubtractSuffStat(GaussianMultivariateRegressionSuffStat &lhs, GaussianMultivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:715
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:733
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:693
GaussianMultivariateRegressionSuffStat(int basis_dim)
Construct a new GaussianMultivariateRegressionSuffStat object.
Definition leaf_model.h:665
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:680
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:600
double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &left_stat, GaussianUnivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:513
GaussianUnivariateRegressionSuffStat()
Construct a new GaussianUnivariateRegressionSuffStat object, setting all sufficient statistics to zer...
Definition leaf_model.h:521
void SubtractSuffStat(GaussianUnivariateRegressionSuffStat &lhs, GaussianUnivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:570
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:580
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:588
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:594
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:548
void AddSuffStat(GaussianUnivariateRegressionSuffStat &lhs, GaussianUnivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:559
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:535
Marginal likelihood and posterior computation for heteroskedastic log-linear variance model.
Definition leaf_model.h:877
double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
double PosteriorParameterShape(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior shape parameter.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat &left_stat, LogLinearVarianceSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterScale(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior scale parameter.
Sufficient statistic and associated operations for heteroskedastic log-linear variance model.
Definition leaf_model.h:804
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:865
void SubtractSuffStat(LogLinearVarianceSuffStat &lhs, LogLinearVarianceSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:848
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:828
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:871
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:821
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:857
void AddSuffStat(LogLinearVarianceSuffStat &lhs, LogLinearVarianceSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:838
Definition normal_sampler.h:24
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Decision tree data structure.
Definition tree.h:69
Definition normal_sampler.h:12
static SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim=0)
Factory function that creates a new SuffStat object for the specified model type.
Definition leaf_model.h:975
std::variant< GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, GaussianMultivariateRegressionSuffStat, LogLinearVarianceSuffStat > SuffStatVariant
Unifying layer for disparate sufficient statistic class types.
Definition leaf_model.h:944
ModelType
Leaf models for the forest sampler:
Definition leaf_model.h:351
static LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd &Sigma0, double a, double b)
Factory function that creates a new LeafModel object for the specified model type.
Definition leaf_model.h:996
std::variant< GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, GaussianMultivariateRegressionLeafModel, LogLinearVarianceLeafModel > LeafModelVariant
Unifying layer for disparate leaf model class types.
Definition leaf_model.h:957
Definition category_tracker.h:40