5#ifndef STOCHTREE_LEAF_MODEL_H_
6#define STOCHTREE_LEAF_MODEL_H_
9#include <stochtree/cutpoint_candidates.h>
10#include <stochtree/data.h>
11#include <stochtree/gamma_sampler.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/log.h>
14#include <stochtree/meta.h>
15#include <stochtree/normal_sampler.h>
16#include <stochtree/openmp_utils.h>
17#include <stochtree/partition_tracker.h>
18#include <stochtree/prior.h>
19#include <stochtree/tree.h>
353 kConstantLeafGaussian,
354 kUnivariateRegressionLeafGaussian,
355 kMultivariateRegressionLeafGaussian,
390 sum_yw += outcome(row_idx, 0);
408 sum_w += suff_stat.sum_w;
409 sum_yw += suff_stat.sum_yw;
419 sum_w = lhs.sum_w + rhs.sum_w;
420 sum_yw = lhs.sum_yw + rhs.sum_yw;
430 sum_w = lhs.sum_w - rhs.sum_w;
431 sum_yw = lhs.sum_yw - rhs.sum_yw;
439 return n > threshold;
447 return n >= threshold;
554 sum_yxw += outcome(row_idx, 0)*dataset.
BasisValue(row_idx, 0);
572 sum_xxw += suff_stat.sum_xxw;
573 sum_yxw += suff_stat.sum_yxw;
583 sum_xxw = lhs.sum_xxw + rhs.sum_xxw;
584 sum_yxw = lhs.sum_yxw + rhs.sum_yxw;
594 sum_xxw = lhs.sum_xxw - rhs.sum_xxw;
595 sum_yxw = lhs.sum_yxw - rhs.sum_yxw;
603 return n > threshold;
611 return n >= threshold;
668 void SetScale(
double tau) {tau_ = tau;}
669 inline bool RequiresBasis() {
return true;}
672 UnivariateNormalSampler normal_sampler_;
680 Eigen::MatrixXd XtWX;
681 Eigen::MatrixXd ytWX;
689 XtWX = Eigen::MatrixXd::Zero(basis_dim, basis_dim);
690 ytWX = Eigen::MatrixXd::Zero(1, basis_dim);
708 XtWX += dataset.
GetBasis()(row_idx, Eigen::all).transpose()*dataset.
GetBasis()(row_idx, Eigen::all);
709 ytWX += (outcome(row_idx, 0)*(dataset.
GetBasis()(row_idx, Eigen::all)));
717 XtWX = Eigen::MatrixXd::Zero(p, p);
718 ytWX = Eigen::MatrixXd::Zero(1, p);
727 XtWX += suff_stat.XtWX;
728 ytWX += suff_stat.ytWX;
738 XtWX = lhs.XtWX + rhs.XtWX;
739 ytWX = lhs.ytWX + rhs.ytWX;
749 XtWX = lhs.XtWX - rhs.XtWX;
750 ytWX = lhs.ytWX - rhs.ytWX;
758 return n > threshold;
766 return n >= threshold;
828 void SetScale(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0;}
829 inline bool RequiresBasis() {
return true;}
831 Eigen::MatrixXd Sigma_0_;
832 MultivariateNormalSampler multivariate_normal_sampler_;
839 double weighted_sum_ei;
842 weighted_sum_ei = 0.0;
855 weighted_sum_ei += std::exp(std::log(outcome(row_idx)*outcome(row_idx)) - tracker.GetSamplePrediction(row_idx) + tracker.GetTreeSamplePrediction(row_idx, tree_idx));
862 weighted_sum_ei = 0.0;
871 weighted_sum_ei += suff_stat.weighted_sum_ei;
881 weighted_sum_ei = lhs.weighted_sum_ei + rhs.weighted_sum_ei;
891 weighted_sum_ei = lhs.weighted_sum_ei - rhs.weighted_sum_ei;
899 return n > threshold;
907 return n >= threshold;
965 void SetPriorShape(
double a) {a_ = a;}
966 void SetPriorRate(
double b) {b_ = b;}
967 inline bool RequiresBasis() {
return false;}
971 GammaSampler gamma_sampler_;
1004 unsigned int y =
static_cast<unsigned int>(outcome(row_idx));
1007 double Z = dataset.GetAuxiliaryDataValue(0, row_idx);
1008 double lambda_minus = dataset.GetAuxiliaryDataValue(1, row_idx);
1011 const std::vector<double>& gamma = dataset.GetAuxiliaryDataVectorConst(2);
1012 const std::vector<double>& seg = dataset.GetAuxiliaryDataVectorConst(3);
1014 int K = gamma.size() + 1;
1017 other_sum += std::exp(lambda_minus) * seg[y];
1019 sum_Y_less_K += 1.0;
1020 other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]);
1040 sum_Y_less_K += suff_stat.sum_Y_less_K;
1041 other_sum += suff_stat.other_sum;
1052 sum_Y_less_K = lhs.sum_Y_less_K + rhs.sum_Y_less_K;
1053 other_sum = lhs.other_sum + rhs.other_sum;
1064 sum_Y_less_K = lhs.sum_Y_less_K - rhs.sum_Y_less_K;
1065 other_sum = lhs.other_sum - rhs.other_sum;
1074 return n > threshold;
1083 return n >= threshold;
1141 inline bool RequiresBasis() {
return false;}
1176template<
typename SuffStatType,
typename... SuffStatConstructorArgs>
1177static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) {
1178 return SuffStatType(leaf_suff_stat_args...);
1181template<
typename LeafModelType,
typename... LeafModelConstructorArgs>
1182static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_model_args) {
1183 return LeafModelType(leaf_model_args...);
1193 if (model_type == kConstantLeafGaussian) {
1194 return createSuffStat<GaussianConstantSuffStat>();
1195 }
else if (model_type == kUnivariateRegressionLeafGaussian) {
1196 return createSuffStat<GaussianUnivariateRegressionSuffStat>();
1197 }
else if (model_type == kMultivariateRegressionLeafGaussian) {
1198 return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim);
1199 }
else if (model_type == kLogLinearVariance) {
1200 return createSuffStat<LogLinearVarianceSuffStat>();
1202 return createSuffStat<CloglogOrdinalSuffStat>();
1216 if (model_type == kConstantLeafGaussian) {
1217 return createLeafModel<GaussianConstantLeafModel, double>(tau);
1218 }
else if (model_type == kUnivariateRegressionLeafGaussian) {
1219 return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau);
1220 }
else if (model_type == kMultivariateRegressionLeafGaussian) {
1221 return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0);
1222 }
else if (model_type == kLogLinearVariance) {
1223 return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b);
1225 return createLeafModel<CloglogOrdinalLeafModel, double, double>(a, b);
1229template<
typename SuffStatType,
typename... SuffStatConstructorArgs>
1230static inline void AccumulateSuffStatProposed(
1231 SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
1232 ColumnVector& residual,
double global_variance, TreeSplit& split,
int tree_num,
int leaf_num,
int split_feature,
int num_threads,
1233 SuffStatConstructorArgs&... suff_stat_args
1236 int node_begin_index = tracker.UnsortedNodeBegin(tree_num, leaf_num);
1237 int node_end_index = tracker.UnsortedNodeEnd(tree_num, leaf_num);
1240 UnsortedNodeSampleTracker* unsorted_node_sample_tracker = tracker.GetUnsortedNodeSampleTracker();
1241 FeatureUnsortedPartition* feature_partition = unsorted_node_sample_tracker->GetFeaturePartition(tree_num);
1244 int chunk_size = (node_end_index - node_begin_index) / num_threads;
1245 if (chunk_size < 100) {
1247 chunk_size = node_end_index - node_begin_index;
1250 if (num_threads > 1) {
1252 std::vector<std::pair<int, int>> thread_ranges(num_threads);
1253 std::vector<SuffStatType> thread_suff_stats_node;
1254 std::vector<SuffStatType> thread_suff_stats_left;
1255 std::vector<SuffStatType> thread_suff_stats_right;
1256 for (
int i = 0; i < num_threads; i++) {
1257 thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size,
1258 node_begin_index + (i + 1) * chunk_size);
1259 thread_suff_stats_node.emplace_back(suff_stat_args...);
1260 thread_suff_stats_left.emplace_back(suff_stat_args...);
1261 thread_suff_stats_right.emplace_back(suff_stat_args...);
1265 StochTree::ParallelFor(0, num_threads, num_threads, [&](
int i) {
1266 int start_idx = thread_ranges[i].first;
1267 int end_idx = thread_ranges[i].second;
1268 for (
int idx = start_idx; idx < end_idx; idx++) {
1269 int obs_num = feature_partition->indices_[idx];
1270 double feature_value = dataset.CovariateValue(obs_num, split_feature);
1271 thread_suff_stats_node[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1272 if (split.SplitTrue(feature_value)) {
1273 thread_suff_stats_left[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1275 thread_suff_stats_right[i].IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1281 for (
int i = 0; i < num_threads; i++) {
1282 node_suff_stat.AddSuffStatInplace(thread_suff_stats_node[i]);
1283 left_suff_stat.AddSuffStatInplace(thread_suff_stats_left[i]);
1284 right_suff_stat.AddSuffStatInplace(thread_suff_stats_right[i]);
1287 for (
int idx = node_begin_index; idx < node_end_index; idx++) {
1288 int obs_num = feature_partition->indices_[idx];
1289 double feature_value = dataset.CovariateValue(obs_num, split_feature);
1290 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1291 if (split.SplitTrue(feature_value)) {
1292 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1294 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, obs_num, tree_num);
1300template<
typename SuffStatType>
1301static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
1302 ColumnVector& residual,
double global_variance,
int tree_num,
int split_node_id,
int left_node_id,
int right_node_id) {
1304 auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id);
1305 auto left_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, left_node_id);
1306 auto right_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, right_node_id);
1307 auto right_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, right_node_id);
1310 for (
auto i = left_node_begin_iter; i != left_node_end_iter; i++) {
1312 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1313 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1317 for (
auto i = right_node_begin_iter; i != right_node_end_iter; i++) {
1319 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1320 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1324template<
typename SuffStatType,
bool sorted>
1325static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual,
int tree_num,
int node_id) {
1327 std::vector<data_size_t>::iterator node_begin_iter;
1328 std::vector<data_size_t>::iterator node_end_iter;
1331 node_begin_iter = tracker.SortedNodeBeginIterator(node_id, 0);
1332 node_end_iter = tracker.SortedNodeEndIterator(node_id, 0);
1334 node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
1335 node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
1339 for (
auto i = node_begin_iter; i != node_end_iter; i++) {
1341 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
1345template<
typename SuffStatType>
1346static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container,
1347 ForestDataset& dataset, ColumnVector& residual,
double global_variance,
int tree_num,
int node_id,
1348 int feature_num,
int cutpoint_num) {
1350 auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num);
1351 auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num);
1354 data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num);
1357 data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num, feature_num);
1358 data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_num, feature_num);
1359 data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num + 1, feature_num);
1363 auto cutpoint_begin_iter = node_begin_iter + (current_bin_begin - node_begin);
1364 auto cutpoint_end_iter = node_begin_iter + (next_bin_begin - node_begin);
1367 for (
auto i = cutpoint_begin_iter; i != cutpoint_end_iter; i++) {
1369 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
Marginal likelihood and posterior computation for complementary log-log ordinal BART model.
Definition leaf_model.h:1095
CloglogOrdinalLeafModel(double a, double b)
Construct a new CloglogOrdinalLeafModel object.
Definition leaf_model.h:1104
double SplitLogMarginalLikelihood(CloglogOrdinalSuffStat &left_stat, CloglogOrdinalSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat &suff_stat, double global_variance)
Helper function to compute log marginal likelihood from sufficient statistics.
double PosteriorParameterShape(CloglogOrdinalSuffStat &suff_stat, double global_variance)
Posterior shape parameter for leaf node log-gamma distribution.
double PosteriorParameterRate(CloglogOrdinalSuffStat &suff_stat, double global_variance)
Posterior rate parameter for leaf node log-gamma distribution.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for complementary log-log ordinal BART model.
Definition leaf_model.h:976
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:1073
void AddSuffStat(CloglogOrdinalSuffStat &lhs, CloglogOrdinalSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:1050
CloglogOrdinalSuffStat()
Construct a new CloglogOrdinalSuffStat object, setting all sufficient statistics to zero.
Definition leaf_model.h:985
void SubtractSuffStat(CloglogOrdinalSuffStat &lhs, CloglogOrdinalSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:1062
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:1027
void AddSuffStatInplace(CloglogOrdinalSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:1038
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:1089
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:1082
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:1000
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:193
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
double BasisValue(data_size_t row, int col)
Returns a dataset's basis value stored at (row, col)
Definition data.h:371
Eigen::MatrixXd & GetBasis()
Return a reference to the raw Eigen::MatrixXd storing the basis data.
Definition data.h:389
double VarWeightValue(data_size_t row)
Returns a dataset's variance weight stored at element row
Definition data.h:377
bool HasVarWeights()
Whether or not a ForestDataset has (yet) loaded variance weights.
Definition data.h:351
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:46
Definition gamma_sampler.h:10
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:458
double SplitLogMarginalLikelihood(GaussianConstantSuffStat &left_stat, GaussianConstantSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterVariance(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
bool RequiresBasis()
Whether this model requires a basis vector for posterior inference and prediction.
Definition leaf_model.h:518
GaussianConstantLeafModel(double tau)
Construct a new GaussianConstantLeafModel object.
Definition leaf_model.h:465
double PosteriorParameterMean(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
void SetScale(double tau)
Set a new value for the leaf node scale parameter.
Definition leaf_model.h:514
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:361
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:452
GaussianConstantSuffStat()
Construct a new GaussianConstantSuffStat object, setting all sufficient statistics to zero.
Definition leaf_model.h:369
void SubtractSuffStat(GaussianConstantSuffStat &lhs, GaussianConstantSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:428
void AddSuffStat(GaussianConstantSuffStat &lhs, GaussianConstantSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:417
void AddSuffStatInplace(GaussianConstantSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:406
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:446
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:438
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:383
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:396
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:777
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &left_stat, GaussianMultivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd &Sigma_0)
Construct a new GaussianMultivariateRegressionLeafModel object.
Definition leaf_model.h:784
Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:676
void AddSuffStatInplace(GaussianMultivariateRegressionSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:725
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:757
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:771
void AddSuffStat(GaussianMultivariateRegressionSuffStat &lhs, GaussianMultivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:736
void SubtractSuffStat(GaussianMultivariateRegressionSuffStat &lhs, GaussianMultivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:747
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:765
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:715
GaussianMultivariateRegressionSuffStat(int basis_dim)
Construct a new GaussianMultivariateRegressionSuffStat object.
Definition leaf_model.h:687
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:702
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:622
double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &left_stat, GaussianUnivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:525
GaussianUnivariateRegressionSuffStat()
Construct a new GaussianUnivariateRegressionSuffStat object, setting all sufficient statistics to zer...
Definition leaf_model.h:533
void SubtractSuffStat(GaussianUnivariateRegressionSuffStat &lhs, GaussianUnivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:592
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:602
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:610
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:616
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:560
void AddSuffStat(GaussianUnivariateRegressionSuffStat &lhs, GaussianUnivariateRegressionSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:581
void AddSuffStatInplace(GaussianUnivariateRegressionSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:570
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:547
Marginal likelihood and posterior computation for heteroskedastic log-linear variance model.
Definition leaf_model.h:918
double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
double PosteriorParameterShape(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior shape parameter.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat &left_stat, LogLinearVarianceSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterScale(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior scale parameter.
Sufficient statistic and associated operations for heteroskedastic log-linear variance model.
Definition leaf_model.h:836
bool SampleGreaterThanEqual(data_size_t threshold)
Check whether accumulated sample size, n, is greater than or equal to some threshold.
Definition leaf_model.h:906
void SubtractSuffStat(LogLinearVarianceSuffStat &lhs, LogLinearVarianceSuffStat &rhs)
Set the value of each sufficient statistic to the difference between the values provided by lhs and t...
Definition leaf_model.h:889
void ResetSuffStat()
Reset all of the sufficient statistics to zero.
Definition leaf_model.h:860
data_size_t SampleSize()
Return the sample size accumulated by a sufficient stat object.
Definition leaf_model.h:912
void IncrementSuffStat(ForestDataset &dataset, Eigen::VectorXd &outcome, ForestTracker &tracker, data_size_t row_idx, int tree_idx)
Accumulate data from observation row_idx into the sufficient statistics.
Definition leaf_model.h:853
void AddSuffStatInplace(LogLinearVarianceSuffStat &suff_stat)
Increment the value of each sufficient statistic by the values provided by suff_stat
Definition leaf_model.h:869
bool SampleGreaterThan(data_size_t threshold)
Check whether accumulated sample size, n, is greater than some threshold.
Definition leaf_model.h:898
void AddSuffStat(LogLinearVarianceSuffStat &lhs, LogLinearVarianceSuffStat &rhs)
Set the value of each sufficient statistic to the sum of the values provided by lhs and rhs
Definition leaf_model.h:879
Definition normal_sampler.h:25
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:31
Decision tree data structure.
Definition tree.h:66
Definition normal_sampler.h:13
static SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim=0)
Factory function that creates a new SuffStat object for the specified model type.
Definition leaf_model.h:1192
std::variant< GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, GaussianMultivariateRegressionSuffStat, LogLinearVarianceSuffStat, CloglogOrdinalSuffStat > SuffStatVariant
Unifying layer for disparate sufficient statistic class types.
Definition leaf_model.h:1160
ModelType
Leaf models for the forest sampler:
Definition leaf_model.h:352
static LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd &Sigma0, double a, double b)
Factory function that creates a new LeafModel object for the specified model type.
Definition leaf_model.h:1215
std::variant< GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, GaussianMultivariateRegressionLeafModel, LogLinearVarianceLeafModel, CloglogOrdinalLeafModel > LeafModelVariant
Unifying layer for disparate leaf model class types.
Definition leaf_model.h:1174
A collection of random number generation utilities.
Definition category_tracker.h:36