StochTree 0.1.1
Loading...
Searching...
No Matches
leaf_model.h
1
5#ifndef STOCHTREE_LEAF_MODEL_H_
6#define STOCHTREE_LEAF_MODEL_H_
7
8#include <Eigen/Dense>
9#include <stochtree/cutpoint_candidates.h>
10#include <stochtree/data.h>
11#include <stochtree/gamma_sampler.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/log.h>
14#include <stochtree/meta.h>
15#include <stochtree/normal_sampler.h>
16#include <stochtree/openmp_utils.h>
17#include <stochtree/partition_tracker.h>
18#include <stochtree/prior.h>
19#include <stochtree/tree.h>
20
21#include <random>
22#include <variant>
23
24namespace StochTree {
25
352 kConstantLeafGaussian,
353 kUnivariateRegressionLeafGaussian,
354 kMultivariateRegressionLeafGaussian,
355 kLogLinearVariance
356};
357
360 public:
361 data_size_t n;
362 double sum_w;
363 double sum_yw;
368 n = 0;
369 sum_w = 0.0;
370 sum_yw = 0.0;
371 }
381 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
382 n += 1;
383 if (dataset.HasVarWeights()) {
384 sum_w += 1/dataset.VarWeightValue(row_idx);
385 sum_yw += outcome(row_idx, 0)/dataset.VarWeightValue(row_idx);
386 } else {
387 sum_w += 1.0;
388 sum_yw += outcome(row_idx, 0);
389 }
390 }
395 n = 0;
396 sum_w = 0.0;
397 sum_yw = 0.0;
398 }
405 n += suff_stat.n;
406 sum_w += suff_stat.sum_w;
407 sum_yw += suff_stat.sum_yw;
408 }
416 n = lhs.n + rhs.n;
417 sum_w = lhs.sum_w + rhs.sum_w;
418 sum_yw = lhs.sum_yw + rhs.sum_yw;
419 }
427 n = lhs.n - rhs.n;
428 sum_w = lhs.sum_w - rhs.sum_w;
429 sum_yw = lhs.sum_yw - rhs.sum_yw;
430 }
436 bool SampleGreaterThan(data_size_t threshold) {
437 return n > threshold;
438 }
444 bool SampleGreaterThanEqual(data_size_t threshold) {
445 return n >= threshold;
446 }
450 data_size_t SampleSize() {
451 return n;
452 }
453};
454
457 public:
463 GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
472 double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance);
479 double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance);
486 double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance);
493 double PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance);
505 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
506 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
512 void SetScale(double tau) {tau_ = tau;}
516 inline bool RequiresBasis() {return false;}
517 private:
518 double tau_;
519 UnivariateNormalSampler normal_sampler_;
520};
521
524 public:
525 data_size_t n;
526 double sum_xxw;
527 double sum_yxw;
532 n = 0;
533 sum_xxw = 0.0;
534 sum_yxw = 0.0;
535 }
545 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
546 n += 1;
547 if (dataset.HasVarWeights()) {
548 sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx);
549 sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx);
550 } else {
551 sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0);
552 sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0);
553 }
554 }
559 n = 0;
560 sum_xxw = 0.0;
561 sum_yxw = 0.0;
562 }
569 n += suff_stat.n;
570 sum_xxw += suff_stat.sum_xxw;
571 sum_yxw += suff_stat.sum_yxw;
572 }
580 n = lhs.n + rhs.n;
581 sum_xxw = lhs.sum_xxw + rhs.sum_xxw;
582 sum_yxw = lhs.sum_yxw + rhs.sum_yxw;
583 }
591 n = lhs.n - rhs.n;
592 sum_xxw = lhs.sum_xxw - rhs.sum_xxw;
593 sum_yxw = lhs.sum_yxw - rhs.sum_yxw;
594 }
600 bool SampleGreaterThan(data_size_t threshold) {
601 return n > threshold;
602 }
608 bool SampleGreaterThanEqual(data_size_t threshold) {
609 return n >= threshold;
610 }
614 data_size_t SampleSize() {
615 return n;
616 }
617};
618
621 public:
622 GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
638 double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
645 double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
652 double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
664 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
665 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
666 void SetScale(double tau) {tau_ = tau;}
667 inline bool RequiresBasis() {return true;}
668 private:
669 double tau_;
670 UnivariateNormalSampler normal_sampler_;
671};
672
675 public:
676 data_size_t n;
677 int p;
678 Eigen::MatrixXd XtWX;
679 Eigen::MatrixXd ytWX;
686 n = 0;
687 XtWX = Eigen::MatrixXd::Zero(basis_dim, basis_dim);
688 ytWX = Eigen::MatrixXd::Zero(1, basis_dim);
689 p = basis_dim;
690 }
700 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
701 n += 1;
702 if (dataset.HasVarWeights()) {
703 XtWX += dataset.GetBasis()(row_idx, Eigen::all).transpose()*dataset.GetBasis()(row_idx, Eigen::all)/dataset.VarWeightValue(row_idx);
704 ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all)))/dataset.VarWeightValue(row_idx);
705 } else {
706 XtWX += dataset.GetBasis()(row_idx, Eigen::all).transpose()*dataset.GetBasis()(row_idx, Eigen::all);
707 ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all)));
708 }
709 }
714 n = 0;
715 XtWX = Eigen::MatrixXd::Zero(p, p);
716 ytWX = Eigen::MatrixXd::Zero(1, p);
717 }
724 n += suff_stat.n;
725 XtWX += suff_stat.XtWX;
726 ytWX += suff_stat.ytWX;
727 }
735 n = lhs.n + rhs.n;
736 XtWX = lhs.XtWX + rhs.XtWX;
737 ytWX = lhs.ytWX + rhs.ytWX;
738 }
746 n = lhs.n - rhs.n;
747 XtWX = lhs.XtWX - rhs.XtWX;
748 ytWX = lhs.ytWX - rhs.ytWX;
749 }
755 bool SampleGreaterThan(data_size_t threshold) {
756 return n > threshold;
757 }
763 bool SampleGreaterThanEqual(data_size_t threshold) {
764 return n >= threshold;
765 }
769 data_size_t SampleSize() {
770 return n;
771 }
772};
773
776 public:
782 GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();}
805 Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
812 Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
824 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
825 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
826 void SetScale(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0;}
827 inline bool RequiresBasis() {return true;}
828 private:
829 Eigen::MatrixXd Sigma_0_;
830 MultivariateNormalSampler multivariate_normal_sampler_;
831};
832
835 public:
836 data_size_t n;
837 double weighted_sum_ei;
839 n = 0;
840 weighted_sum_ei = 0.0;
841 }
851 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
852 n += 1;
853 weighted_sum_ei += std::exp(std::log(outcome(row_idx)*outcome(row_idx)) - tracker.GetSamplePrediction(row_idx) + tracker.GetTreeSamplePrediction(row_idx, tree_idx));
854 }
859 n = 0;
860 weighted_sum_ei = 0.0;
861 }
868 n += suff_stat.n;
869 weighted_sum_ei += suff_stat.weighted_sum_ei;
870 }
878 n = lhs.n + rhs.n;
879 weighted_sum_ei = lhs.weighted_sum_ei + rhs.weighted_sum_ei;
880 }
888 n = lhs.n - rhs.n;
889 weighted_sum_ei = lhs.weighted_sum_ei - rhs.weighted_sum_ei;
890 }
896 bool SampleGreaterThan(data_size_t threshold) {
897 return n > threshold;
898 }
904 bool SampleGreaterThanEqual(data_size_t threshold) {
905 return n >= threshold;
906 }
910 data_size_t SampleSize() {
911 return n;
912 }
913};
914
917 public:
918 LogLinearVarianceLeafModel(double a, double b) {a_ = a; b_ = b; gamma_sampler_ = GammaSampler();}
927 double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat& left_stat, LogLinearVarianceSuffStat& right_stat, double global_variance);
934 double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance);
935 double SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance);
942 double PosteriorParameterShape(LogLinearVarianceSuffStat& suff_stat, double global_variance);
949 double PosteriorParameterScale(LogLinearVarianceSuffStat& suff_stat, double global_variance);
961 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
962 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
963 void SetPriorShape(double a) {a_ = a;}
964 void SetPriorRate(double b) {b_ = b;}
965 inline bool RequiresBasis() {return false;}
966 private:
967 double a_;
968 double b_;
969 GammaSampler gamma_sampler_;
970};
971
984
997
998template<typename SuffStatType, typename... SuffStatConstructorArgs>
999static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) {
1000 return SuffStatType(leaf_suff_stat_args...);
1001}
1002
1003template<typename LeafModelType, typename... LeafModelConstructorArgs>
1004static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_model_args) {
1005 return LeafModelType(leaf_model_args...);
1006}
1007
1014static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) {
1015 if (model_type == kConstantLeafGaussian) {
1016 return createSuffStat<GaussianConstantSuffStat>();
1017 } else if (model_type == kUnivariateRegressionLeafGaussian) {
1018 return createSuffStat<GaussianUnivariateRegressionSuffStat>();
1019 } else if (model_type == kMultivariateRegressionLeafGaussian) {
1020 return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim);
1021 } else {
1022 return createSuffStat<LogLinearVarianceSuffStat>();
1023 }
1024}
1025
1035static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) {
1036 if (model_type == kConstantLeafGaussian) {
1037 return createLeafModel<GaussianConstantLeafModel, double>(tau);
1038 } else if (model_type == kUnivariateRegressionLeafGaussian) {
1039 return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau);
1040 } else if (model_type == kMultivariateRegressionLeafGaussian) {
1041 return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0);
1042 } else {
1043 return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b);
1044 }
1045}
1046
1047template<typename SuffStatType, typename... SuffStatConstructorArgs>
1048static inline void AccumulateSuffStatProposed(
1049 SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
1050 ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads,
1051 SuffStatConstructorArgs&... suff_stat_args
1052) {
1053 // Determine the position of the node's indices in the forest tracking data structure
1054 int node_begin_index = tracker.UnsortedNodeBegin(tree_num, leaf_num);
1055 int node_end_index = tracker.UnsortedNodeEnd(tree_num, leaf_num);
1056
1057 // Extract pointer to the feature partition for tree_num
1058 UnsortedNodeSampleTracker* unsorted_node_sample_tracker = tracker.GetUnsortedNodeSampleTracker();
1059 FeatureUnsortedPartition* feature_partition = unsorted_node_sample_tracker->GetFeaturePartition(tree_num);
1060
1061 // Determine the number of threads to use
1062 int chunk_size = (node_end_index - node_begin_index) / num_threads;
1063 if (chunk_size < 100) {
1064 num_threads = 1;
1065 chunk_size = node_end_index - node_begin_index;
1066 }
1067
1068 if (num_threads > 1) {
1069 // Split the work into num_threads chunks
1070 std::vector<std::pair<int, int>> thread_ranges(num_threads);
1071 std::vector<SuffStatType> thread_suff_stats_node;
1072 std::vector<SuffStatType> thread_suff_stats_left;
1073 std::vector<SuffStatType> thread_suff_stats_right;
1074 for (int i = 0; i < num_threads; i++) {
1075 thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size,
1076 node_begin_index + (i + 1) * chunk_size);
1077 thread_suff_stats_node.emplace_back(suff_stat_args...);
1078 thread_suff_stats_left.emplace_back(suff_stat_args...);
1079 thread_suff_stats_right.emplace_back(suff_stat_args...);
1080 }
1081
1082 // Accumulate sufficient statistics
1083 StochTree::ParallelFor(0, num_threads, num_threads, [&](int i) {
1084 int start_idx = thread_ranges[i].first;
1085 int end_idx = thread_ranges[i].second;
1086 for (int idx = start_idx; idx < end_idx; idx++) {
1087 int obs_num = feature_partition->indices_[idx];
1088 double feature_value = dataset.CovariateValue(obs_num, split_feature);
1089 thread_suff_stats_node[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1090 if (split.SplitTrue(feature_value)) {
1091 thread_suff_stats_left[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1092 } else {
1093 thread_suff_stats_right[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1094 }
1095 }
1096 });
1097
1098 // Combine the thread-local sufficient statistics
1099 for (int i = 0; i < num_threads; i++) {
1100 node_suff_stat.AddSuffStatInplace(thread_suff_stats_node[i]);
1101 left_suff_stat.AddSuffStatInplace(thread_suff_stats_left[i]);
1102 right_suff_stat.AddSuffStatInplace(thread_suff_stats_right[i]);
1103 }
1104 } else {
1105 for (int idx = node_begin_index; idx < node_end_index; idx++) {
1106 int obs_num = feature_partition->indices_[idx];
1107 double feature_value = dataset.CovariateValue(obs_num, split_feature);
1108 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1109 if (split.SplitTrue(feature_value)) {
1110 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1111 } else {
1112 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1113 }
1114 }
1115 }
1116}
1117
1118template<typename SuffStatType>
1119static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
1120 ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) {
1121 // Acquire iterators
1122 auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id);
1123 auto left_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, left_node_id);
1124 auto right_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, right_node_id);
1125 auto right_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, right_node_id);
1126
1127 // Accumulate sufficient statistics for the left and split nodes
1128 for (auto i = left_node_begin_iter; i != left_node_end_iter; i++) {
1129 auto idx = *i;
1130 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1131 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1132 }
1133
1134 // Accumulate sufficient statistics for the right and split nodes
1135 for (auto i = right_node_begin_iter; i != right_node_end_iter; i++) {
1136 auto idx = *i;
1137 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1138 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1139 }
1140}
1141
1142template<typename SuffStatType, bool sorted>
1143static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, int tree_num, int node_id) {
1144 // Acquire iterators
1145 std::vector<data_size_t>::iterator node_begin_iter;
1146 std::vector<data_size_t>::iterator node_end_iter;
1147 if (sorted) {
1148 // Default to the first feature if we're using the presort tracker
1149 node_begin_iter = tracker.SortedNodeBeginIterator(node_id, 0);
1150 node_end_iter = tracker.SortedNodeEndIterator(node_id, 0);
1151 } else {
1152 node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
1153 node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
1154 }
1155
1156 // Accumulate sufficient statistics
1157 for (auto i = node_begin_iter; i != node_end_iter; i++) {
1158 auto idx = *i;
1159 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1160 }
1161}
1162
1163template<typename SuffStatType>
1164static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container,
1165 ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id,
1166 int feature_num, int cutpoint_num) {
1167 // Acquire iterators
1168 auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num);
1169 auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num);
1170
1171 // Determine node start point
1172 data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num);
1173
1174 // Determine cutpoint bin start and end points
1175 data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num, feature_num);
1176 data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_num, feature_num);
1177 data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num + 1, feature_num);
1178
1179 // Cutpoint specific iterators
1180 // TODO: fix the hack of having to subtract off node_begin, probably by cleaning up the CutpointGridContainer interface
1181 auto cutpoint_begin_iter = node_begin_iter + (current_bin_begin - node_begin);
1182 auto cutpoint_end_iter = node_begin_iter + (next_bin_begin - node_begin);
1183
1184 // Accumulate sufficient statistics
1185 for (auto i = cutpoint_begin_iter; i != cutpoint_end_iter; i++) {
1186 auto idx = *i;
1187 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1188 }
1189}
1190
// end of leaf_model_group
1192
1193} // namespace StochTree
1194
1195#endif // STOCHTREE_LEAF_MODEL_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:193
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
double BasisValue(data_size_t row, int col)
Returns a dataset's basis value stored at (row, col)
Definition data.h:371
Eigen::MatrixXd & GetBasis()
Return a reference to the raw Eigen::MatrixXd storing the basis data.
Definition data.h:389
double VarWeightValue(data_size_t row)
Returns a dataset's variance weight stored at element row
Definition data.h:377
bool HasVarWeights()
Whether or not a ForestDataset has (yet) loaded variance weights.
Definition data.h:351
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:47
Definition gamma_sampler.h:9
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:456
double SplitLogMarginalLikelihood(GaussianConstantSuffStat &left_stat, GaussianConstantSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterVariance(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
bool RequiresBasis()
Whether this model requires a basis vector for posterior inference and prediction.
Definition leaf_model.h:516
GaussianConstantLeafModel(double tau)
Construct a new GaussianConstantLeafModel object.
Definition leaf_model.h:463
double PosteriorParameterMean(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
void SetScale(double tau)
Set a new value for the leaf node scale parameter.
Definition leaf_model.h:512
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:359
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:450
GaussianConstantSuffStat()
Construct a new GaussianConstantSuffStat object, setting all sufficient statistics to zero.
Definition leaf_model.h:367
void SubtractSuffStat(GaussianConstantSuffStat &lhs, GaussianConstantSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:426
void AddSuffStat(GaussianConstantSuffStat &lhs, GaussianConstantSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:415
void AddSuffStatInplace(GaussianConstantSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:404
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:444
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:436
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:381
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:394
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:775
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &left_stat, GaussianMultivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd &Sigma_0)
Construct a new GaussianMultivariateRegressionLeafModel object.
Definition leaf_model.h:782
Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:674
void AddSuffStatInplace(GaussianMultivariateRegressionSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:723
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:755
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:769
void AddSuffStat(GaussianMultivariateRegressionSuffStat &lhs, GaussianMultivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:734
void SubtractSuffStat(GaussianMultivariateRegressionSuffStat &lhs, GaussianMultivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:745
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:763
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:713
GaussianMultivariateRegressionSuffStat(int basis_dim)
Construct a new GaussianMultivariateRegressionSuffStat object.
Definition leaf_model.h:685
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:700
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:620
double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &left_stat, GaussianUnivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:523
GaussianUnivariateRegressionSuffStat()
Construct a new GaussianUnivariateRegressionSuffStat object, setting all sufficient statistics to zer...
Definition leaf_model.h:531
void SubtractSuffStat(GaussianUnivariateRegressionSuffStat &lhs, GaussianUnivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:590
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:600
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:608
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:614
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:558
void AddSuffStat(GaussianUnivariateRegressionSuffStat &lhs, GaussianUnivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:579
void AddSuffStatInplace(GaussianUnivariateRegressionSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:568
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:545
Marginal likelihood and posterior computation for heteroskedastic log-linear variance model.
Definition leaf_model.h:916
double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
double PosteriorParameterShape(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior shape parameter.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat &left_stat, LogLinearVarianceSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterScale(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior scale parameter.
Sufficient statistic and associated operations for heteroskedastic log-linear variance model.
Definition leaf_model.h:834
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:904
void SubtractSuffStat(LogLinearVarianceSuffStat &lhs, LogLinearVarianceSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:887
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:858
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:910
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:851
void AddSuffStatInplace(LogLinearVarianceSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:867
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:896
void AddSuffStat(LogLinearVarianceSuffStat &lhs, LogLinearVarianceSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:877
Definition normal_sampler.h:24
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:31
Decision tree data structure.
Definition tree.h:66
Definition normal_sampler.h:12
static SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim=0)
Factory function that creates a new SuffStat object for the specified model type.
Definition leaf_model.h:1014
std::variant< GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, GaussianMultivariateRegressionSuffStat, LogLinearVarianceSuffStat > SuffStatVariant
Unifying layer for disparate sufficient statistic class types.
Definition leaf_model.h:983
ModelType
Leaf models for the forest sampler:
Definition leaf_model.h:351
static LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd &Sigma0, double a, double b)
Factory function that creates a new LeafModel object for the specified model type.
Definition leaf_model.h:1035
std::variant< GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, GaussianMultivariateRegressionLeafModel, LogLinearVarianceLeafModel > LeafModelVariant
Unifying layer for disparate leaf model class types.
Definition leaf_model.h:996
Definition category_tracker.h:36