BCF#
stochtree.bcf.BCFModel
#
Class that handles sampling, storage, and serialization of stochastic forest models for causal effect estimation. The class takes its name from Bayesian Causal Forests, an MCMC sampler originally developed in Hahn, Murray, Carvalho (2020), but supports several sampling algorithms:
- MCMC: The "classic" sampler defined in Hahn, Murray, Carvalho (2020). In order to run the MCMC sampler,
set
num_gfr = 0
(explained below) and then define a sampler according to several parameters:num_burnin
: the number of iterations to run before "retaining" samples for further analysis. These "burned in" samples are helpful for allowing a sampler to converge before retaining samples.num_chains
: the number of independent sequences of MCMC samples to generate (typically referred to in the literature as "chains")num_mcmc
: the number of "retained" samples of the posterior distributionkeep_every
: after a sampler has "burned in", we will run the sampler forkeep_every
*num_mcmc
iterations, retaining one of eachkeep_every
iteration in a chain.
- GFR (Grow-From-Root): A fast, greedy approximation of the BART MCMC sampling algorithm introduced in Krantsevich, He, and Hahn (2023). GFR sampler iterations are
governed by the
num_gfr
parameter, and there are two primary ways to use this sampler:- Standalone: setting
num_gfr > 0
and bothnum_burnin = 0
andnum_mcmc = 0
will only run and retain GFR samples of the posterior. This is typically referred to as "XBART" (accelerated BART). - Initializer for MCMC: setting
num_gfr > 0
andnum_mcmc > 0
will use ensembles from the GFR algorithm to initializenum_chains
independent MCMC BART samplers, which are run fornum_mcmc
iterations. This is typically referred to as "warm start BART".
- Standalone: setting
In addition to enabling multiple samplers, we support a broad set of models. First, note that the original BCF model of Hahn, Murray, Carvalho (2020) is
for continuous outcome \(y\), binary treatment \(Z\), and covariates \(X\).
In words, there are two nonparametric mean functions -- a "prognostic" function and a "treatment effect" function -- governed by tree ensembles with BART priors and an additive (mean-zero) Gaussian error term, whose variance is parameterized with an inverse gamma prior.
The BCFModel
class supports the following extensions of this model:
- Continuous Treatment: If \(Z\) is continuous rather than binary, we define \(b_z(X) = \tau(X, Z) = Z \tau(X)\), where the "leaf model" for the \(\tau\) forest is essentially a regression on continuous \(Z\).
- Heteroskedasticity: Rather than define \(\epsilon\) parameterically, we can let a forest \(\sigma^2(X)\) model a conditional error variance function. This can be done by setting
num_trees_variance > 0
in theparams
dictionary passed to thesample
method.
sample(X_train, Z_train, y_train, pi_train=None, X_test=None, Z_test=None, pi_test=None, num_gfr=5, num_burnin=0, num_mcmc=100, params=None)
#
Runs a BCF sampler on provided training set. Outcome predictions and estimates of the prognostic and treatment effect functions will be cached for the training set and (if provided) the test set.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X_train
|
array or DataFrame
|
Covariates used to split trees in the ensemble. Can be passed as either a matrix or dataframe. |
required |
Z_train
|
array
|
Array of (continuous or binary; univariate or multivariate) treatment assignments. |
required |
y_train
|
array
|
Outcome to be modeled by the ensemble. |
required |
pi_train
|
array
|
Optional vector of propensity scores. If not provided, this will be estimated from the data. |
None
|
X_test
|
array
|
Optional test set of covariates used to define "out of sample" evaluation data. |
None
|
Z_test
|
array
|
Optional test set of (continuous or binary) treatment assignments.
Must be provided if |
None
|
pi_test
|
array
|
Optional test set vector of propensity scores. If not provided (but |
None
|
num_gfr
|
int
|
Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to |
5
|
num_burnin
|
int
|
Number of "burn-in" iterations of the MCMC sampler. Defaults to |
0
|
num_mcmc
|
int
|
Number of "retained" iterations of the MCMC sampler. Defaults to |
100
|
params
|
dict
|
Dictionary of model parameters, each of which has a default value.
|
None
|
predict_tau(X, Z, propensity=None)
#
Predict CATE function for every provided observation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X
|
array or DataFrame
|
Test set covariates. |
required |
Z
|
array
|
Test set treatment indicators. |
required |
propensity
|
array
|
Optional test set propensities. Must be provided if propensities were provided when the model was sampled. |
None
|
Returns:
Type | Description |
---|---|
array
|
Array with as many rows as in |
predict_variance(covariates, propensity=None)
#
Predict expected conditional variance from a BART model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
covariates
|
array
|
Test set covariates. |
required |
propensity
|
array
|
Test set propensity scores. Optional (not currently used in variance forests). |
None
|
Returns:
Type | Description |
---|---|
array
|
Array of predictions corresponding to the variance forest. Each array will contain as many rows as in |
predict(X, Z, propensity=None)
#
Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation.
Predicted outcomes are computed as yhat = mu_x + Z*tau_x
where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X
|
array or DataFrame
|
Test set covariates. |
required |
Z
|
array
|
Test set treatment indicators. |
required |
propensity
|
`np.array`
|
Optional test set propensities. Must be provided if propensities were provided when the model was sampled. |
None
|
Returns:
Name | Type | Description |
---|---|---|
tau_x |
array
|
Conditional average treatment effect (CATE) samples for every observation provided. |
mu_x |
array
|
Prognostic effect samples for every observation provided. |
yhat_x |
array
|
Outcome prediction samples for every observation provided. |
sigma2_x |
(array, optional)
|
Variance forest samples for every observation provided. Only returned if the model includes a heteroskedasticity forest. |
to_json()
#
Converts a sampled BART model to JSON string representation (which can then be saved to a file or
processed using the json
library)
Returns:
Type | Description |
---|---|
str
|
JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests |
from_json(json_string)
#
Converts a JSON string to an in-memory BART model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
json_string
|
str
|
JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests |
required |
is_sampled()
#
Whether or not a BCF model has been sampled.
Returns:
Type | Description |
---|---|
bool
|
|