Class that stores a single ensemble of decision trees (often treated as the "active forest")
Source:R/forest.R
Forest.RdWrapper around a C++ tree ensemble
Public fields
forest_ptrExternal pointer to a C++ TreeEnsemble class
internal_forest_is_emptyWhether the forest has not yet been "initialized" such that its
predictfunction can be called.
Methods
Method new()
Create a new Forest object.
Usage
Forest$new(
num_trees,
leaf_dimension = 1,
is_leaf_constant = FALSE,
is_exponentiated = FALSE
)Method merge_forest()
Create a larger forest by merging the trees of this forest with those of another forest
Method add_constant()
Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, constant_value will be added to every dimension of the leaves.
Method multiply_constant()
Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, constant_multiple will be multiplied through every dimension of the leaves.
Method predict()
Predict forest on every sample in forest_dataset
Method predict_raw()
Predict "raw" leaf values (without being multiplied by basis) for every sample in forest_dataset
Returns
Array of predictions for each observation in forest_dataset and
each sample in the ForestSamples class with each prediction having the
dimensionality of the forests' leaf model. In the case of a constant leaf model
or univariate leaf regression, this array is a vector (length is the number of
observations). In the case of a multivariate leaf regression,
this array is a matrix (number of observations by leaf model dimension,
number of samples).
Method set_root_leaves()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Method prepare_for_sampler()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Arguments
datasetForestDatasetDataset class (covariates, basis, etc...)outcomeOutcomeOutcome class (residual / partial residual)forest_modelForestModelobject storing tracking structures used in training / samplingleaf_model_intInteger value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
leaf_valueConstant leaf value(s) to be fixed for each tree in the ensemble indexed by
forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.
Method adjust_residual()
Adjusts residual based on the predictions of a forest
This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
Arguments
datasetForestDatasetobject storing the covariates and bases for a given forestoutcomeOutcomeobject storing the residuals to be updated based on forest predictionsforest_modelForestModelobject storing tracking structures used in training / samplingrequires_basisWhether or not a forest requires a basis for prediction
addWhether forest predictions should be added to or subtracted from residuals
Method add_numeric_split_tree()
Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble
Usage
Forest$add_numeric_split_tree(
tree_num,
leaf_num,
feature_num,
split_threshold,
left_leaf_value,
right_leaf_value
)Arguments
tree_numIndex of the tree to be split
leaf_numLeaf to be split
feature_numFeature that defines the new split
split_thresholdValue that defines the cutoff of the new split
left_leaf_valueValue (or vector of values) to assign to the newly created left node
right_leaf_valueValue (or vector of values) to assign to the newly created right node
Method get_tree_leaves()
Retrieve a vector of indices of leaf nodes for a given tree in a given forest
Method get_tree_split_counts()
Retrieve a vector of split counts for every training set variable in a given tree in the forest
Method get_forest_split_counts()
Retrieve a vector of split counts for every training set variable in the forest
Method is_empty()
When a forest object is created, it is "empty" in the sense that none
of its component trees have leaves with values. There are two ways to
"initialize" a Forest object. First, the set_root_leaves() method
simply initializes every tree in the forest to a single node carrying
the same (user-specified) leaf value. Second, the prepare_for_sampler()
method initializes every tree in the forest to a single node with the
same value and also propagates this information through to a ForestModel
object, which must be synchronized with a Forest during a forest
sampler loop.