StochTree 0.0.1
Loading...
Searching...
No Matches
Functions
Forest Sampler API

Functions for sampling from a forest. The core interfce of these functions, as used by the R, Python, and standalone C++ program, is defined by MCMCSampleOneIter, which runs one iteration of the MCMC sampler for a given forest, and GFRSampleOneIter, which runs one iteration of the grow-from-root (GFR) algorithm for a given forest. All other functions are essentially helpers used in a sampling function, which are documented here to make extending the C++ codebase more straightforward. More...

Functions

static void StochTree::VarSplitRange (ForestTracker &tracker, ForestDataset &dataset, int tree_num, int leaf_split, int feature_split, double &var_min, double &var_max)
 Computer the range of available split values for a continuous variable, given the current structure of a tree.
 
static bool StochTree::NodesNonConstantAfterSplit (ForestDataset &dataset, ForestTracker &tracker, TreeSplit &split, int tree_num, int leaf_split, int feature_split)
 Determines whether a proposed split creates two leaf nodes with constant values for every feature (thus ensuring that the tree cannot split further).
 
template<typename LeafModel , typename LeafSuffStat , typename... LeafSuffStatConstructorArgs>
static void StochTree::GFRSampleOneIter (TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, double global_variance, std::vector< FeatureType > &feature_types, int cutpoint_grid_size, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
 
template<typename LeafModel , typename LeafSuffStat , typename... LeafSuffStatConstructorArgs>
static void StochTree::MCMCSampleOneIter (TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
 Runs one iteration of the MCMC sampler for a tree ensemble model, which consists of two steps for every tree in a forest:
 

Detailed Description

Functions for sampling from a forest. The core interfce of these functions, as used by the R, Python, and standalone C++ program, is defined by MCMCSampleOneIter, which runs one iteration of the MCMC sampler for a given forest, and GFRSampleOneIter, which runs one iteration of the grow-from-root (GFR) algorithm for a given forest. All other functions are essentially helpers used in a sampling function, which are documented here to make extending the C++ codebase more straightforward.

Function Documentation

◆ VarSplitRange()

static void StochTree::VarSplitRange ( ForestTracker tracker,
ForestDataset dataset,
int  tree_num,
int  leaf_split,
int  feature_split,
double &  var_min,
double &  var_max 
)
inlinestatic

Computer the range of available split values for a continuous variable, given the current structure of a tree.

Parameters
trackerTracking data structures that speed up sampler operations.
datasetData object containining training data, including covariates, leaf regression bases, and case weights.
tree_numIndex of the tree for which a split is proposed.
leaf_splitIndex of the leaf in tree_num for which a split is proposed.
feature_splitIndex of the feature that we will query the available range.
var_minCurrent minimum feature value (called by refence and modified by this function).
var_maxCurrent maximum feature value (called by refence and modified by this function).

◆ NodesNonConstantAfterSplit()

static bool StochTree::NodesNonConstantAfterSplit ( ForestDataset dataset,
ForestTracker tracker,
TreeSplit split,
int  tree_num,
int  leaf_split,
int  feature_split 
)
inlinestatic

Determines whether a proposed split creates two leaf nodes with constant values for every feature (thus ensuring that the tree cannot split further).

Parameters
datasetData object containining training data, including covariates, leaf regression bases, and case weights.
trackerTracking data structures that speed up sampler operations.
splitProposed split of tree tree_num at node leaf_split.
tree_numIndex of the tree for which a split is proposed.
leaf_splitIndex of the leaf in tree_num for which a split is proposed.
feature_splitIndex of the feature to which split will be applied
Returns
true if split creates two nodes with constant values for every feature in dataset, false otherwise.

◆ GFRSampleOneIter()

template<typename LeafModel , typename LeafSuffStat , typename... LeafSuffStatConstructorArgs>
static void StochTree::GFRSampleOneIter ( TreeEnsemble active_forest,
ForestTracker tracker,
ForestContainer forests,
LeafModel &  leaf_model,
ForestDataset dataset,
ColumnVector residual,
TreePrior tree_prior,
std::mt19937 &  gen,
std::vector< double > &  variable_weights,
double  global_variance,
std::vector< FeatureType > &  feature_types,
int  cutpoint_grid_size,
bool  keep_forest,
bool  pre_initialized,
bool  backfitting,
LeafSuffStatConstructorArgs &...  leaf_suff_stat_args 
)
inlinestatic

Runs one iteration of the "grow-from-root" (GFR) sampler for a tree ensemble model, which consists of two steps for every tree in a forest:

  1. Grow a tree by recursively sampling cutpoint via the GFR algorithm
  2. Sampling leaf node parameters, conditional on an updated tree, via a Gibbs sampler
Template Parameters
LeafModelLeaf model type (i.e. GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, etc...)
LeafSuffStatLeaf sufficient statistic type (i.e. GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, etc...)
LeafSuffStatConstructorArgsType of constructor arguments used to initialize LeafSuffStat class. For GaussianMultivariateRegressionSuffStat, this is int, while each of the other three sufficient statistic classes do not take a constructor argument.
Parameters
active_forestCurrent state of an ensemble from the sampler's perspective. This is managed through an "active forest" class, as distinct from a "forest container" class of stored ensemble samples because we often wish to update model state without saving the result (e.g. during burn-in or thinning of an MCMC sampler).
trackerTracking data structures that speed up sampler operations, synchronized with active_forest tracking a forest's state.
forestsContainer of "stored" forests.
leaf_modelLeaf model object – type is determined by template argument LeafModel.
datasetData object containining training data, including covariates, leaf regression bases, and case weights.
residualData object containing residual used in training. The state of residual is updated by this function (the prior predictions of active_forest are added to the residual and the updated predictions from active_forest are subtracted back out).
tree_priorConfiguration for tree prior (i.e. max depth, min samples in a leaf, depth-defined split probability).
genRandom number generator for sampler.
variable_weightsVector of selection weights for each variable in dataset.
global_varianceCurrent value of (possibly stochastic) global error variance parameter.
feature_typesEnum-coded vector of feature types (see FeatureType) for each feature in dataset.
cutpoint_grid_sizeMaximum size of a grid of potential cutpoints (the grow-from-root algorithm evaluates a series of potential cutpoints for each feature and this parameter "thins" the cutpoint candidates for numeric variables).
keep_forestWhether or not active_forest should be retained in forests.
pre_initializedWhether or not active_forest has already been initialized (note: this parameter will be refactored out soon).
backfittingWhether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered).
leaf_suff_stat_argsAny arguments which must be supplied to initialize a LeafSuffStat object.

◆ MCMCSampleOneIter()

template<typename LeafModel , typename LeafSuffStat , typename... LeafSuffStatConstructorArgs>
static void StochTree::MCMCSampleOneIter ( TreeEnsemble active_forest,
ForestTracker tracker,
ForestContainer forests,
LeafModel &  leaf_model,
ForestDataset dataset,
ColumnVector residual,
TreePrior tree_prior,
std::mt19937 &  gen,
std::vector< double > &  variable_weights,
double  global_variance,
bool  keep_forest,
bool  pre_initialized,
bool  backfitting,
LeafSuffStatConstructorArgs &...  leaf_suff_stat_args 
)
inlinestatic

Runs one iteration of the MCMC sampler for a tree ensemble model, which consists of two steps for every tree in a forest:

  1. Sampling "birth-death" tree modifications via the Metropolis-Hastings algorithm
  2. Sampling leaf node parameters, conditional on a (possibly-updated) tree, via a Gibbs sampler
Template Parameters
LeafModelLeaf model type (i.e. GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, etc...)
LeafSuffStatLeaf sufficient statistic type (i.e. GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, etc...)
LeafSuffStatConstructorArgsType of constructor arguments used to initialize LeafSuffStat class. For GaussianMultivariateRegressionSuffStat, this is int, while each of the other three sufficient statistic classes do not take a constructor argument.
Parameters
active_forestCurrent state of an ensemble from the sampler's perspective. This is managed through an "active forest" class, as distinct from a "forest container" class of stored ensemble samples because we often wish to update model state without saving the result (e.g. during burn-in or thinning of an MCMC sampler).
trackerTracking data structures that speed up sampler operations, synchronized with active_forest tracking a forest's state.
forestsContainer of "stored" forests.
leaf_modelLeaf model object – type is determined by template argument LeafModel.
datasetData object containining training data, including covariates, leaf regression bases, and case weights.
residualData object containing residual used in training. The state of residual is updated by this function (the prior predictions of active_forest are added to the residual and the updated predictions from active_forest are subtracted back out).
tree_priorConfiguration for tree prior (i.e. max depth, min samples in a leaf, depth-defined split probability).
genRandom number generator for sampler.
variable_weightsVector of selection weights for each variable in dataset.
global_varianceCurrent value of (possibly stochastic) global error variance parameter.
keep_forestWhether or not active_forest should be retained in forests.
pre_initializedWhether or not active_forest has already been initialized (note: this parameter will be refactored out soon).
backfittingWhether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered).
leaf_suff_stat_argsAny arguments which must be supplied to initialize a LeafSuffStat object.