StochTree 0.4.3.9000
Loading...
Searching...
No Matches
leaf_model.h
1
5#ifndef STOCHTREE_LEAF_MODEL_H_
6#define STOCHTREE_LEAF_MODEL_H_
7
8#include <Eigen/Dense>
9#include <stochtree/cutpoint_candidates.h>
10#include <stochtree/data.h>
11#include <stochtree/gamma_sampler.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/log.h>
14#include <stochtree/meta.h>
15#include <stochtree/normal_sampler.h>
16#include <stochtree/openmp_utils.h>
17#include <stochtree/partition_tracker.h>
18#include <stochtree/prior.h>
19#include <stochtree/tree.h>
20
21#include <random>
22#include <variant>
23
24namespace StochTree {
25
353 kConstantLeafGaussian,
354 kUnivariateRegressionLeafGaussian,
355 kMultivariateRegressionLeafGaussian,
356 kLogLinearVariance,
357 kCloglogOrdinal
358};
359
362 public:
363 data_size_t n;
364 double sum_w;
365 double sum_yw;
370 n = 0;
371 sum_w = 0.0;
372 sum_yw = 0.0;
373 }
383 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
384 n += 1;
385 if (dataset.HasVarWeights()) {
386 sum_w += 1/dataset.VarWeightValue(row_idx);
387 sum_yw += outcome(row_idx, 0)/dataset.VarWeightValue(row_idx);
388 } else {
389 sum_w += 1.0;
390 sum_yw += outcome(row_idx, 0);
391 }
392 }
397 n = 0;
398 sum_w = 0.0;
399 sum_yw = 0.0;
400 }
407 n += suff_stat.n;
408 sum_w += suff_stat.sum_w;
409 sum_yw += suff_stat.sum_yw;
410 }
418 n = lhs.n + rhs.n;
419 sum_w = lhs.sum_w + rhs.sum_w;
420 sum_yw = lhs.sum_yw + rhs.sum_yw;
421 }
429 n = lhs.n - rhs.n;
430 sum_w = lhs.sum_w - rhs.sum_w;
431 sum_yw = lhs.sum_yw - rhs.sum_yw;
432 }
438 bool SampleGreaterThan(data_size_t threshold) {
439 return n > threshold;
440 }
446 bool SampleGreaterThanEqual(data_size_t threshold) {
447 return n >= threshold;
448 }
452 data_size_t SampleSize() {
453 return n;
454 }
455};
456
459 public:
465 GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
474 double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance);
481 double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance);
488 double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance);
495 double PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance);
507 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
508 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
514 void SetScale(double tau) {tau_ = tau;}
518 inline bool RequiresBasis() {return false;}
519 private:
520 double tau_;
521 UnivariateNormalSampler normal_sampler_;
522};
523
526 public:
527 data_size_t n;
528 double sum_xxw;
529 double sum_yxw;
534 n = 0;
535 sum_xxw = 0.0;
536 sum_yxw = 0.0;
537 }
547 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
548 n += 1;
549 if (dataset.HasVarWeights()) {
550 sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx);
551 sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx);
552 } else {
553 sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0);
554 sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0);
555 }
556 }
561 n = 0;
562 sum_xxw = 0.0;
563 sum_yxw = 0.0;
564 }
571 n += suff_stat.n;
572 sum_xxw += suff_stat.sum_xxw;
573 sum_yxw += suff_stat.sum_yxw;
574 }
582 n = lhs.n + rhs.n;
583 sum_xxw = lhs.sum_xxw + rhs.sum_xxw;
584 sum_yxw = lhs.sum_yxw + rhs.sum_yxw;
585 }
593 n = lhs.n - rhs.n;
594 sum_xxw = lhs.sum_xxw - rhs.sum_xxw;
595 sum_yxw = lhs.sum_yxw - rhs.sum_yxw;
596 }
602 bool SampleGreaterThan(data_size_t threshold) {
603 return n > threshold;
604 }
610 bool SampleGreaterThanEqual(data_size_t threshold) {
611 return n >= threshold;
612 }
616 data_size_t SampleSize() {
617 return n;
618 }
619};
620
623 public:
624 GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
640 double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
647 double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
654 double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
666 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
667 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
668 void SetScale(double tau) {tau_ = tau;}
669 inline bool RequiresBasis() {return true;}
670 private:
671 double tau_;
672 UnivariateNormalSampler normal_sampler_;
673};
674
677 public:
678 data_size_t n;
679 int p;
680 Eigen::MatrixXd XtWX;
681 Eigen::MatrixXd ytWX;
688 n = 0;
689 XtWX = Eigen::MatrixXd::Zero(basis_dim, basis_dim);
690 ytWX = Eigen::MatrixXd::Zero(1, basis_dim);
691 p = basis_dim;
692 }
702 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
703 n += 1;
704 if (dataset.HasVarWeights()) {
705 for (int i = 0; i < p; i++) {
706 ytWX(0,i) += outcome(row_idx, 0) * dataset.BasisValue(row_idx, i) / dataset.VarWeightValue(row_idx);
707 for (int j = 0; j < p; j++) {
708 XtWX(i,j) += dataset.BasisValue(row_idx, i) * dataset.BasisValue(row_idx, j) / dataset.VarWeightValue(row_idx);
709 }
710 }
711 } else {
712 for (int i = 0; i < p; i++) {
713 ytWX(0,i) += outcome(row_idx, 0) * dataset.BasisValue(row_idx, i);
714 for (int j = 0; j < p; j++) {
715 XtWX(i,j) += dataset.BasisValue(row_idx, i) * dataset.BasisValue(row_idx, j);
716 }
717 }
718 }
719 }
724 n = 0;
725 for (int i = 0; i < p; i++) {
726 ytWX(0, i) = 0.0;
727 for (int j = 0; j < p; j++) {
728 XtWX(i, j) = 0.0;
729 }
730 }
731 }
738 n += suff_stat.n;
739 XtWX += suff_stat.XtWX;
740 ytWX += suff_stat.ytWX;
741 }
749 n = lhs.n + rhs.n;
750 XtWX = lhs.XtWX + rhs.XtWX;
751 ytWX = lhs.ytWX + rhs.ytWX;
752 }
760 n = lhs.n - rhs.n;
761 XtWX = lhs.XtWX - rhs.XtWX;
762 ytWX = lhs.ytWX - rhs.ytWX;
763 }
769 bool SampleGreaterThan(data_size_t threshold) {
770 return n > threshold;
771 }
777 bool SampleGreaterThanEqual(data_size_t threshold) {
778 return n >= threshold;
779 }
783 data_size_t SampleSize() {
784 return n;
785 }
786};
787
790 public:
796 GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();}
819 Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
826 Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
838 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
839 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
840 void SetScale(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0;}
841 inline bool RequiresBasis() {return true;}
842 private:
843 Eigen::MatrixXd Sigma_0_;
844 MultivariateNormalSampler multivariate_normal_sampler_;
845};
846
849 public:
850 data_size_t n;
851 double weighted_sum_ei;
853 n = 0;
854 weighted_sum_ei = 0.0;
855 }
865 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
866 n += 1;
867 weighted_sum_ei += std::exp(std::log(outcome(row_idx)*outcome(row_idx)) - tracker.GetSamplePrediction(row_idx) + tracker.GetTreeSamplePrediction(row_idx, tree_idx));
868 }
873 n = 0;
874 weighted_sum_ei = 0.0;
875 }
882 n += suff_stat.n;
883 weighted_sum_ei += suff_stat.weighted_sum_ei;
884 }
892 n = lhs.n + rhs.n;
893 weighted_sum_ei = lhs.weighted_sum_ei + rhs.weighted_sum_ei;
894 }
902 n = lhs.n - rhs.n;
903 weighted_sum_ei = lhs.weighted_sum_ei - rhs.weighted_sum_ei;
904 }
910 bool SampleGreaterThan(data_size_t threshold) {
911 return n > threshold;
912 }
918 bool SampleGreaterThanEqual(data_size_t threshold) {
919 return n >= threshold;
920 }
924 data_size_t SampleSize() {
925 return n;
926 }
927};
928
931 public:
932 LogLinearVarianceLeafModel(double a, double b) {a_ = a; b_ = b; gamma_sampler_ = GammaSampler();}
941 double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat& left_stat, LogLinearVarianceSuffStat& right_stat, double global_variance);
948 double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance);
949 double SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance);
956 double PosteriorParameterShape(LogLinearVarianceSuffStat& suff_stat, double global_variance);
963 double PosteriorParameterScale(LogLinearVarianceSuffStat& suff_stat, double global_variance);
975 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
976 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
977 void SetPriorShape(double a) {a_ = a;}
978 void SetPriorRate(double b) {b_ = b;}
979 inline bool RequiresBasis() {return false;}
980 private:
981 double a_;
982 double b_;
983 GammaSampler gamma_sampler_;
984};
985
986
989 public:
990 data_size_t n;
991 double sum_Y_less_K;
992 double other_sum;
993
998 n = 0;
999 sum_Y_less_K = 0.0;
1000 other_sum = 0.0;
1001 }
1002
1012 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
1013 n += 1;
1014
1015 // Get ordinal outcome value for this observation
1016 unsigned int y = static_cast<unsigned int>(outcome(row_idx));
1017
1018 // Get auxiliary data from tracker (assuming types: 0=latents Z, 1=forest predictions, 2=cutpoints gamma, 3=cumsum exp of gamma)
1019 double Z = dataset.GetAuxiliaryDataValue(0, row_idx); // latent variables Z
1020 double lambda_minus = dataset.GetAuxiliaryDataValue(1, row_idx); // forest predictions excluding current tree
1021
1022 // Get cutpoints gamma and cumulative sum of exp(gamma)
1023 const std::vector<double>& gamma = dataset.GetAuxiliaryDataVectorConst(2); // cutpoints gamma
1024 const std::vector<double>& seg = dataset.GetAuxiliaryDataVectorConst(3); // cumsum exp of gamma
1025
1026 int K = gamma.size() + 1; // Number of ordinal categories
1027
1028 if (y == K - 1) {
1029 other_sum += std::exp(lambda_minus) * seg[y]; // checked and it's correct
1030 } else {
1031 sum_Y_less_K += 1.0;
1032 other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]); // checked and it's correct
1033 }
1034 }
1035
1040 n = 0;
1041 sum_Y_less_K = 0.0;
1042 other_sum = 0.0;
1043 }
1044
1051 n += suff_stat.n;
1052 sum_Y_less_K += suff_stat.sum_Y_less_K;
1053 other_sum += suff_stat.other_sum;
1054 }
1055
1063 n = lhs.n + rhs.n;
1064 sum_Y_less_K = lhs.sum_Y_less_K + rhs.sum_Y_less_K;
1065 other_sum = lhs.other_sum + rhs.other_sum;
1066 }
1067
1075 n = lhs.n - rhs.n;
1076 sum_Y_less_K = lhs.sum_Y_less_K - rhs.sum_Y_less_K;
1077 other_sum = lhs.other_sum - rhs.other_sum;
1078 }
1079
1085 bool SampleGreaterThan(data_size_t threshold) {
1086 return n > threshold;
1087 }
1088
1094 bool SampleGreaterThanEqual(data_size_t threshold) {
1095 return n >= threshold;
1096 }
1097
1101 data_size_t SampleSize() {
1102 return n;
1103 }
1104};
1105
1108 public:
1116 CloglogOrdinalLeafModel(double a, double b) {
1117 a_ = a;
1118 b_ = b;
1119 gamma_sampler_ = GammaSampler();
1120 }
1122
1126 double SplitLogMarginalLikelihood(CloglogOrdinalSuffStat& left_stat, CloglogOrdinalSuffStat& right_stat, double global_variance);
1127
1131 double NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance);
1132
1136 double SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance);
1137
1141 double PosteriorParameterShape(CloglogOrdinalSuffStat& suff_stat, double global_variance);
1142
1146 double PosteriorParameterRate(CloglogOrdinalSuffStat& suff_stat, double global_variance);
1147
1152 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
1153 inline bool RequiresBasis() {return false;}
1154
1155 private:
1156 double a_;
1157 double b_;
1158 GammaSampler gamma_sampler_;
1159};
1160
1173
1187
1188template<typename SuffStatType, typename... SuffStatConstructorArgs>
1189static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) {
1190 return SuffStatType(leaf_suff_stat_args...);
1191}
1192
1193template<typename LeafModelType, typename... LeafModelConstructorArgs>
1194static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_model_args) {
1195 return LeafModelType(leaf_model_args...);
1196}
1197
1204static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) {
1205 if (model_type == kConstantLeafGaussian) {
1206 return createSuffStat<GaussianConstantSuffStat>();
1207 } else if (model_type == kUnivariateRegressionLeafGaussian) {
1208 return createSuffStat<GaussianUnivariateRegressionSuffStat>();
1209 } else if (model_type == kMultivariateRegressionLeafGaussian) {
1210 return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim);
1211 } else if (model_type == kLogLinearVariance) {
1212 return createSuffStat<LogLinearVarianceSuffStat>();
1213 } else {
1214 return createSuffStat<CloglogOrdinalSuffStat>();
1215 }
1216}
1217
1227static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) {
1228 if (model_type == kConstantLeafGaussian) {
1229 return createLeafModel<GaussianConstantLeafModel, double>(tau);
1230 } else if (model_type == kUnivariateRegressionLeafGaussian) {
1231 return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau);
1232 } else if (model_type == kMultivariateRegressionLeafGaussian) {
1233 return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0);
1234 } else if (model_type == kLogLinearVariance) {
1235 return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b);
1236 } else {
1237 return createLeafModel<CloglogOrdinalLeafModel, double, double>(a, b);
1238 }
1239}
1240
1241template<typename SuffStatType, typename... SuffStatConstructorArgs>
1242static inline void AccumulateSuffStatProposed(
1243 SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
1244 ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads,
1245 SuffStatConstructorArgs&... suff_stat_args
1246) {
1247 // Determine the position of the node's indices in the forest tracking data structure
1248 int node_begin_index = tracker.UnsortedNodeBegin(tree_num, leaf_num);
1249 int node_end_index = tracker.UnsortedNodeEnd(tree_num, leaf_num);
1250
1251 // Extract pointer to the feature partition for tree_num
1252 UnsortedNodeSampleTracker* unsorted_node_sample_tracker = tracker.GetUnsortedNodeSampleTracker();
1253 FeatureUnsortedPartition* feature_partition = unsorted_node_sample_tracker->GetFeaturePartition(tree_num);
1254
1255 // Determine the number of threads to use
1256 int chunk_size = (node_end_index - node_begin_index) / num_threads;
1257 if (chunk_size < 100) {
1258 num_threads = 1;
1259 chunk_size = node_end_index - node_begin_index;
1260 }
1261
1262 if (num_threads > 1) {
1263 // Split the work into num_threads chunks
1264 std::vector<std::pair<int, int>> thread_ranges(num_threads);
1265 std::vector<SuffStatType> thread_suff_stats_node;
1266 std::vector<SuffStatType> thread_suff_stats_left;
1267 std::vector<SuffStatType> thread_suff_stats_right;
1268 for (int i = 0; i < num_threads; i++) {
1269 thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size,
1270 node_begin_index + (i + 1) * chunk_size);
1271 thread_suff_stats_node.emplace_back(suff_stat_args...);
1272 thread_suff_stats_left.emplace_back(suff_stat_args...);
1273 thread_suff_stats_right.emplace_back(suff_stat_args...);
1274 }
1275
1276 // Accumulate sufficient statistics
1277 StochTree::ParallelFor(0, num_threads, num_threads, [&](int i) {
1278 int start_idx = thread_ranges[i].first;
1279 int end_idx = thread_ranges[i].second;
1280 for (int idx = start_idx; idx < end_idx; idx++) {
1281 int obs_num = feature_partition->indices_[idx];
1282 double feature_value = dataset.CovariateValue(obs_num, split_feature);
1283 thread_suff_stats_node[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1284 if (split.SplitTrue(feature_value)) {
1285 thread_suff_stats_left[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1286 } else {
1287 thread_suff_stats_right[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1288 }
1289 }
1290 });
1291
1292 // Combine the thread-local sufficient statistics
1293 for (int i = 0; i < num_threads; i++) {
1294 node_suff_stat.AddSuffStatInplace(thread_suff_stats_node[i]);
1295 left_suff_stat.AddSuffStatInplace(thread_suff_stats_left[i]);
1296 right_suff_stat.AddSuffStatInplace(thread_suff_stats_right[i]);
1297 }
1298 } else {
1299 for (int idx = node_begin_index; idx < node_end_index; idx++) {
1300 int obs_num = feature_partition->indices_[idx];
1301 double feature_value = dataset.CovariateValue(obs_num, split_feature);
1302 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1303 if (split.SplitTrue(feature_value)) {
1304 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1305 } else {
1306 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1307 }
1308 }
1309 }
1310}
1311
1312template<typename SuffStatType>
1313static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
1314 ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) {
1315 // Acquire iterators
1316 auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id);
1317 auto left_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, left_node_id);
1318 auto right_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, right_node_id);
1319 auto right_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, right_node_id);
1320
1321 // Accumulate sufficient statistics for the left and split nodes
1322 for (auto i = left_node_begin_iter; i != left_node_end_iter; i++) {
1323 auto idx = *i;
1324 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1325 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1326 }
1327
1328 // Accumulate sufficient statistics for the right and split nodes
1329 for (auto i = right_node_begin_iter; i != right_node_end_iter; i++) {
1330 auto idx = *i;
1331 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1332 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1333 }
1334}
1335
1336template<typename SuffStatType, bool sorted>
1337static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, int tree_num, int node_id) {
1338 // Acquire iterators
1339 std::vector<data_size_t>::iterator node_begin_iter;
1340 std::vector<data_size_t>::iterator node_end_iter;
1341 if (sorted) {
1342 // Default to the first feature if we're using the presort tracker
1343 node_begin_iter = tracker.SortedNodeBeginIterator(node_id, 0);
1344 node_end_iter = tracker.SortedNodeEndIterator(node_id, 0);
1345 } else {
1346 node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
1347 node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
1348 }
1349
1350 // Accumulate sufficient statistics
1351 for (auto i = node_begin_iter; i != node_end_iter; i++) {
1352 auto idx = *i;
1353 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1354 }
1355}
1356
1357template<typename SuffStatType>
1358static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container,
1359 ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id,
1360 int feature_num, int cutpoint_num) {
1361 // Acquire iterators
1362 auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num);
1363 auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num);
1364
1365 // Determine node start point
1366 data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num);
1367
1368 // Determine cutpoint bin start and end points
1369 data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num, feature_num);
1370 data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_num, feature_num);
1371 data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num + 1, feature_num);
1372
1373 // Cutpoint specific iterators
1374 // TODO: fix the hack of having to subtract off node_begin, probably by cleaning up the CutpointGridContainer interface
1375 auto cutpoint_begin_iter = node_begin_iter + (current_bin_begin - node_begin);
1376 auto cutpoint_end_iter = node_begin_iter + (next_bin_begin - node_begin);
1377
1378 // Accumulate sufficient statistics
1379 for (auto i = cutpoint_begin_iter; i != cutpoint_end_iter; i++) {
1380 auto idx = *i;
1381 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1382 }
1383}
1384
// end of leaf_model_group
1386
1387} // namespace StochTree
1388
1389#endif // STOCHTREE_LEAF_MODEL_H_
Marginal likelihood and posterior computation for complementary log-log ordinal BART model.
Definition leaf_model.h:1107
CloglogOrdinalLeafModel(double a, double b)
Construct a new CloglogOrdinalLeafModel object.
Definition leaf_model.h:1116
double SplitLogMarginalLikelihood(CloglogOrdinalSuffStat &left_stat, CloglogOrdinalSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat &suff_stat, double global_variance)
Helper function to compute log marginal likelihood from sufficient statistics.
double PosteriorParameterShape(CloglogOrdinalSuffStat &suff_stat, double global_variance)
Posterior shape parameter for leaf node log-gamma distribution.
double PosteriorParameterRate(CloglogOrdinalSuffStat &suff_stat, double global_variance)
Posterior rate parameter for leaf node log-gamma distribution.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for complementary log-log ordinal BART model.
Definition leaf_model.h:988
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:1085
void AddSuffStat(CloglogOrdinalSuffStat &lhs, CloglogOrdinalSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:1062
CloglogOrdinalSuffStat()
Construct a new CloglogOrdinalSuffStat object, setting all sufficient statistics to zero.
Definition leaf_model.h:997
void SubtractSuffStat(CloglogOrdinalSuffStat &lhs, CloglogOrdinalSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:1074
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:1039
void AddSuffStatInplace(CloglogOrdinalSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:1050
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:1101
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:1094
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:1012
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:193
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
double BasisValue(data_size_t row, int col)
Returns a dataset's basis value stored at (row, col)
Definition data.h:371
double VarWeightValue(data_size_t row)
Returns a dataset's variance weight stored at element row
Definition data.h:377
bool HasVarWeights()
Whether or not a ForestDataset has (yet) loaded variance weights.
Definition data.h:351
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:46
Definition gamma_sampler.h:10
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:458
double SplitLogMarginalLikelihood(GaussianConstantSuffStat &left_stat, GaussianConstantSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterVariance(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
bool RequiresBasis()
Whether this model requires a basis vector for posterior inference and prediction.
Definition leaf_model.h:518
GaussianConstantLeafModel(double tau)
Construct a new GaussianConstantLeafModel object.
Definition leaf_model.h:465
double PosteriorParameterMean(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
void SetScale(double tau)
Set a new value for the leaf node scale parameter.
Definition leaf_model.h:514
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:361
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:452
GaussianConstantSuffStat()
Construct a new GaussianConstantSuffStat object, setting all sufficient statistics to zero.
Definition leaf_model.h:369
void SubtractSuffStat(GaussianConstantSuffStat &lhs, GaussianConstantSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:428
void AddSuffStat(GaussianConstantSuffStat &lhs, GaussianConstantSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:417
void AddSuffStatInplace(GaussianConstantSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:406
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:446
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:438
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:383
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:396
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:789
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &left_stat, GaussianMultivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd &Sigma_0)
Construct a new GaussianMultivariateRegressionLeafModel object.
Definition leaf_model.h:796
Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:676
void AddSuffStatInplace(GaussianMultivariateRegressionSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:737
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:769
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:783
void AddSuffStat(GaussianMultivariateRegressionSuffStat &lhs, GaussianMultivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:748
void SubtractSuffStat(GaussianMultivariateRegressionSuffStat &lhs, GaussianMultivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:759
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:777
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:723
GaussianMultivariateRegressionSuffStat(int basis_dim)
Construct a new GaussianMultivariateRegressionSuffStat object.
Definition leaf_model.h:687
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:702
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:622
double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &left_stat, GaussianUnivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:525
GaussianUnivariateRegressionSuffStat()
Construct a new GaussianUnivariateRegressionSuffStat object, setting all sufficient statistics to zer...
Definition leaf_model.h:533
void SubtractSuffStat(GaussianUnivariateRegressionSuffStat &lhs, GaussianUnivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:592
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:602
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:610
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:616
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:560
void AddSuffStat(GaussianUnivariateRegressionSuffStat &lhs, GaussianUnivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:581
void AddSuffStatInplace(GaussianUnivariateRegressionSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:570
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:547
Marginal likelihood and posterior computation for heteroskedastic log-linear variance model.
Definition leaf_model.h:930
double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
double PosteriorParameterShape(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior shape parameter.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat &left_stat, LogLinearVarianceSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterScale(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior scale parameter.
Sufficient statistic and associated operations for heteroskedastic log-linear variance model.
Definition leaf_model.h:848
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:918
void SubtractSuffStat(LogLinearVarianceSuffStat &lhs, LogLinearVarianceSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:901
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:872
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:924
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:865
void AddSuffStatInplace(LogLinearVarianceSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:881
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:910
void AddSuffStat(LogLinearVarianceSuffStat &lhs, LogLinearVarianceSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:891
Definition normal_sampler.h:25
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:31
Decision tree data structure.
Definition tree.h:66
Definition normal_sampler.h:13
static SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim=0)
Factory function that creates a new SuffStat object for the specified model type.
Definition leaf_model.h:1204
std::variant< GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, GaussianMultivariateRegressionSuffStat, LogLinearVarianceSuffStat, CloglogOrdinalSuffStat > SuffStatVariant
Unifying layer for disparate sufficient statistic class types.
Definition leaf_model.h:1172
ModelType
Leaf models for the forest sampler:
Definition leaf_model.h:352
static LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd &Sigma0, double a, double b)
Factory function that creates a new LeafModel object for the specified model type.
Definition leaf_model.h:1227
std::variant< GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, GaussianMultivariateRegressionLeafModel, LogLinearVarianceLeafModel, CloglogOrdinalLeafModel > LeafModelVariant
Unifying layer for disparate leaf model class types.
Definition leaf_model.h:1186
A collection of random number generation utilities.
Definition category_tracker.h:36