Causal Inference in StochTree

Introduction

The bcf R function and the BCFModel Python class are the primary interface for causal effect estimation in stochtree. The primary references for these models are the original BCF paper (Hahn et al. (2020)) and the XBCF paper (Krantsevich et al. (2023)), which adapts the BCF Metropolis-Hastings sampler to use the “grow-from-root” algorithm.

Model

Original BCF

The simplest version of the BCF model of Hahn et al. (2020) takes outcome (\(y_i\)), binary treatment (\(Z_i\)) and covariate (\(X_i\)) data and models the outcome, conditional on treatment and covariates, as

\[ \begin{aligned} y_i \mid X_i = x_i, Z_i = z_i &\sim \mathcal{N}\left(f_0(x_i) + \tau(x_i) z_i, \sigma^2\right)\\ f_0 &\sim \text{BART}(\alpha_0, \beta_0, m_0)\\ \tau &\sim \text{BART}(\alpha_{\tau}, \beta_{\tau}, m_{\tau})\\ \sigma^2 &\sim \text{IG}(a,b).\\ \end{aligned} \]

(See the BART page for a detailed overview of the BART prior.) The \(f_0\) forest is typically referred to as the “prognostic” forest and the \(\tau\) forest is typically referred to as the “treatment effect” forest.

stochtree fits this model by default by specifying the following terms in stochtree::bcf() in R or BCFModel.sample() in Python:

Data R Python
\(X\) X_train X_train
\(Z\) Z_train Z_train
\(y\) y_train y_train

Propensities as Covariates

Hahn et al. (2020) recommend using a propensity score (i.e. the conditional probability of being treated, given covariates, \(p(Z=1 \mid X)\)) as a covariate in the \(f_0\) model to guard against regularization-induced confounding and this is the default in stochtree, though it can be modified by setting propensity_covariate in the general_params list to any of

  1. "none": Propensity score used in neither \(f_0\) nor \(\tau\)
  2. "prognostic": Propensity score used in \(f_0\) but not \(\tau\) (this is the default in stochtree)
  3. "treatment_effect": Propensity score used in \(\tau\) but not \(f_0\)
  4. "both": Propensity score used in both \(f_0\) and \(\tau\)

Unless propensity_covariate is set to "none", stochtree needs a vector / matrix of propensities to pass through to the relevant forests. Users can pass propensities estimated beforehand (using any model, it doesn’t need to be BART-based or Bayesian) via the propensity_train argument to stochtree::bcf() / BCFModel.sample(). In the absence of user-provided propensities, stochtree will internally fit a BART model of \(Z\) given \(X\) to estimate propensities for the prognostic / treatment effect forests.

Adaptive Coding

Hahn et al. (2020) note that the binary treatment \(Z \in \left\{0,1\right\}\) can be “re-coded” as \(Z \in \left\{b_0,b_1\right\}\)

\[ \begin{aligned} y_i \mid X_i = x_i, Z_i = z_i &\sim \mathcal{N}\left(f_0(x_i) + \tau(x_i) \left[b_1 z_i + b_0 (1-z_i)\right], \sigma^2\right)\\ f_0 &\sim \text{BART}(\alpha_0, \beta_0, m_0)\\ \tau &\sim \text{BART}(\alpha_{\tau}, \beta_{\tau}, m_{\tau})\\ \sigma^2 &\sim \text{IG}(a,b).\\ \end{aligned} \]

Specifying prior of \(b_0, b_1 \sim \mathcal{N}(0,1/2)\) ensures the CATE function \((b_1 - b_0) \tau(X)\) effectively scales the \(\tau\) forest by a standard normal parameter. \(b_0\) and \(b_1\) can be updated, conditional on \(f_0\), \(\tau\) and \(\sigma^2\), as a simple linear regression of \(y - f_0\) on \(\left[z \tau, (1-z)\tau\right]\).

This option is only available for binary treatments and can be enabled by setting adaptive_coding = T / "adaptive_coding": True in the general_params list / dictionary (it is switched off by default).

Parametric Treatment Effect Term

In the standard BCF model, \[ \begin{aligned} y_i \mid X_i = x_i, Z_i = z_i &\sim \mathcal{N}\left(f_0(x_i) + \tau(x_i) z_i, \sigma^2\right)\\ f_0 &\sim \text{BART}(\alpha_0, \beta_0, m_0)\\ \tau &\sim \text{BART}(\alpha_{\tau}, \beta_{\tau}, m_{\tau})\\ \sigma^2 &\sim \text{IG}(a,b),\\ \end{aligned} \] we can decomposite the treatment effect function, \(\tau\), into a parametric homogeneous treatment effect and a covariate-driven, heterogeneous treatment effect \[ \begin{aligned} y_i \mid X_i = x_i, Z_i = z_i &\sim \mathcal{N}\left(f_0(x_i) + \left[\tau_0 + t(x_i)\right] z_i, \sigma^2\right)\\ f_0 &\sim \text{BART}(\alpha_0, \beta_0, m_0)\\ t &\sim \text{BART}(\alpha_{\tau}, \beta_{\tau}, m_{\tau})\\ \tau_0 &\sim \mathcal{N}(0, \sigma^2_{\tau})\\ \sigma^2 &\sim \text{IG}(a,b),\\ \end{aligned} \]

\(\tau_0\) is sampled via a straightforward Gibbs update, regressing \(y - f_0 - t(X) Z\) on Z. This “treatment intercept” option is enabled by default in stochtree and is controlled by the sample_intercept option in treatment_effect_forest_params. \(\sigma^2_{\tau}\) is automatically calibrated internally based on the outcome variance, but can be directly specified via the tau_0_prior_var parameter in treatment_effect_forest_params.

Forest-based Heteroskedasticity

Both of the above models are homoskedastic, where each observation is assumed to have the same error variance \(\sigma^2\). We can relax this assumption by modeling the variance function with a log-linear forest

\[ \begin{aligned} y_i \mid X_i = x_i &\sim \mathcal{N}\left(f_0(x_i) + z_i \tau(x_i), \sigma^2_0 \exp{h(x_i)}\right)\\ f_0 &\sim \text{BART}(\alpha_0, \beta_0, m_0) \\ \tau &\sim \text{BART}(\alpha_{\tau}, \beta_{\tau}, m_{\tau}) \\ h &\sim \text{logBART}(\alpha_h, \beta_h, m_h) \end{aligned} \]

The \(\text{logBART}\) prior for \(h\) employs the same \(p(split)\) prior as \(\text{BART}\), but places a log inverse-gamma prior on the leaf parameters \(\lambda_{s,j}\) \[ \exp\left(\lambda_{s,j}\right) \sim \text{IG}\left(a,b\right) \]

stochtree exposes a variance_forest_params list / dict in both R and Python, which defaults to num_trees = 0 / 'num_trees': 0 and omits the forest-based heteroskedasticity term. Enabling forest-based heteroskedastic is as straightforward as setting num_trees in the variance forest parameter list to an integer greater than zero. Other prior terms, such as \(\alpha_h\) and \(\beta_h\) can also be set via the variance_forest_params list / dict.

For more details on this model, see Murray (2021) or Pratola et al. (2020).

Additive Random Effects

Random effects models are a massive topic with lots of field-specific notation and terminology (see wikipedia as a starting point). We allow for an additive random effects term to be specified directly through the BART interface in R and Python. Notationally, if we think about our data having a hierarchical group structure, so that each observation \(i\) belongs to a group \(j\), where group \(j\) will generally contain multiple observations, then we model

\[ \begin{aligned} y_{i,j} \mid X_{i,j} = x_{i,j} &\sim \mathcal{N}\left(W_{i,j} \vec{\gamma}_j + f_0(x_{i,j}) + \tau(x_{i,j}) z_{i,j}, \sigma^2\right)\\ f_0 &\sim \text{BART}(\alpha_0, \beta_0, m_0) \\ \tau &\sim \text{BART}(\alpha_{\tau}, \beta_{\tau}, m_{\tau}) \\ \sigma^2 &\sim \text{IG}(a,b)\\ \vec{\gamma}_j &\sim \mathrm{N}(\vec{0}, \sigma^2_{\gamma} I),\\ \end{aligned} \]

where \(\vec{\gamma}_j\) is group \(j\)’s regression parameters on basis \(W\). The prior on variance components \(\sigma^2_{\gamma}\) is specified via the “redundant parameterization” of Gelman et al. (2013), which splits out a “working parameter” \(\alpha\) from a “group parameter” \(\xi\) for better convergence properties: \[ \begin{aligned} \gamma_j &= \alpha \xi_j,\\ \alpha &\sim \mathcal{N}(\mu_{\alpha}, \sigma^2_{\alpha}),\\ \xi_j &\sim \mathcal{N}(\mu_{\xi,j}, \sigma^2_{\xi}),\\ \sigma^2_{\alpha} &= 1,\\ \sigma^2_{\xi} &\sim \text{IG}(a_{\xi},b_{\xi}).\\ \end{aligned} \]

Fitting a model with additive random effects requires passing group labels (\(j\) above) and a basis (\(W\) above) to stochtree::bart() in R or BARTModel.sample() in Python as below:

Data R Python
\(X\) X_train X_train
\(Z\) Z_train Z_train
\(j\) rfx_group_ids_train rfx_group_ids_train
\(W\) rfx_basis_train rfx_basis_train
\(y\) y_train y_train

Each of \(\mu_{\alpha}\), \(\sigma^2_{\alpha}\), \(\mu_{\xi,j}\), \(\sigma^2_{\xi}\), \(a_{\xi}\), and \(b_{\xi}\) can be set in the random_effects_params list / dictionary with the following mapping

Prior Term R Python
\(\mu_{\alpha}\) working_parameter_prior_mean working_parameter_prior_mean
\(\sigma^2_{\alpha}\) group_parameters_prior_mean group_parameter_prior_mean
\(\mu_{\xi,j}\) working_parameter_prior_cov working_parameter_prior_cov
\(\sigma^2_{\xi}\) group_parameter_prior_cov group_parameter_prior_cov
\(a_{\xi}\) variance_prior_shape variance_prior_shape
\(b_{\xi}\) variance_prior_scale variance_prior_scale

Finally, two common enough use cases for additive random effects in BCF is to model “random intercepts”, where \(\vec{\gamma}_j = \gamma_j\) is a single group-specific intercept (and \(W = \vec{1}\)), or “random intercept plus random slope on the treatment”. We allow users to skip manually constructing bases in either of these two cases, and instead set model_spec to intercept_only or intercept_plus_treatment in the random_effects_params list / dictionary. This setting propagates through to the predict.bcfmodel() / BCFModel.predict() methods so that users do not have to pass a basis when predicting from a model fit this way (group IDs are still required, of course).

Sampler

MCMC

The “classic” BART model of Chipman et al. (2010) was sampled via MCMC: Gibbs draws for parametric terms like \(\sigma^2\) and a Metropolis-Hastings algorithm for the tree structure that accepts / rejects proposed modifications to the trees, one at a time.

We control how long to run this algorithm and how its draws should be retained through two core arguments:

  1. num_mcmc: the number of MCMC samples to run and retain for later analysis (defaults to 100)
  2. num_burnin: the number of MCMC samples to immediately discard after drawing (i.e. “burning in” a sampler, this defaults to 0)

The “burn-in” phenomenon is a practice used to allow the MCMC sampler to converge to higher probability regions of the posterior before retaining MCMC draws. There are two further arguments in general_params that govern the behavior of the post-burnin MCMC sampler:

  1. num_chains: the number of independent runs of the MCMC sampler (defaults to 1)
  2. keep_every: how frequently an MCMC draw should be retained (defaults to 1)

Taken together, these parameters tell us that a sampler will run for num_burnin + num_mcmc * num_chains * keep_every iterations.

Grow-From-Root

“Burn-in” is a common practice with complex MCMC samplers (such as BART) to allow for convergence. He and Hahn (2023) developed the “grow-from-root” (GFR) algorithm as an alternative to BART’s Metropolis-Hastings sampler. On it’s own, the GFR algorithm sampling the classic BART model defines a new method called “XBART”, but He and Hahn (2023) note that XBART samples can be used to initialize an MCMC chain, and samples taken via GFR tend to converge to high probability regions much faster than an MCMC chain.

stochtree is built to support this “warm-start” BART procedure out of the box, and number of GFR iterations to run before MCMC chains is specified by the num_gfr argument. If a user specifies num_chains > 1, then different GFR samples will be used to initialize each GFR iteration.

References

Chipman, Hugh A., Edward I. George, and Robert E. McCulloch. 2010. BART: Bayesian additive regression trees.” The Annals of Applied Statistics 4 (1): 266–98. https://doi.org/10.1214/09-AOAS285.
Gelman, Andrew, John B Carlin, Hal S Stern, David B Dunson, Aki Vehtari, and Donald B Rubin. 2013. Bayesian Data Analysis. Third. Chapman; Hall/CRC.
Hahn, P Richard, Jared S Murray, and Carlos M Carvalho. 2020. “Bayesian Regression Tree Models for Causal Inference: Regularization, Confounding, and Heterogeneous Effects (with Discussion).” Bayesian Analysis 15 (3): 965–1056.
He, Jingyu, and P Richard Hahn. 2023. “Stochastic Tree Ensembles for Regularized Nonlinear Regression.” Journal of the American Statistical Association 118 (541): 551–70.
Krantsevich, Nikolay, Jingyu He, and P Richard Hahn. 2023. “Stochastic Tree Ensembles for Estimating Heterogeneous Effects.” International Conference on Artificial Intelligence and Statistics, 6120–31.
Murray, Jared S. 2021. “Log-Linear Bayesian Additive Regression Trees for Multinomial Logistic and Count Regression Models.” Journal of the American Statistical Association 116 (534): 756–69.
Pratola, Matthew T, Hugh A Chipman, Edward I George, and Robert E McCulloch. 2020. “Heteroscedastic BART via Multiplicative Regression Trees.” Journal of Computational and Graphical Statistics 29 (2): 405–17.