sampler.ForestSampler
sampler.ForestSampler(dataset, global_config, forest_config)Wrapper around many of the core C++ sampling data structures and algorithms.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| dataset | Dataset | stochtree dataset object storing covariates / bases / weights |
required |
| global_config | GlobalModelConfig | GlobalModelConfig object containing global model parameters and settings |
required |
| forest_config | ForestModelConfig | ForestModelConfig object containing forest model parameters and settings |
required |
Methods
| Name | Description |
|---|---|
| reconstitute_from_forest | Re-initialize a forest sampler tracking data structures from a specific forest in a ForestContainer |
| sample_one_iteration | Sample one iteration of a forest using the specified model and tree sampling algorithm |
| prepare_for_sampler | Initialize forest and tracking data structures with constant root values before running a sampler |
| adjust_residual | Method that “adjusts” the residual used for training tree ensembles by either adding or subtracting the prediction of each tree to the existing residual. |
| propagate_basis_update | Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions |
| get_cached_forest_predictions | Extract an internally-cached prediction of a forest on the training dataset in a sampler. |
| update_alpha | Update alpha in the tree prior |
| update_beta | Update beta in the tree prior |
| update_min_samples_leaf | Update min_samples_leaf in the tree prior |
| update_max_depth | Update max_depth in the tree prior |
reconstitute_from_forest
sampler.ForestSampler.reconstitute_from_forest(
forest,
dataset,
residual,
is_mean_model,
)Re-initialize a forest sampler tracking data structures from a specific forest in a ForestContainer
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| dataset | Dataset | stochtree dataset object storing covariates / bases / weights |
required |
| residual | Residual | stochtree object storing continuously updated partial / full residual |
required |
| forest | Forest | stochtree object storing tree ensemble |
required |
| is_mean_model | bool | Indicator of whether the model being updated a conditional mean model (True) or a conditional variance model (False) |
required |
sample_one_iteration
sampler.ForestSampler.sample_one_iteration(
forest_container,
forest,
dataset,
residual,
rng,
global_config,
forest_config,
keep_forest,
gfr,
num_threads=-1,
)Sample one iteration of a forest using the specified model and tree sampling algorithm
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| forest_container | ForestContainer | stochtree object storing tree ensembles |
required |
| forest | Forest | stochtree object storing the “active” forest being sampled |
required |
| dataset | Dataset | stochtree dataset object storing covariates / bases / weights |
required |
| residual | Residual | stochtree object storing continuously updated partial / full residual |
required |
| rng | RNG | stochtree object storing C++ random number generator to be used sampling algorithm |
required |
| global_config | GlobalModelConfig | GlobalModelConfig object containing global model parameters and settings |
required |
| forest_config | ForestModelConfig | ForestModelConfig object containing forest model parameters and settings |
required |
| keep_forest | bool | Whether or not the resulting forest should be retained in forest_container or discarded (due to burnin or thinning for example) |
required |
| gfr | bool | Whether or not the “grow-from-root” (GFR) sampler is run (if this is True and leaf_model_int=0 this is equivalent to XBART, if this is FALSE and leaf_model_int=0 this is equivalent to the original BART) |
required |
| num_threads | int | Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user’s system, this will default to 1, otherwise to the maximum number of available threads. |
-1 |
prepare_for_sampler
sampler.ForestSampler.prepare_for_sampler(
dataset,
residual,
forest,
leaf_model,
initial_values,
)Initialize forest and tracking data structures with constant root values before running a sampler
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| dataset | Dataset | stochtree dataset object storing covariates / bases / weights |
required |
| residual | Residual | stochtree object storing continuously updated partial / full residual |
required |
| forest | Forest | stochtree object storing the “active” forest being sampled |
required |
| leaf_model | int | Integer encoding the leaf model type | required |
| initial_values | np.array | Constant root node value(s) at which to initialize forest prediction (internally, it is divided by the number of trees and typically it is 0 for mean models and 1 for variance models). | required |
adjust_residual
sampler.ForestSampler.adjust_residual(
dataset,
residual,
forest,
requires_basis,
add,
)Method that “adjusts” the residual used for training tree ensembles by either adding or subtracting the prediction of each tree to the existing residual.
This is typically run just once at the beginning of a forest sampling algorithm — after trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| dataset | Dataset | stochtree dataset object storing covariates / bases / weights |
required |
| residual | Residual | stochtree object storing continuously updated partial / full residual |
required |
| forest | Forest | stochtree object storing the “active” forest being sampled |
required |
| requires_basis | bool | Whether or not the forest requires a basis dot product when predicting | required |
| add | bool | Whether the predictions of each tree are added (if add=True) or subtracted (add=False) from the outcome to form the new residual |
required |
propagate_basis_update
sampler.ForestSampler.propagate_basis_update(dataset, residual, forest)Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall “function” represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| dataset | Dataset | Stochtree dataset object storing covariates / bases / weights | required |
| residual | Residual | Stochtree object storing continuously updated partial / full residual | required |
| forest | Forest | Stochtree object storing the “active” forest being sampled | required |
get_cached_forest_predictions
sampler.ForestSampler.get_cached_forest_predictions()Extract an internally-cached prediction of a forest on the training dataset in a sampler.
Returns
| Name | Type | Description |
|---|---|---|
| np.array | Numpy 1D array with as many elements as observations in the training dataset |
update_alpha
sampler.ForestSampler.update_alpha(alpha)Update alpha in the tree prior
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| alpha | float | New value of alpha to be used |
required |
update_beta
sampler.ForestSampler.update_beta(beta)Update beta in the tree prior
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| beta | float | New value of beta to be used |
required |
update_min_samples_leaf
sampler.ForestSampler.update_min_samples_leaf(min_samples_leaf)Update min_samples_leaf in the tree prior
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| min_samples_leaf | int | New value of min_samples_leaf to be used |
required |
update_max_depth
sampler.ForestSampler.update_max_depth(max_depth)Update max_depth in the tree prior
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| max_depth | int | New value of max_depth to be used |
required |