Hosts the C++ data structures needed to sample an ensemble of decision trees, and exposes functionality to run a forest sampler (using either MCMC or the grow-from-root algorithm).
Public fields
tracker_ptrExternal pointer to a C++ ForestTracker class
tree_prior_ptrExternal pointer to a C++ TreePrior class
Methods
Method new()
Create a new ForestModel object.
Usage
ForestModel$new(
forest_dataset,
feature_types,
num_trees,
n,
alpha,
beta,
min_samples_leaf,
max_depth = -1
)Arguments
forest_datasetForestDatasetobject, used to initialize forest sampling data structuresfeature_typesFeature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
num_treesNumber of trees in the forest being sampled
nNumber of observations in
forest_datasetalphaRoot node split probability in tree prior
betaDepth prior penalty in tree prior
min_samples_leafMinimum number of samples in a tree leaf
max_depthMaximum depth that any tree can reach
Method sample_one_iteration()
Run a single iteration of the forest sampling algorithm (MCMC or GFR)
Usage
ForestModel$sample_one_iteration(
forest_dataset,
residual,
forest_samples,
active_forest,
rng,
forest_model_config,
global_model_config,
num_threads = -1,
keep_forest = TRUE,
gfr = TRUE
)Arguments
forest_datasetDataset used to sample the forest
residualOutcome used to sample the forest
forest_samplesContainer of forest samples
active_forest"Active" forest updated by the sampler in each iteration
rngWrapper around C++ random number generator
forest_model_configForestModelConfig object containing forest model parameters and settings
global_model_configGlobalModelConfig object containing global model parameters and settings
num_threadsNumber of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to
1, otherwise to the maximum number of available threads.keep_forest(Optional) Whether the updated forest sample should be saved to
forest_samples. Default:TRUE.gfr(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default:
TRUE.
Method get_cached_forest_predictions()
Extract an internally-cached prediction of a forest on the training dataset in a sampler.
Method propagate_basis_update()
Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall "function" represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run.
Method propagate_residual_update()
Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree.
This function is run after the Outcome class's update_data method, which overwrites the partial residual with an entirely new stream of outcome data.