BCF#
stochtree.bcf.BCFModel
#
Class that handles sampling, storage, and serialization of stochastic forest models for causal effect estimation. The class takes its name from Bayesian Causal Forests, an MCMC sampler originally developed in Hahn, Murray, Carvalho (2020), but supports several sampling algorithms:
- MCMC: The "classic" sampler defined in Hahn, Murray, Carvalho (2020). In order to run the MCMC sampler,
set
num_gfr = 0(explained below) and then define a sampler according to several parameters:num_burnin: the number of iterations to run before "retaining" samples for further analysis. These "burned in" samples are helpful for allowing a sampler to converge before retaining samples.num_chains: the number of independent sequences of MCMC samples to generate (typically referred to in the literature as "chains")num_mcmc: the number of "retained" samples of the posterior distributionkeep_every: after a sampler has "burned in", we will run the sampler forkeep_every*num_mcmciterations, retaining one of eachkeep_everyiteration in a chain.
- GFR (Grow-From-Root): A fast, greedy approximation of the BART MCMC sampling algorithm introduced in Krantsevich, He, and Hahn (2023). GFR sampler iterations are
governed by the
num_gfrparameter, and there are two primary ways to use this sampler:- Standalone: setting
num_gfr > 0and bothnum_burnin = 0andnum_mcmc = 0will only run and retain GFR samples of the posterior. This is typically referred to as "XBART" (accelerated BART). - Initializer for MCMC: setting
num_gfr > 0andnum_mcmc > 0will use ensembles from the GFR algorithm to initializenum_chainsindependent MCMC BART samplers, which are run fornum_mcmciterations. This is typically referred to as "warm start BART".
- Standalone: setting
In addition to enabling multiple samplers, we support a broad set of models. First, note that the original BCF model of Hahn, Murray, Carvalho (2020) is
for continuous outcome \(y\), binary treatment \(Z\), and covariates \(X\).
In words, there are two nonparametric mean functions -- a "prognostic" function and a "treatment effect" function -- governed by tree ensembles with BART priors and an additive (mean-zero) Gaussian error term, whose variance is parameterized with an inverse gamma prior.
The BCFModel class supports the following extensions of this model:
- Continuous Treatment: If \(Z\) is continuous rather than binary, we define \(b_z(X) = \tau(X, Z) = Z \tau(X)\), where the "leaf model" for the \(\tau\) forest is essentially a regression on continuous \(Z\).
- Heteroskedasticity: Rather than define \(\epsilon\) parameterically, we can let a forest \(\sigma^2(X)\) model a conditional error variance function. This can be done by setting
num_trees_variance > 0in theparamsdictionary passed to thesamplemethod.
sample(X_train, Z_train, y_train, propensity_train=None, rfx_group_ids_train=None, rfx_basis_train=None, X_test=None, Z_test=None, propensity_test=None, rfx_group_ids_test=None, rfx_basis_test=None, num_gfr=5, num_burnin=0, num_mcmc=100, previous_model_json=None, previous_model_warmstart_sample_num=None, general_params=None, prognostic_forest_params=None, treatment_effect_forest_params=None, variance_forest_params=None, random_effects_params=None)
#
Runs a BCF sampler on provided training set. Outcome predictions and estimates of the prognostic and treatment effect functions will be cached for the training set and (if provided) the test set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X_train
|
array or DataFrame
|
Covariates used to split trees in the ensemble. Can be passed as either a matrix or dataframe. |
required |
Z_train
|
array
|
Array of (continuous or binary; univariate or multivariate) treatment assignments. |
required |
y_train
|
array
|
Outcome to be modeled by the ensemble. |
required |
propensity_train
|
array
|
Optional vector of propensity scores. If not provided, this will be estimated from the data. |
None
|
rfx_group_ids_train
|
array
|
Optional group labels used for an additive random effects model. |
None
|
rfx_basis_train
|
array
|
Optional basis for "random-slope" regression in an additive random effects model. |
None
|
X_test
|
array
|
Optional test set of covariates used to define "out of sample" evaluation data. |
None
|
Z_test
|
array
|
Optional test set of (continuous or binary) treatment assignments.
Must be provided if |
None
|
propensity_test
|
array
|
Optional test set vector of propensity scores. If not provided (but |
None
|
rfx_group_ids_test
|
array
|
Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. |
None
|
rfx_basis_test
|
array
|
Optional test set basis for "random-slope" regression in additive random effects model. |
None
|
num_gfr
|
int
|
Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to |
5
|
num_burnin
|
int
|
Number of "burn-in" iterations of the MCMC sampler. Defaults to |
0
|
num_mcmc
|
int
|
Number of "retained" iterations of the MCMC sampler. Defaults to |
100
|
general_params
|
dict
|
Dictionary of general model parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
prognostic_forest_params
|
dict
|
Dictionary of prognostic forest model parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
treatment_effect_forest_params
|
dict
|
Dictionary of treatment effect forest model parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
variance_forest_params
|
dict
|
Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
random_effects_params
|
dict
|
Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional.
|
None
|
previous_model_json
|
str
|
JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to |
None
|
previous_model_warmstart_sample_num
|
int
|
Sample number from |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
self |
BCFModel
|
Sampled BCF Model. |
predict(X, Z, propensity=None, rfx_group_ids=None, rfx_basis=None, type='posterior', terms='all', scale='linear')
#
Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation.
Predicted outcomes are computed as yhat = mu_x + Z*tau_x where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function.
When random effects are present, they are either included in yhat additively if rfx_model_spec == "custom". They are included in mu_x if rfx_model_spec == "intercept_only" or
partially included in mu_x and partially included in tau_x rfx_model_spec == "intercept_plus_treatment".
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
array or DataFrame
|
Test set covariates. |
required |
Z
|
array
|
Test set treatment indicators. |
required |
propensity
|
`np.array`
|
Optional test set propensities. Must be provided if propensities were provided when the model was sampled. |
None
|
rfx_group_ids
|
array
|
Optional group labels used for an additive random effects model. |
None
|
rfx_basis
|
array
|
Optional basis for "random-slope" regression in an additive random effects model. Not necessary if |
None
|
type
|
str
|
Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". |
'posterior'
|
terms
|
str
|
Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return |
'all'
|
scale
|
str
|
Scale on which to return predictions. Options are "linear" (the default), which returns predictions on the original outcome scale, and "probit", which returns predictions on the probit (latent) scale. Only applicable for models fit with |
'linear'
|
Returns:
| Type | Description |
|---|---|
Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested.
|
|
compute_contrast(X_0, X_1, Z_0, Z_1, propensity_0=None, propensity_1=None, rfx_group_ids_0=None, rfx_group_ids_1=None, rfx_basis_0=None, rfx_basis_1=None, type='posterior', scale='linear')
#
Compute a contrast using a BCF model by making two sets of outcome predictions and taking their
difference. This function provides the flexibility to compute any contrast of interest by specifying
covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast.
For simplicity, we refer to the subtrahend of the contrast as the "control" or Y0 term and the minuend
of the contrast as the Y1 term, though the requested contrast need not match the "control vs treatment"
terminology of a classic two-treatment causal inference problem. We mirror the function calls and
terminology of the predict.bartmodel function, labeling each prediction data term with a 1 to denote
its contribution to the treatment prediction of a contrast and 0 to denote inclusion in the control prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X_0
|
array or DataFrame
|
Covariates used for prediction in the "control" case. Must be a numpy array or dataframe. |
required |
X_1
|
array or DataFrame
|
Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe. |
required |
Z_0
|
array
|
Treatments used for prediction in the "control" case. Must be a numpy array or vector. |
required |
Z_1
|
array
|
Treatments used for prediction in the "treatment" case. Must be a numpy array or vector. |
required |
propensity_0
|
`np.array`
|
Propensities used for prediction in the "control" case. Must be a numpy array or vector. |
None
|
propensity_1
|
`np.array`
|
Propensities used for prediction in the "treatment" case. Must be a numpy array or vector. |
None
|
rfx_group_ids_0
|
array
|
Test set group labels used for prediction from an additive random effects model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a numpy array. |
None
|
rfx_group_ids_1
|
array
|
Test set group labels used for prediction from an additive random effects model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a numpy array. |
None
|
rfx_basis_0
|
array
|
Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a numpy array. |
None
|
rfx_basis_1
|
array
|
Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a numpy array. |
None
|
type
|
str
|
Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". |
'posterior'
|
scale
|
str
|
Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing |
'linear'
|
Returns:
| Type | Description |
|---|---|
Array, either 1d or 2d depending on whether type = "mean" or "posterior".
|
|
compute_posterior_interval(X=None, Z=None, propensity=None, rfx_group_ids=None, rfx_basis=None, terms='all', level=0.95, scale='linear')
#
Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
array
|
Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, treatment effect forest, variance forest, or overall predictions). |
None
|
Z
|
array
|
Optional array of treatment assignments. Required if the requested term is |
None
|
propensity
|
array
|
Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. |
None
|
rfx_group_ids
|
array
|
Optional vector of group IDs for random effects. Required if the requested term includes random effects. |
None
|
rfx_basis
|
array
|
Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. |
None
|
terms
|
str
|
Character string specifying the model term(s) for which to compute intervals. Options for BCF models are |
'all'
|
scale
|
str
|
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing |
'linear'
|
level
|
float
|
A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. |
0.95
|
Returns:
| Type | Description |
|---|---|
dict
|
A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. |
sample_posterior_predictive(X, Z, propensity=None, rfx_group_ids=None, rfx_basis=None, num_draws_per_sample=None)
#
Sample from the posterior predictive distribution for outcomes modeled by BART
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
array
|
An array or data frame of covariates. |
required |
Z
|
array
|
An array of treatment assignments. |
required |
propensity
|
array
|
Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. |
None
|
rfx_group_ids
|
array
|
Optional vector of group IDs for random effects. Required if the requested term includes random effects. |
None
|
rfx_basis
|
array
|
Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. |
None
|
num_draws_per_sample
|
int
|
The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). |
None
|
Returns:
| Type | Description |
|---|---|
array
|
A matrix of posterior predictive samples. If |
to_json()
#
Converts a sampled BART model to JSON string representation (which can then be saved to a file or
processed using the json library)
Returns:
| Type | Description |
|---|---|
str
|
JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests |
from_json(json_string)
#
Converts a JSON string to an in-memory BART model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
json_string
|
str
|
JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests |
required |
from_json_string_list(json_string_list)
#
Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object which can be used for prediction, etc...
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
json_string_list
|
list of str
|
List of JSON strings which can be parsed to objects of type |
required |
is_sampled()
#
Whether or not a BCF model has been sampled.
Returns:
| Type | Description |
|---|---|
bool
|
|
has_term(term)
#
Whether or not a model includes a term.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
term
|
str
|
Character string specifying the model term to check for. Options for BCF models are |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|