Skip to contents

Wrapper around a C++ class that stores a single ensemble of decision trees (often treated as the "active forest" / current state of a forest term in a sampling loop in R)

This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at stochtree.ai

Public fields

forest_ptr

External pointer to a C++ TreeEnsemble class

internal_forest_is_empty

Whether the forest has not yet been "initialized" such that its predict function can be called.

Methods


Method new()

Create a new Forest object.

Usage

Forest$new(
  num_trees,
  leaf_dimension = 1,
  is_leaf_constant = FALSE,
  is_exponentiated = FALSE
)

Arguments

num_trees

Number of trees in the forest

leaf_dimension

Dimensionality of the outcome model

is_leaf_constant

Whether leaf is constant

is_exponentiated

Whether forest predictions should be exponentiated before being returned

Returns

A new Forest object.


Method merge_forest()

Create a larger forest by merging the trees of this forest with those of another forest

Usage

Forest$merge_forest(forest)

Arguments

forest

Forest to be merged into this forest


Method add_constant()

Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, constant_value will be added to every dimension of the leaves.

Usage

Forest$add_constant(constant_value)

Arguments

constant_value

Value that will be added to every leaf of every tree


Method multiply_constant()

Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, constant_multiple will be multiplied through every dimension of the leaves.

Usage

Forest$multiply_constant(constant_multiple)

Arguments

constant_multiple

Value that will be multiplied by every leaf of every tree


Method predict()

Predict forest on every sample in forest_dataset

Usage

Forest$predict(forest_dataset)

Arguments

forest_dataset

ForestDataset R class

Returns

vector of predictions with as many rows as in forest_dataset


Method predict_raw()

Predict "raw" leaf values (without being multiplied by basis) for every sample in forest_dataset

Usage

Forest$predict_raw(forest_dataset)

Arguments

forest_dataset

ForestDataset R class

Returns

Array of predictions for each observation in forest_dataset and each sample in the ForestSamples class with each prediction having the dimensionality of the forests' leaf model. In the case of a constant leaf model or univariate leaf regression, this array is a vector (length is the number of observations). In the case of a multivariate leaf regression, this array is a matrix (number of observations by leaf model dimension, number of samples).


Method set_root_leaves()

Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.

Usage

Forest$set_root_leaves(leaf_value)

Arguments

leaf_value

Constant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.


Method prepare_for_sampler()

Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.

Usage

Forest$prepare_for_sampler(
  dataset,
  outcome,
  forest_model,
  leaf_model_int,
  leaf_value
)

Arguments

dataset

ForestDataset Dataset class (covariates, basis, etc...)

outcome

Outcome Outcome class (residual / partial residual)

forest_model

ForestModel object storing tracking structures used in training / sampling

leaf_model_int

Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).

leaf_value

Constant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.


Method adjust_residual()

Adjusts residual based on the predictions of a forest

This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.

Usage

Forest$adjust_residual(dataset, outcome, forest_model, requires_basis, add)

Arguments

dataset

ForestDataset object storing the covariates and bases for a given forest

outcome

Outcome object storing the residuals to be updated based on forest predictions

forest_model

ForestModel object storing tracking structures used in training / sampling

requires_basis

Whether or not a forest requires a basis for prediction

add

Whether forest predictions should be added to or subtracted from residuals


Method num_trees()

Return number of trees in each ensemble of a Forest object

Usage

Forest$num_trees()

Returns

Tree count


Method leaf_dimension()

Return output dimension of trees in a Forest object

Usage

Forest$leaf_dimension()

Returns

Leaf node parameter size


Method is_constant_leaf()

Return constant leaf status of trees in a Forest object

Usage

Forest$is_constant_leaf()

Returns

TRUE if leaves are constant, FALSE otherwise


Method is_exponentiated()

Return exponentiation status of trees in a Forest object

Usage

Forest$is_exponentiated()

Returns

TRUE if leaf predictions must be exponentiated, FALSE otherwise


Method add_numeric_split_tree()

Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble

Usage

Forest$add_numeric_split_tree(
  tree_num,
  leaf_num,
  feature_num,
  split_threshold,
  left_leaf_value,
  right_leaf_value
)

Arguments

tree_num

Index of the tree to be split

leaf_num

Leaf to be split

feature_num

Feature that defines the new split

split_threshold

Value that defines the cutoff of the new split

left_leaf_value

Value (or vector of values) to assign to the newly created left node

right_leaf_value

Value (or vector of values) to assign to the newly created right node


Method get_tree_leaves()

Retrieve a vector of indices of leaf nodes for a given tree in a given forest

Usage

Forest$get_tree_leaves(tree_num)

Arguments

tree_num

Index of the tree for which leaf indices will be retrieved


Method get_tree_split_counts()

Retrieve a vector of split counts for every training set variable in a given tree in the forest

Usage

Forest$get_tree_split_counts(tree_num, num_features)

Arguments

tree_num

Index of the tree for which split counts will be retrieved

num_features

Total number of features in the training set


Method get_forest_split_counts()

Retrieve a vector of split counts for every training set variable in the forest

Usage

Forest$get_forest_split_counts(num_features)

Arguments

num_features

Total number of features in the training set


Method tree_max_depth()

Maximum depth of a specific tree in the forest

Usage

Forest$tree_max_depth(tree_num)

Arguments

tree_num

Tree index within forest

Returns

Maximum leaf depth


Method average_max_depth()

Average the maximum depth of each tree in the forest

Usage

Forest$average_max_depth()

Returns

Average maximum depth


Method is_empty()

When a forest object is created, it is "empty" in the sense that none of its component trees have leaves with values. There are two ways to "initialize" a Forest object. First, the set_root_leaves() method simply initializes every tree in the forest to a single node carrying the same (user-specified) leaf value. Second, the prepare_for_sampler() method initializes every tree in the forest to a single node with the same value and also propagates this information through to a ForestModel object, which must be synchronized with a Forest during a forest sampler loop.

Usage

Forest$is_empty()

Returns

TRUE if a Forest has not yet been initialized with a constant root value, FALSE otherwise if the forest has already been initialized / grown.