Wrapper around a C++ class that stores a single ensemble of decision trees (often treated as the "active forest" / current state of a forest term in a sampling loop in R)
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at stochtree.ai
Public fields
forest_ptrExternal pointer to a C++ TreeEnsemble class
internal_forest_is_emptyWhether the forest has not yet been "initialized" such that its
predictfunction can be called.
Methods
Method new()
Create a new Forest object.
Usage
Forest$new(
num_trees,
leaf_dimension = 1,
is_leaf_constant = FALSE,
is_exponentiated = FALSE
)Method merge_forest()
Create a larger forest by merging the trees of this forest with those of another forest
Method add_constant()
Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, constant_value will be added to every dimension of the leaves.
Method multiply_constant()
Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, constant_multiple will be multiplied through every dimension of the leaves.
Method predict()
Predict forest on every sample in forest_dataset
Method predict_raw()
Predict "raw" leaf values (without being multiplied by basis) for every sample in forest_dataset
Returns
Array of predictions for each observation in forest_dataset and
each sample in the ForestSamples class with each prediction having the
dimensionality of the forests' leaf model. In the case of a constant leaf model
or univariate leaf regression, this array is a vector (length is the number of
observations). In the case of a multivariate leaf regression,
this array is a matrix (number of observations by leaf model dimension,
number of samples).
Method set_root_leaves()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Method prepare_for_sampler()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Arguments
datasetForestDatasetDataset class (covariates, basis, etc...)outcomeOutcomeOutcome class (residual / partial residual)forest_modelForestModelobject storing tracking structures used in training / samplingleaf_model_intInteger value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
leaf_valueConstant leaf value(s) to be fixed for each tree in the ensemble indexed by
forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.
Method adjust_residual()
Adjusts residual based on the predictions of a forest
This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
Arguments
datasetForestDatasetobject storing the covariates and bases for a given forestoutcomeOutcomeobject storing the residuals to be updated based on forest predictionsforest_modelForestModelobject storing tracking structures used in training / samplingrequires_basisWhether or not a forest requires a basis for prediction
addWhether forest predictions should be added to or subtracted from residuals
Method add_numeric_split_tree()
Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble
Usage
Forest$add_numeric_split_tree(
tree_num,
leaf_num,
feature_num,
split_threshold,
left_leaf_value,
right_leaf_value
)Arguments
tree_numIndex of the tree to be split
leaf_numLeaf to be split
feature_numFeature that defines the new split
split_thresholdValue that defines the cutoff of the new split
left_leaf_valueValue (or vector of values) to assign to the newly created left node
right_leaf_valueValue (or vector of values) to assign to the newly created right node
Method get_tree_leaves()
Retrieve a vector of indices of leaf nodes for a given tree in a given forest
Method get_tree_split_counts()
Retrieve a vector of split counts for every training set variable in a given tree in the forest
Method get_forest_split_counts()
Retrieve a vector of split counts for every training set variable in the forest
Method is_empty()
When a forest object is created, it is "empty" in the sense that none
of its component trees have leaves with values. There are two ways to
"initialize" a Forest object. First, the set_root_leaves() method
simply initializes every tree in the forest to a single node carrying
the same (user-specified) leaf value. Second, the prepare_for_sampler()
method initializes every tree in the forest to a single node with the
same value and also propagates this information through to a ForestModel
object, which must be synchronized with a Forest during a forest
sampler loop.