Skip to content

Sampler API#

stochtree.sampler.ForestSampler #

Wrapper around many of the core C++ sampling data structures and algorithms.

Parameters:

Name Type Description Default
dataset Dataset

stochtree dataset object storing covariates / bases / weights

required
feature_types array

Array of integer-coded values indicating the column type of each feature in dataset. Integer codes map 0 to "numeric" (continuous), 1 to "ordered categorical, and 2 to "unordered categorical".

required
num_trees int

Number of trees in the forest model that this sampler class will fit.

required
num_obs int

Number of observations / "rows" in dataset.

required
alpha float

Prior probability of splitting for a tree of depth 0 in a forest model. Tree split prior combines alpha and beta via alpha*(1+node_depth)^-beta.

required
beta float

Exponent that decreases split probabilities for nodes of depth > 0 in a forest model. Tree split prior combines alpha and beta via alpha*(1+node_depth)^-beta.

required
min_samples_leaf int

Minimum allowable size of a leaf, in terms of training samples, in a forest model.

required
max_depth int

Maximum depth of any tree in the ensemble in a forest model.

-1

reconstitute_from_forest(forest, dataset, residual, is_mean_model) #

Re-initialize a forest sampler tracking data structures from a specific forest in a ForestContainer

Parameters:

Name Type Description Default
dataset Dataset

stochtree dataset object storing covariates / bases / weights

required
residual Residual

stochtree object storing continuously updated partial / full residual

required
forest Forest

stochtree object storing tree ensemble

required
is_mean_model bool

Indicator of whether the model being updated a conditional mean model (True) or a conditional variance model (False)

required

sample_one_iteration(forest_container, forest, dataset, residual, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, gfr, pre_initialized) #

Sample one iteration of a forest using the specified model and tree sampling algorithm

Parameters:

Name Type Description Default
forest_container ForestContainer

stochtree object storing tree ensembles

required
forest Forest

stochtree object storing the "active" forest being sampled

required
dataset Dataset

stochtree dataset object storing covariates / bases / weights

required
residual Residual

stochtree object storing continuously updated partial / full residual

required
rng RNG

stochtree object storing C++ random number generator to be used sampling algorithm

required
feature_types array

Array of integer-coded feature types (0 = numeric, 1 = ordered categorical, 2 = unordered categorical)

required
cutpoint_grid_size int

Maximum size of a grid of available cutpoints (which thins the number of possible splits, particularly useful in the grow-from-root algorithm)

required
leaf_model_scale_input array

Numpy array containing leaf model scale parameter (if the leaf model is univariate, this is essentially a scalar which is used as such in the C++ source, but stored as a numpy array)

required
variable_weights array

Numpy array containing sampling probabilities for each feature

required
a_forest float

Shape parameter for the inverse gamma outcome model for a heteroskedasticity forest

required
b_forest float

Scale parameter for the inverse gamma outcome model for a heteroskedasticity forest

required
global_variance float

Current value of the global error variance parameter

required
leaf_model_int int

Integer encoding the leaf model type (0 = constant Gaussian leaf mean model, 1 = univariate Gaussian leaf regression mean model, 2 = multivariate Gaussian leaf regression mean model, 3 = univariate Inverse Gamma constant leaf variance model)

required
keep_forest bool

Whether or not the resulting forest should be retained in forest_container or discarded (due to burnin or thinning for example)

required
gfr bool

Whether or not the "grow-from-root" (GFR) sampler is run (if this is True and leaf_model_int=0 this is equivalent to XBART, if this is FALSE and leaf_model_int=0 this is equivalent to the original BART)

required
pre_initialized bool

Whether or not the forest being sampled has already been initialized

required

prepare_for_sampler(dataset, residual, forest, leaf_model, initial_values) #

Initialize forest and tracking data structures with constant root values before running a sampler

Parameters:

Name Type Description Default
dataset Dataset

stochtree dataset object storing covariates / bases / weights

required
residual Residual

stochtree object storing continuously updated partial / full residual

required
forest Forest

stochtree object storing the "active" forest being sampled

required
leaf_model int

Integer encoding the leaf model type

required
initial_values array

Constant root node value(s) at which to initialize forest prediction (internally, it is divided by the number of trees and typically it is 0 for mean models and 1 for variance models).

required

adjust_residual(dataset, residual, forest, requires_basis, add) #

Method that "adjusts" the residual used for training tree ensembles by either adding or subtracting the prediction of each tree to the existing residual.

This is typically run just once at the beginning of a forest sampling algorithm --- after trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.

Parameters:

Name Type Description Default
dataset Dataset

stochtree dataset object storing covariates / bases / weights

required
residual Residual

stochtree object storing continuously updated partial / full residual

required
forest Forest

stochtree object storing the "active" forest being sampled

required
requires_basis bool

Whether or not the forest requires a basis dot product when predicting

required
add bool

Whether the predictions of each tree are added (if add=True) or subtracted (add=False) from the outcome to form the new residual

required

propagate_basis_update(dataset, residual, forest) #

Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.

This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall "function" represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run.

Parameters:

Name Type Description Default
dataset Dataset

Stochtree dataset object storing covariates / bases / weights

required
residual Residual

Stochtree object storing continuously updated partial / full residual

required
forest Forest

Stochtree object storing the "active" forest being sampled

required

update_alpha(alpha) #

Update alpha in the tree prior

Parameters:

Name Type Description Default
alpha float

New value of alpha to be used

required

update_beta(beta) #

Update beta in the tree prior

Parameters:

Name Type Description Default
beta float

New value of beta to be used

required

update_min_samples_leaf(min_samples_leaf) #

Update min_samples_leaf in the tree prior

Parameters:

Name Type Description Default
min_samples_leaf int

New value of min_samples_leaf to be used

required

update_max_depth(max_depth) #

Update max_depth in the tree prior

Parameters:

Name Type Description Default
max_depth int

New value of max_depth to be used

required

stochtree.sampler.GlobalVarianceModel #

Wrapper around methods / functions for sampling a "global" error variance model with inverse gamma prior.

sample_one_iteration(residual, rng, a, b) #

Sample one iteration of a global error variance parameter

Parameters:

Name Type Description Default
residual Residual

stochtree object storing continuously updated partial / full residual

required
rng RNG

stochtree object storing C++ random number generator to be used sampling algorithm

required
a float

Shape parameter for the inverse gamma error variance model

required
b float

Scale parameter for the inverse gamma error variance model

required

Returns:

Type Description
float

One draw from a Gibbs sampler for the error variance model, which depends on the rest of the model only through the "full" residual stored in a Residual object (net of predictions of any mean term such as a forest or an additive parametric fixed / random effect term).

stochtree.sampler.LeafVarianceModel #

Wrapper around methods / functions for sampling a "leaf scale" model for the variance term of a Gaussian leaf model with inverse gamma prior.

sample_one_iteration(forest, rng, a, b) #

Sample one iteration of a forest leaf model's variance parameter (assuming a location-scale leaf model, most commonly N(0, tau))

Parameters:

Name Type Description Default
forest Forest

stochtree object storing the "active" forest being sampled

required
rng RNG

stochtree object storing C++ random number generator to be used sampling algorithm

required
a float

Shape parameter for the inverse gamma leaf scale model

required
b float

Scale parameter for the inverse gamma leaf scale model

required

Returns:

Type Description
float

One draw from a Gibbs sampler for the leaf scale model, which depends on the rest of the model only through its respective forest.