BART#
stochtree.bart.BARTModel
#
Class that handles sampling, storage, and serialization of stochastic forest models for supervised learning. The class takes its name from Bayesian Additive Regression Trees, an MCMC sampler originally developed in Chipman, George, McCulloch (2010), but supports several sampling algorithms:
- MCMC: The "classic" sampler defined in Chipman, George, McCulloch (2010). In order to run the MCMC sampler, set
num_gfr = 0
(explained below) and then define a sampler according to several parameters:num_burnin
: the number of iterations to run before "retaining" samples for further analysis. These "burned in" samples are helpful for allowing a sampler to converge before retaining samples.num_chains
: the number of independent sequences of MCMC samples to generate (typically referred to in the literature as "chains")num_mcmc
: the number of "retained" samples of the posterior distributionkeep_every
: after a sampler has "burned in", we will run the sampler forkeep_every
*num_mcmc
iterations, retaining one of eachkeep_every
iteration in a chain.
- GFR (Grow-From-Root): A fast, greedy approximation of the BART MCMC sampling algorithm introduced in He and Hahn (2021). GFR sampler iterations are governed by the
num_gfr
parameter, and there are two primary ways to use this sampler:- Standalone: setting
num_gfr > 0
and bothnum_burnin = 0
andnum_mcmc = 0
will only run and retain GFR samples of the posterior. This is typically referred to as "XBART" (accelerated BART). - Initializer for MCMC: setting
num_gfr > 0
andnum_mcmc > 0
will use ensembles from the GFR algorithm to initializenum_chains
independent MCMC BART samplers, which are run fornum_mcmc
iterations. This is typically referred to as "warm start BART".
- Standalone: setting
In addition to enabling multiple samplers, we support a broad set of models. First, note that the original BART model of Chipman, George, McCulloch (2010) is
In words, there is a nonparametric mean function governed by a tree ensemble with a BART prior and an additive (mean-zero) Gaussian error term, whose variance is parameterized with an inverse gamma prior.
The BARTModel
class supports the following extensions of this model:
- Leaf Regression: Rather than letting
f(X)
define a standard decision tree ensemble, in which each tree usesX
to partition the data and then serve up constant predictions, we allow for modelsf(X,Z)
in whichX
andZ
together define a partitioned linear model (X
partitions the data andZ
serves as the basis for regression models). This model can be run by specifyingbasis_train
in thesample
method. - Heteroskedasticity: Rather than define \(\epsilon\) parameterically, we can let a forest \(\sigma^2(X)\) model a conditional error variance function. This can be done by setting
num_trees_variance > 0
in theparams
dictionary passed to thesample
method.
sample(X_train, y_train, basis_train=None, X_test=None, basis_test=None, num_gfr=5, num_burnin=0, num_mcmc=100, params=None)
#
Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X_train
|
array
|
Training set covariates on which trees may be partitioned. |
required |
y_train
|
array
|
Training set outcome. |
required |
basis_train
|
array
|
Optional training set basis vector used to define a regression to be run in the leaves of each tree. |
None
|
X_test
|
array
|
Optional test set covariates. |
None
|
basis_test
|
array
|
Optional test set basis vector used to define a regression to be run in the leaves of each tree. Must be included / omitted consistently (i.e. if basis_train is provided, then basis_test must be provided alongside X_test). |
None
|
num_gfr
|
int
|
Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to |
5
|
num_burnin
|
int
|
Number of "burn-in" iterations of the MCMC sampler. Defaults to |
0
|
num_mcmc
|
int
|
Number of "retained" iterations of the MCMC sampler. Defaults to |
100
|
params
|
dict
|
Dictionary of model parameters, each of which has a default value.
|
None
|
predict(covariates, basis=None)
#
Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a mean or variance term, or a tuple of prediction arrays, if a BART model includes both.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
covariates
|
array
|
Test set covariates. |
required |
basis
|
array
|
Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. |
None
|
Returns:
Name | Type | Description |
---|---|---|
mu_x |
(array, optional)
|
Mean forest predictions. |
sigma2_x |
(array, optional)
|
Variance forest predictions. |
predict_mean(covariates, basis=None)
#
Predict expected conditional outcome from a BART model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
covariates
|
array
|
Test set covariates. |
required |
basis
|
array
|
Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. |
None
|
Returns:
Type | Description |
---|---|
array
|
Mean forest predictions. |
predict_variance(covariates)
#
Predict expected conditional variance from a BART model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
covariates
|
array
|
Test set covariates. |
required |
Returns:
Type | Description |
---|---|
array
|
Variance forest predictions. |
to_json()
#
Converts a sampled BART model to JSON string representation (which can then be saved to a file or
processed using the json
library)
Returns:
Type | Description |
---|---|
str
|
JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests |
from_json(json_string)
#
Converts a JSON string to an in-memory BART model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
json_string
|
str
|
JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests |
required |
is_sampled()
#
Whether or not a BART model has been sampled.
Returns:
Type | Description |
---|---|
bool
|
|