sampler.ForestSampler

sampler.ForestSampler(dataset, global_config, forest_config)

Wrapper around many of the core C++ sampling data structures and algorithms.

Parameters

Name Type Description Default
dataset Dataset stochtree dataset object storing covariates / bases / weights required
global_config GlobalModelConfig GlobalModelConfig object containing global model parameters and settings required
forest_config ForestModelConfig ForestModelConfig object containing forest model parameters and settings required

Methods

Name Description
reconstitute_from_forest Re-initialize a forest sampler tracking data structures from a specific forest in a ForestContainer
sample_one_iteration Sample one iteration of a forest using the specified model and tree sampling algorithm
prepare_for_sampler Initialize forest and tracking data structures with constant root values before running a sampler
adjust_residual Method that “adjusts” the residual used for training tree ensembles by either adding or subtracting the prediction of each tree to the existing residual.
propagate_basis_update Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions
get_cached_forest_predictions Extract an internally-cached prediction of a forest on the training dataset in a sampler.
update_alpha Update alpha in the tree prior
update_beta Update beta in the tree prior
update_min_samples_leaf Update min_samples_leaf in the tree prior
update_max_depth Update max_depth in the tree prior

reconstitute_from_forest

sampler.ForestSampler.reconstitute_from_forest(
    forest,
    dataset,
    residual,
    is_mean_model,
)

Re-initialize a forest sampler tracking data structures from a specific forest in a ForestContainer

Parameters

Name Type Description Default
dataset Dataset stochtree dataset object storing covariates / bases / weights required
residual Residual stochtree object storing continuously updated partial / full residual required
forest Forest stochtree object storing tree ensemble required
is_mean_model bool Indicator of whether the model being updated a conditional mean model (True) or a conditional variance model (False) required

sample_one_iteration

sampler.ForestSampler.sample_one_iteration(
    forest_container,
    forest,
    dataset,
    residual,
    rng,
    global_config,
    forest_config,
    keep_forest,
    gfr,
    num_threads=-1,
)

Sample one iteration of a forest using the specified model and tree sampling algorithm

Parameters

Name Type Description Default
forest_container ForestContainer stochtree object storing tree ensembles required
forest Forest stochtree object storing the “active” forest being sampled required
dataset Dataset stochtree dataset object storing covariates / bases / weights required
residual Residual stochtree object storing continuously updated partial / full residual required
rng RNG stochtree object storing C++ random number generator to be used sampling algorithm required
global_config GlobalModelConfig GlobalModelConfig object containing global model parameters and settings required
forest_config ForestModelConfig ForestModelConfig object containing forest model parameters and settings required
keep_forest bool Whether or not the resulting forest should be retained in forest_container or discarded (due to burnin or thinning for example) required
gfr bool Whether or not the “grow-from-root” (GFR) sampler is run (if this is True and leaf_model_int=0 this is equivalent to XBART, if this is FALSE and leaf_model_int=0 this is equivalent to the original BART) required
num_threads int Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user’s system, this will default to 1, otherwise to the maximum number of available threads. -1

prepare_for_sampler

sampler.ForestSampler.prepare_for_sampler(
    dataset,
    residual,
    forest,
    leaf_model,
    initial_values,
)

Initialize forest and tracking data structures with constant root values before running a sampler

Parameters

Name Type Description Default
dataset Dataset stochtree dataset object storing covariates / bases / weights required
residual Residual stochtree object storing continuously updated partial / full residual required
forest Forest stochtree object storing the “active” forest being sampled required
leaf_model int Integer encoding the leaf model type required
initial_values np.array Constant root node value(s) at which to initialize forest prediction (internally, it is divided by the number of trees and typically it is 0 for mean models and 1 for variance models). required

adjust_residual

sampler.ForestSampler.adjust_residual(
    dataset,
    residual,
    forest,
    requires_basis,
    add,
)

Method that “adjusts” the residual used for training tree ensembles by either adding or subtracting the prediction of each tree to the existing residual.

This is typically run just once at the beginning of a forest sampling algorithm — after trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.

Parameters

Name Type Description Default
dataset Dataset stochtree dataset object storing covariates / bases / weights required
residual Residual stochtree object storing continuously updated partial / full residual required
forest Forest stochtree object storing the “active” forest being sampled required
requires_basis bool Whether or not the forest requires a basis dot product when predicting required
add bool Whether the predictions of each tree are added (if add=True) or subtracted (add=False) from the outcome to form the new residual required

propagate_basis_update

sampler.ForestSampler.propagate_basis_update(dataset, residual, forest)

Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.

This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall “function” represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run.

Parameters

Name Type Description Default
dataset Dataset Stochtree dataset object storing covariates / bases / weights required
residual Residual Stochtree object storing continuously updated partial / full residual required
forest Forest Stochtree object storing the “active” forest being sampled required

get_cached_forest_predictions

sampler.ForestSampler.get_cached_forest_predictions()

Extract an internally-cached prediction of a forest on the training dataset in a sampler.

Returns

Name Type Description
np.array Numpy 1D array with as many elements as observations in the training dataset

update_alpha

sampler.ForestSampler.update_alpha(alpha)

Update alpha in the tree prior

Parameters

Name Type Description Default
alpha float New value of alpha to be used required

update_beta

sampler.ForestSampler.update_beta(beta)

Update beta in the tree prior

Parameters

Name Type Description Default
beta float New value of beta to be used required

update_min_samples_leaf

sampler.ForestSampler.update_min_samples_leaf(min_samples_leaf)

Update min_samples_leaf in the tree prior

Parameters

Name Type Description Default
min_samples_leaf int New value of min_samples_leaf to be used required

update_max_depth

sampler.ForestSampler.update_max_depth(max_depth)

Update max_depth in the tree prior

Parameters

Name Type Description Default
max_depth int New value of max_depth to be used required