BART#
stochtree.bart.BARTModel
#
Class that handles sampling, storage, and serialization of stochastic forest models for supervised learning. The class takes its name from Bayesian Additive Regression Trees, an MCMC sampler originally developed in Chipman, George, McCulloch (2010), but supports several sampling algorithms:
- MCMC: The "classic" sampler defined in Chipman, George, McCulloch (2010). In order to run the MCMC sampler, set
num_gfr = 0(explained below) and then define a sampler according to several parameters:num_burnin: the number of iterations to run before "retaining" samples for further analysis. These "burned in" samples are helpful for allowing a sampler to converge before retaining samples.num_chains: the number of independent sequences of MCMC samples to generate (typically referred to in the literature as "chains")num_mcmc: the number of "retained" samples of the posterior distributionkeep_every: after a sampler has "burned in", we will run the sampler forkeep_every*num_mcmciterations, retaining one of eachkeep_everyiteration in a chain.
- GFR (Grow-From-Root): A fast, greedy approximation of the BART MCMC sampling algorithm introduced in He and Hahn (2021). GFR sampler iterations are governed by the
num_gfrparameter, and there are two primary ways to use this sampler:- Standalone: setting
num_gfr > 0and bothnum_burnin = 0andnum_mcmc = 0will only run and retain GFR samples of the posterior. This is typically referred to as "XBART" (accelerated BART). - Initializer for MCMC: setting
num_gfr > 0andnum_mcmc > 0will use ensembles from the GFR algorithm to initializenum_chainsindependent MCMC BART samplers, which are run fornum_mcmciterations. This is typically referred to as "warm start BART".
- Standalone: setting
In addition to enabling multiple samplers, we support a broad set of models. First, note that the original BART model of Chipman, George, McCulloch (2010) is
In words, there is a nonparametric mean function governed by a tree ensemble with a BART prior and an additive (mean-zero) Gaussian error term, whose variance is parameterized with an inverse gamma prior.
The BARTModel class supports the following extensions of this model:
- Leaf Regression: Rather than letting
f(X)define a standard decision tree ensemble, in which each tree usesXto partition the data and then serve up constant predictions, we allow for modelsf(X,Z)in whichXandZtogether define a partitioned linear model (Xpartitions the data andZserves as the basis for regression models). This model can be run by specifyingleaf_basis_trainin thesamplemethod. - Heteroskedasticity: Rather than define \(\epsilon\) parameterically, we can let a forest \(\sigma^2(X)\) model a conditional error variance function. This can be done by setting
num_trees_variance > 0in theparamsdictionary passed to thesamplemethod.
sample(X_train, y_train, leaf_basis_train=None, rfx_group_ids_train=None, rfx_basis_train=None, X_test=None, leaf_basis_test=None, rfx_group_ids_test=None, rfx_basis_test=None, num_gfr=5, num_burnin=0, num_mcmc=100, general_params=None, mean_forest_params=None, variance_forest_params=None, random_effects_params=None, previous_model_json=None, previous_model_warmstart_sample_num=None)
#
Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X_train
|
array
|
Training set covariates on which trees may be partitioned. |
required |
y_train
|
array
|
Training set outcome. |
required |
leaf_basis_train
|
array
|
Optional training set basis vector used to define a regression to be run in the leaves of each tree. |
None
|
rfx_group_ids_train
|
array
|
Optional group labels used for an additive random effects model. |
None
|
rfx_basis_train
|
array
|
Optional basis for "random-slope" regression in an additive random effects model. |
None
|
X_test
|
array
|
Optional test set covariates. |
None
|
leaf_basis_test
|
array
|
Optional test set basis vector used to define a regression to be run in the leaves of each tree. Must be included / omitted consistently (i.e. if leaf_basis_train is provided, then leaf_basis_test must be provided alongside X_test). |
None
|
rfx_group_ids_test
|
array
|
Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. |
None
|
rfx_basis_test
|
array
|
Optional test set basis for "random-slope" regression in additive random effects model. |
None
|
num_gfr
|
int
|
Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to |
5
|
num_burnin
|
int
|
Number of "burn-in" iterations of the MCMC sampler. Defaults to |
0
|
num_mcmc
|
int
|
Number of "retained" iterations of the MCMC sampler. Defaults to |
100
|
general_params
|
dict
|
Dictionary of general model parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
mean_forest_params
|
dict
|
Dictionary of mean forest model parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
variance_forest_params
|
dict
|
Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
random_effects_params
|
dict
|
Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
previous_model_json
|
str
|
JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to |
None
|
previous_model_warmstart_sample_num
|
int
|
Sample number from |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
self |
BARTModel
|
Sampled BART Model. |
predict(covariates, basis=None, rfx_group_ids=None, rfx_basis=None, type='posterior', terms='all', scale='linear')
#
Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a mean or variance term, or a tuple of prediction arrays, if a BART model includes both.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
covariates
|
array
|
Test set covariates. |
required |
basis
|
array
|
Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. |
None
|
rfx_group_ids
|
array
|
Optional group labels used for an additive random effects model. |
None
|
rfx_basis
|
array
|
Optional basis for "random-slope" regression in an additive random effects model. |
None
|
type
|
str
|
Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". |
'posterior'
|
terms
|
str
|
Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return |
'all'
|
scale
|
str
|
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing |
'linear'
|
Returns:
| Type | Description |
|---|---|
Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested.
|
|
compute_contrast(covariates_0, covariates_1, basis_0=None, basis_1=None, rfx_group_ids_0=None, rfx_group_ids_1=None, rfx_basis_0=None, rfx_basis_1=None, type='posterior', scale='linear')
#
Compute a contrast using a BART model by making two sets of outcome predictions and taking their
difference. This function provides the flexibility to compute any contrast of interest by specifying
covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast.
For simplicity, we refer to the subtrahend of the contrast as the "control" or Y0 term and the minuend
of the contrast as the Y1 term, though the requested contrast need not match the "control vs treatment"
terminology of a classic two-treatment causal inference problem. We mirror the function calls and
terminology of the predict.bartmodel function, labeling each prediction data term with a 1 to denote
its contribution to the treatment prediction of a contrast and 0 to denote inclusion in the control prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
covariates_0
|
array or DataFrame
|
Covariates used for prediction in the "control" case. Must be a numpy array or dataframe. |
required |
covariates_1
|
array or DataFrame
|
Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe. |
required |
basis_0
|
array
|
Bases used for prediction in the "control" case (by e.g. dot product with leaf values). |
None
|
basis_1
|
array
|
Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). |
None
|
rfx_group_ids_0
|
array
|
Test set group labels used for prediction from an additive random effects model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a numpy array. |
None
|
rfx_group_ids_1
|
array
|
Test set group labels used for prediction from an additive random effects model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a numpy array. |
None
|
rfx_basis_0
|
array
|
Test set basis for used for prediction from an additive random effects model in the "control" case. |
None
|
rfx_basis_1
|
array
|
Test set basis for used for prediction from an additive random effects model in the "treatment" case. |
None
|
type
|
str
|
Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". |
'posterior'
|
scale
|
str
|
Scale of the contrast. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing |
'linear'
|
Returns:
| Type | Description |
|---|---|
Array, either 1d or 2d depending on whether type = "mean" or "posterior".
|
|
compute_posterior_interval(terms='all', scale='linear', level=0.95, covariates=None, basis=None, rfx_group_ids=None, rfx_basis=None)
#
Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
terms
|
str
|
Character string specifying the model term(s) for which to compute intervals. Options for BART models are |
'all'
|
scale
|
str
|
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing |
'linear'
|
level
|
float
|
A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. |
0.95
|
covariates
|
array
|
Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). |
None
|
basis
|
array
|
Optional array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. |
None
|
rfx_group_ids
|
array
|
Optional vector of group IDs for random effects. Required if the requested term includes random effects. |
None
|
rfx_basis
|
array
|
Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. |
None
|
Returns:
| Type | Description |
|---|---|
dict
|
A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. |
sample_posterior_predictive(covariates=None, basis=None, rfx_group_ids=None, rfx_basis=None, num_draws_per_sample=None)
#
Sample from the posterior predictive distribution for outcomes modeled by BART
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
covariates
|
array
|
An array or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). |
None
|
basis
|
array
|
An array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. |
None
|
rfx_group_ids
|
array
|
An array of group IDs for random effects. Required if the BART model includes random effects. |
None
|
rfx_basis
|
array
|
An array of basis function evaluations for random effects. Required if the BART model includes random effects. |
None
|
num_draws_per_sample
|
int
|
The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). |
None
|
Returns:
| Type | Description |
|---|---|
array
|
A matrix of posterior predictive samples. If |
to_json()
#
Converts a sampled BART model to JSON string representation (which can then be saved to a file or
processed using the json library)
Returns:
| Type | Description |
|---|---|
str
|
JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests |
from_json(json_string)
#
Converts a JSON string to an in-memory BART model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
json_string
|
str
|
JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests |
required |
from_json_string_list(json_string_list)
#
Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object which can be used for prediction, etc...
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
json_string_list
|
list of str
|
List of JSON strings which can be parsed to objects of type |
required |
is_sampled()
#
Whether or not a BART model has been sampled.
Returns:
| Type | Description |
|---|---|
bool
|
|
has_term(term)
#
Whether or not a model includes a term.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
term
|
str
|
Character string specifying the model term to check for. Options for BART models are |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|