Skip to content

BART#

stochtree.bart.BARTModel #

Class that handles sampling, storage, and serialization of stochastic forest models for supervised learning. The class takes its name from Bayesian Additive Regression Trees, an MCMC sampler originally developed in Chipman, George, McCulloch (2010), but supports several sampling algorithms:

  • MCMC: The "classic" sampler defined in Chipman, George, McCulloch (2010). In order to run the MCMC sampler, set num_gfr = 0 (explained below) and then define a sampler according to several parameters:
    • num_burnin: the number of iterations to run before "retaining" samples for further analysis. These "burned in" samples are helpful for allowing a sampler to converge before retaining samples.
    • num_chains: the number of independent sequences of MCMC samples to generate (typically referred to in the literature as "chains")
    • num_mcmc: the number of "retained" samples of the posterior distribution
    • keep_every: after a sampler has "burned in", we will run the sampler for keep_every * num_mcmc iterations, retaining one of each keep_every iteration in a chain.
  • GFR (Grow-From-Root): A fast, greedy approximation of the BART MCMC sampling algorithm introduced in He and Hahn (2021). GFR sampler iterations are governed by the num_gfr parameter, and there are two primary ways to use this sampler:
    • Standalone: setting num_gfr > 0 and both num_burnin = 0 and num_mcmc = 0 will only run and retain GFR samples of the posterior. This is typically referred to as "XBART" (accelerated BART).
    • Initializer for MCMC: setting num_gfr > 0 and num_mcmc > 0 will use ensembles from the GFR algorithm to initialize num_chains independent MCMC BART samplers, which are run for num_mcmc iterations. This is typically referred to as "warm start BART".

In addition to enabling multiple samplers, we support a broad set of models. First, note that the original BART model of Chipman, George, McCulloch (2010) is

\[\begin{equation*} \begin{aligned} y &= f(X) + \epsilon\\ f(X) &\sim \text{BART}(\cdot)\\ \epsilon &\sim N(0, \sigma^2)\\ \sigma^2 &\sim IG(\nu, \nu\lambda) \end{aligned} \end{equation*}\]

In words, there is a nonparametric mean function governed by a tree ensemble with a BART prior and an additive (mean-zero) Gaussian error term, whose variance is parameterized with an inverse gamma prior.

The BARTModel class supports the following extensions of this model:

  • Leaf Regression: Rather than letting f(X) define a standard decision tree ensemble, in which each tree uses X to partition the data and then serve up constant predictions, we allow for models f(X,Z) in which X and Z together define a partitioned linear model (X partitions the data and Z serves as the basis for regression models). This model can be run by specifying basis_train in the sample method.
  • Heteroskedasticity: Rather than define \(\epsilon\) parameterically, we can let a forest \(\sigma^2(X)\) model a conditional error variance function. This can be done by setting num_trees_variance > 0 in the params dictionary passed to the sample method.

sample(X_train, y_train, basis_train=None, X_test=None, basis_test=None, num_gfr=5, num_burnin=0, num_mcmc=100, params=None) #

Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis.

Parameters:

Name Type Description Default
X_train array

Training set covariates on which trees may be partitioned.

required
y_train array

Training set outcome.

required
basis_train array

Optional training set basis vector used to define a regression to be run in the leaves of each tree.

None
X_test array

Optional test set covariates.

None
basis_test array

Optional test set basis vector used to define a regression to be run in the leaves of each tree. Must be included / omitted consistently (i.e. if basis_train is provided, then basis_test must be provided alongside X_test).

None
num_gfr int

Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to 5.

5
num_burnin int

Number of "burn-in" iterations of the MCMC sampler. Defaults to 0. Ignored if num_gfr > 0.

0
num_mcmc int

Number of "retained" iterations of the MCMC sampler. Defaults to 100. If this is set to 0, GFR (XBART) samples will be retained.

100
params dict

Dictionary of model parameters, each of which has a default value.

  • cutpoint_grid_size (int): Maximum number of cutpoints to consider for each feature. Defaults to 100.
  • sigma_leaf (float): Scale parameter on the (conditional mean) leaf node regression model.
  • alpha_mean (float): Prior probability of splitting for a tree of depth 0 in the conditional mean model. Tree split prior combines alpha_mean and beta_mean via alpha_mean*(1+node_depth)^-beta_mean.
  • beta_mean (float): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional mean model. Tree split prior combines alpha_mean and beta_mean via alpha_mean*(1+node_depth)^-beta_mean.
  • min_samples_leaf_mean (int): Minimum allowable size of a leaf, in terms of training samples, in the conditional mean model. Defaults to 5.
  • max_depth_mean (int): Maximum depth of any tree in the ensemble in the conditional mean model. Defaults to 10. Can be overriden with -1 which does not enforce any depth limits on trees.
  • alpha_variance (float): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines alpha_variance and beta_variance via alpha_variance*(1+node_depth)^-beta_variance.
  • beta_variance (float): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines alpha_variance and beta_variance via alpha_variance*(1+node_depth)^-beta_variance.
  • min_samples_leaf_variance (int): Minimum allowable size of a leaf, in terms of training samples in the conditional variance model. Defaults to 5.
  • max_depth_variance (int): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to 10. Can be overriden with -1 which does not enforce any depth limits on trees.
  • a_global (float): Shape parameter in the IG(a_global, b_global) global error variance model. Defaults to 0.
  • b_global (float): Scale parameter in the IG(a_global, b_global) global error variance prior. Defaults to 0.
  • a_leaf (float): Shape parameter in the IG(a_leaf, b_leaf) leaf node parameter variance model. Defaults to 3.
  • b_leaf (float): Scale parameter in the IG(a_leaf, b_leaf) leaf node parameter variance model. Calibrated internally as 0.5/num_trees_mean if not set here.
  • a_forest (float): Shape parameter in the [optional] IG(a_forest, b_forest) conditional error variance forest (which is only sampled if num_trees_variance > 0). Calibrated internally as num_trees_variance / 1.5^2 + 0.5 if not set here.
  • b_forest (float): Scale parameter in the [optional] IG(a_forest, b_forest) conditional error variance forest (which is only sampled if num_trees_variance > 0). Calibrated internally as num_trees_variance / 1.5^2 if not set here.
  • sigma2_init (float): Starting value of global variance parameter. Set internally as a percentage of the standardized outcome variance if not set here.
  • variance_forest_leaf_init (float): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as np.log(pct_var_variance_forest_init*np.var((y-np.mean(y))/np.std(y)))/num_trees_variance if not set.
  • pct_var_sigma2_init (float): Percentage of standardized outcome variance used to initialize global error variance parameter. Superseded by sigma2. Defaults to 1.
  • pct_var_variance_forest_init (float): Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 1. Superseded by variance_forest_init.
  • variance_scale (float): Variance after the data have been scaled. Default: 1.
  • variable_weights_mean (np.array): Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of X_train if not provided.
  • variable_weights_variance (np.array): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of X_train if not provided.
  • num_trees_mean (int): Number of trees in the ensemble for the conditional mean model. Defaults to 200. If num_trees_mean = 0, the conditional mean will not be modeled using a forest and the function will only proceed if num_trees_variance > 0.
  • num_trees_variance (int): Number of trees in the ensemble for the conditional variance model. Defaults to 0. Variance is only modeled using a tree / forest if num_trees_variance > 0.
  • sample_sigma_global (bool): Whether or not to update the sigma^2 global error variance parameter based on IG(a_global, b_global). Defaults to True.
  • sample_sigma_leaf (bool): Whether or not to update the tau leaf scale variance parameter based on IG(a_leaf, b_leaf). Cannot (currently) be set to true if basis_train has more than one column. Defaults to False.
  • random_seed (int): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to std::random_device.
  • keep_burnin (bool): Whether or not "burnin" samples should be included in predictions. Defaults to False. Ignored if num_mcmc == 0.
  • keep_gfr (bool): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to False. Ignored if num_mcmc == 0.
  • num_chains (int): How many independent MCMC chains should be sampled. If num_mcmc = 0, this is ignored. If num_gfr = 0, then each chain is run from root for num_mcmc * keep_every + num_burnin iterations, with num_mcmc samples retained. If num_gfr > 0, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that num_gfr >= num_chains. Default: 1.
  • keep_every (int): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to 1. Setting keep_every = k for some k > 1 will "thin" the MCMC samples by retaining every k-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
None

predict(covariates, basis=None) #

Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a mean or variance term, or a tuple of prediction arrays, if a BART model includes both.

Parameters:

Name Type Description Default
covariates array

Test set covariates.

required
basis array

Optional test set basis vector, must be provided if the model was trained with a leaf regression basis.

None

Returns:

Name Type Description
mu_x (array, optional)

Mean forest predictions.

sigma2_x (array, optional)

Variance forest predictions.

predict_mean(covariates, basis=None) #

Predict expected conditional outcome from a BART model.

Parameters:

Name Type Description Default
covariates array

Test set covariates.

required
basis array

Optional test set basis vector, must be provided if the model was trained with a leaf regression basis.

None

Returns:

Type Description
array

Mean forest predictions.

predict_variance(covariates) #

Predict expected conditional variance from a BART model.

Parameters:

Name Type Description Default
covariates array

Test set covariates.

required

Returns:

Type Description
array

Variance forest predictions.

to_json() #

Converts a sampled BART model to JSON string representation (which can then be saved to a file or processed using the json library)

Returns:

Type Description
str

JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests

from_json(json_string) #

Converts a JSON string to an in-memory BART model.

Parameters:

Name Type Description Default
json_string str

JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests

required

is_sampled() #

Whether or not a BART model has been sampled.

Returns:

Type Description
bool

True if a BART model has been sampled, False otherwise