StochTree 0.0.1
Loading...
Searching...
No Matches
leaf_model.h
1
5#ifndef STOCHTREE_LEAF_MODEL_H_
6#define STOCHTREE_LEAF_MODEL_H_
7
8#include <Eigen/Dense>
9#include <stochtree/cutpoint_candidates.h>
10#include <stochtree/data.h>
11#include <stochtree/gamma_sampler.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/log.h>
14#include <stochtree/meta.h>
15#include <stochtree/normal_sampler.h>
16#include <stochtree/partition_tracker.h>
17#include <stochtree/prior.h>
18#include <stochtree/tree.h>
19
20#include <random>
21#include <tuple>
22#include <variant>
23
24namespace StochTree {
25
352 kConstantLeafGaussian,
353 kUnivariateRegressionLeafGaussian,
354 kMultivariateRegressionLeafGaussian,
355 kLogLinearVariance
356};
357
360 public:
361 data_size_t n;
362 double sum_w;
363 double sum_yw;
365 n = 0;
366 sum_w = 0.0;
367 sum_yw = 0.0;
368 }
369 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
370 n += 1;
371 if (dataset.HasVarWeights()) {
372 sum_w += 1/dataset.VarWeightValue(row_idx);
373 sum_yw += outcome(row_idx, 0)/dataset.VarWeightValue(row_idx);
374 } else {
375 sum_w += 1.0;
376 sum_yw += outcome(row_idx, 0);
377 }
378 }
379 void ResetSuffStat() {
380 n = 0;
381 sum_w = 0.0;
382 sum_yw = 0.0;
383 }
384 void AddSuffStat(GaussianConstantSuffStat& lhs, GaussianConstantSuffStat& rhs) {
385 n = lhs.n + rhs.n;
386 sum_w = lhs.sum_w + rhs.sum_w;
387 sum_yw = lhs.sum_yw + rhs.sum_yw;
388 }
389 void SubtractSuffStat(GaussianConstantSuffStat& lhs, GaussianConstantSuffStat& rhs) {
390 n = lhs.n - rhs.n;
391 sum_w = lhs.sum_w - rhs.sum_w;
392 sum_yw = lhs.sum_yw - rhs.sum_yw;
393 }
394 bool SampleGreaterThan(data_size_t threshold) {
395 return n > threshold;
396 }
397 bool SampleGreaterThanEqual(data_size_t threshold) {
398 return n >= threshold;
399 }
400 data_size_t SampleSize() {
401 return n;
402 }
403};
404
407 public:
408 GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
417 double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance);
424 double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance);
431 double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance);
438 double PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance);
450 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
451 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
452 void SetScale(double tau) {tau_ = tau;}
453 inline bool RequiresBasis() {return false;}
454 private:
455 double tau_;
456 UnivariateNormalSampler normal_sampler_;
457};
458
461 public:
462 data_size_t n;
463 double sum_xxw;
464 double sum_yxw;
466 n = 0;
467 sum_xxw = 0.0;
468 sum_yxw = 0.0;
469 }
470 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
471 n += 1;
472 if (dataset.HasVarWeights()) {
473 sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx);
474 sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0)/dataset.VarWeightValue(row_idx);
475 } else {
476 sum_xxw += dataset.BasisValue(row_idx, 0)*dataset.BasisValue(row_idx, 0);
477 sum_yxw += outcome(row_idx, 0)*dataset.BasisValue(row_idx, 0);
478 }
479 }
480 void ResetSuffStat() {
481 n = 0;
482 sum_xxw = 0.0;
483 sum_yxw = 0.0;
484 }
486 n = lhs.n + rhs.n;
487 sum_xxw = lhs.sum_xxw + rhs.sum_xxw;
488 sum_yxw = lhs.sum_yxw + rhs.sum_yxw;
489 }
491 n = lhs.n - rhs.n;
492 sum_xxw = lhs.sum_xxw - rhs.sum_xxw;
493 sum_yxw = lhs.sum_yxw - rhs.sum_yxw;
494 }
495 bool SampleGreaterThan(data_size_t threshold) {
496 return n > threshold;
497 }
498 bool SampleGreaterThanEqual(data_size_t threshold) {
499 return n >= threshold;
500 }
501 data_size_t SampleSize() {
502 return n;
503 }
504};
505
508 public:
509 GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
525 double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
532 double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
539 double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
551 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
552 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
553 void SetScale(double tau) {tau_ = tau;}
554 inline bool RequiresBasis() {return true;}
555 private:
556 double tau_;
557 UnivariateNormalSampler normal_sampler_;
558};
559
562 public:
563 data_size_t n;
564 int p;
565 Eigen::MatrixXd XtWX;
566 Eigen::MatrixXd ytWX;
568 n = 0;
569 XtWX = Eigen::MatrixXd::Zero(basis_dim, basis_dim);
570 ytWX = Eigen::MatrixXd::Zero(1, basis_dim);
571 p = basis_dim;
572 }
573 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
574 n += 1;
575 if (dataset.HasVarWeights()) {
576 XtWX += dataset.GetBasis()(row_idx, Eigen::all).transpose()*dataset.GetBasis()(row_idx, Eigen::all)/dataset.VarWeightValue(row_idx);
577 ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all)))/dataset.VarWeightValue(row_idx);
578 } else {
579 XtWX += dataset.GetBasis()(row_idx, Eigen::all).transpose()*dataset.GetBasis()(row_idx, Eigen::all);
580 ytWX += (outcome(row_idx, 0)*(dataset.GetBasis()(row_idx, Eigen::all)));
581 }
582 }
583 void ResetSuffStat() {
584 n = 0;
585 XtWX = Eigen::MatrixXd::Zero(p, p);
586 ytWX = Eigen::MatrixXd::Zero(1, p);
587 }
589 n = lhs.n + rhs.n;
590 XtWX = lhs.XtWX + rhs.XtWX;
591 ytWX = lhs.ytWX + rhs.ytWX;
592 }
594 n = lhs.n - rhs.n;
595 XtWX = lhs.XtWX - rhs.XtWX;
596 ytWX = lhs.ytWX - rhs.ytWX;
597 }
598 bool SampleGreaterThan(data_size_t threshold) {
599 return n > threshold;
600 }
601 bool SampleGreaterThanEqual(data_size_t threshold) {
602 return n >= threshold;
603 }
604 data_size_t SampleSize() {
605 return n;
606 }
607};
608
611 public:
617 GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();}
640 Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
647 Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
659 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
660 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
661 void SetScale(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0;}
662 inline bool RequiresBasis() {return true;}
663 private:
664 Eigen::MatrixXd Sigma_0_;
665 MultivariateNormalSampler multivariate_normal_sampler_;
666};
667
670 public:
671 data_size_t n;
672 double weighted_sum_ei;
674 n = 0;
675 weighted_sum_ei = 0.0;
676 }
677 void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) {
678 n += 1;
679 weighted_sum_ei += std::exp(std::log(outcome(row_idx)*outcome(row_idx)) - tracker.GetSamplePrediction(row_idx) + tracker.GetTreeSamplePrediction(row_idx, tree_idx));
680 }
681 void ResetSuffStat() {
682 n = 0;
683 weighted_sum_ei = 0.0;
684 }
685 void AddSuffStat(LogLinearVarianceSuffStat& lhs, LogLinearVarianceSuffStat& rhs) {
686 n = lhs.n + rhs.n;
687 weighted_sum_ei = lhs.weighted_sum_ei + rhs.weighted_sum_ei;
688 }
689 void SubtractSuffStat(LogLinearVarianceSuffStat& lhs, LogLinearVarianceSuffStat& rhs) {
690 n = lhs.n - rhs.n;
691 weighted_sum_ei = lhs.weighted_sum_ei - rhs.weighted_sum_ei;
692 }
693 bool SampleGreaterThan(data_size_t threshold) {
694 return n > threshold;
695 }
696 bool SampleGreaterThanEqual(data_size_t threshold) {
697 return n >= threshold;
698 }
699 data_size_t SampleSize() {
700 return n;
701 }
702};
703
706 public:
707 LogLinearVarianceLeafModel(double a, double b) {a_ = a; b_ = b; gamma_sampler_ = GammaSampler();}
716 double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat& left_stat, LogLinearVarianceSuffStat& right_stat, double global_variance);
723 double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance);
724 double SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance);
731 double PosteriorParameterShape(LogLinearVarianceSuffStat& suff_stat, double global_variance);
738 double PosteriorParameterScale(LogLinearVarianceSuffStat& suff_stat, double global_variance);
750 void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
751 void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
752 void SetPriorShape(double a) {a_ = a;}
753 void SetPriorRate(double b) {b_ = b;}
754 inline bool RequiresBasis() {return false;}
755 private:
756 double a_;
757 double b_;
758 GammaSampler gamma_sampler_;
759};
760
761using SuffStatVariant = std::variant<GaussianConstantSuffStat,
762 GaussianUnivariateRegressionSuffStat,
763 GaussianMultivariateRegressionSuffStat,
764 LogLinearVarianceSuffStat>;
765
766using LeafModelVariant = std::variant<GaussianConstantLeafModel,
767 GaussianUnivariateRegressionLeafModel,
768 GaussianMultivariateRegressionLeafModel,
769 LogLinearVarianceLeafModel>;
770
771template<typename SuffStatType, typename... SuffStatConstructorArgs>
772static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) {
773 return SuffStatType(leaf_suff_stat_args...);
774}
775
776template<typename LeafModelType, typename... LeafModelConstructorArgs>
777static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_model_args) {
778 return LeafModelType(leaf_model_args...);
779}
780
781static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) {
782 if (model_type == kConstantLeafGaussian) {
783 return createSuffStat<GaussianConstantSuffStat>();
784 } else if (model_type == kUnivariateRegressionLeafGaussian) {
785 return createSuffStat<GaussianUnivariateRegressionSuffStat>();
786 } else if (model_type == kMultivariateRegressionLeafGaussian) {
787 return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim);
788 } else {
789 return createSuffStat<LogLinearVarianceSuffStat>();
790 }
791}
792
793static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) {
794 if (model_type == kConstantLeafGaussian) {
795 return createLeafModel<GaussianConstantLeafModel, double>(tau);
796 } else if (model_type == kUnivariateRegressionLeafGaussian) {
797 return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau);
798 } else if (model_type == kMultivariateRegressionLeafGaussian) {
799 return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0);
800 } else if (model_type == kLogLinearVariance) {
801 return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b);
802 } else {
803 Log::Fatal("Incompatible model type provided to leaf model factory");
804 }
805}
806
807template<typename SuffStatType>
808static inline void AccumulateSuffStatProposed(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
809 ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature) {
810 // Acquire iterators
811 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_num);
812 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_num);
813
814 // Accumulate sufficient statistics
815 for (auto i = node_begin_iter; i != node_end_iter; i++) {
816 auto idx = *i;
817 double feature_value = dataset.CovariateValue(idx, split_feature);
818 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
819 if (split.SplitTrue(feature_value)) {
820 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
821 } else {
822 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
823 }
824 }
825}
826
827template<typename SuffStatType>
828static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker,
829 ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) {
830 // Acquire iterators
831 auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id);
832 auto left_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, left_node_id);
833 auto right_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, right_node_id);
834 auto right_node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, right_node_id);
835
836 // Accumulate sufficient statistics for the left and split nodes
837 for (auto i = left_node_begin_iter; i != left_node_end_iter; i++) {
838 auto idx = *i;
839 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
840 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
841 }
842
843 // Accumulate sufficient statistics for the right and split nodes
844 for (auto i = right_node_begin_iter; i != right_node_end_iter; i++) {
845 auto idx = *i;
846 right_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
847 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
848 }
849}
850
851template<typename SuffStatType, bool sorted>
852static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, int tree_num, int node_id) {
853 // Acquire iterators
854 std::vector<data_size_t>::iterator node_begin_iter;
855 std::vector<data_size_t>::iterator node_end_iter;
856 if (sorted) {
857 // Default to the first feature if we're using the presort tracker
858 node_begin_iter = tracker.SortedNodeBeginIterator(node_id, 0);
859 node_end_iter = tracker.SortedNodeEndIterator(node_id, 0);
860 } else {
861 node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
862 node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
863 }
864
865 // Accumulate sufficient statistics
866 for (auto i = node_begin_iter; i != node_end_iter; i++) {
867 auto idx = *i;
868 node_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
869 }
870}
871
872template<typename SuffStatType>
873static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container,
874 ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id,
875 int feature_num, int cutpoint_num) {
876 // Acquire iterators
877 auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num);
878 auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num);
879
880 // Determine node start point
881 data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num);
882
883 // Determine cutpoint bin start and end points
884 data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num, feature_num);
885 data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_num, feature_num);
886 data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_num + 1, feature_num);
887
888 // Cutpoint specific iterators
889 // TODO: fix the hack of having to subtract off node_begin, probably by cleaning up the CutpointGridContainer interface
890 auto cutpoint_begin_iter = node_begin_iter + (current_bin_begin - node_begin);
891 auto cutpoint_end_iter = node_begin_iter + (next_bin_begin - node_begin);
892
893 // Accumulate sufficient statistics
894 for (auto i = cutpoint_begin_iter; i != cutpoint_end_iter; i++) {
895 auto idx = *i;
896 left_suff_stat.IncrementSuffStat(dataset, residual.GetData(), tracker, idx, tree_num);
897 }
898}
899
// end of leaf_model_group
901
902} // namespace StochTree
903
904#endif // STOCHTREE_LEAF_MODEL_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:194
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
double BasisValue(data_size_t row, int col)
Returns a dataset's basis value stored at (row, col)
Definition data.h:372
Eigen::MatrixXd & GetBasis()
Return a reference to the raw Eigen::MatrixXd storing the basis data.
Definition data.h:390
double VarWeightValue(data_size_t row)
Returns a dataset's variance weight stored at element row
Definition data.h:378
bool HasVarWeights()
Whether or not a ForestDataset has (yet) loaded variance weights.
Definition data.h:352
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:50
Definition gamma_sampler.h:9
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:406
double SplitLogMarginalLikelihood(GaussianConstantSuffStat &left_stat, GaussianConstantSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterVariance(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
double PosteriorParameterMean(GaussianConstantSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:359
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:610
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &left_stat, GaussianMultivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd &Sigma_0)
Construct a new GaussianMultivariateRegressionLeafModel object.
Definition leaf_model.h:617
Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:561
Marginal likelihood and posterior computation for gaussian homoskedastic constant leaf outcome model.
Definition leaf_model.h:507
double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &left_stat, GaussianUnivariateRegressionSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior variance.
double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Leaf node posterior mean.
double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model...
Definition leaf_model.h:460
Marginal likelihood and posterior computation for heteroskedastic log-linear variance model.
Definition leaf_model.h:705
double NoSplitLogMarginalLikelihood(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Log marginal likelihood of a node, evaluated only for observations that fall into the node being spli...
double PosteriorParameterShape(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior shape parameter.
void SampleLeafParameters(ForestDataset &dataset, ForestTracker &tracker, ColumnVector &residual, Tree *tree, int tree_num, double global_variance, std::mt19937 &gen)
Draw new parameters for every leaf node in tree, using a Gibbs update that conditions on the data,...
double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat &left_stat, LogLinearVarianceSuffStat &right_stat, double global_variance)
Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node...
double PosteriorParameterScale(LogLinearVarianceSuffStat &suff_stat, double global_variance)
Leaf node posterior scale parameter.
Sufficient statistic and associated operations for heteroskedastic log-linear variance model.
Definition leaf_model.h:669
Definition normal_sampler.h:24
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Decision tree data structure.
Definition tree.h:69
Definition normal_sampler.h:12
ModelType
Leaf models for the forest sampler:
Definition leaf_model.h:351
Definition category_tracker.h:40