Supervised Learning in StochTree

Introduction

The bart R function and the BARTModel Python class are the primary interface for supervised prediction tasks in stochtree. The primary references for these models are the original BART paper (Chipman et al. (2010)) and the XBART paper (He and Hahn (2023)), which replaces the original BART Metropolis-Hastings sampler with a greedy “grow-from-root” sampler for the same underlying model.

Model

Classic BART

The “classic” BART model of Chipman et al. (2010) models every observation (indexed by \(i\)) as i.i.d. normal with a covariate-dependent mean \(f(x_i)\) and parametric variance \(\sigma^2\) with an inverse-gamma prior.

\[ \begin{aligned} y_i \mid X_i = x_i &\sim \mathcal{N}\left(f(x_i), \sigma^2\right)\\ f(x_i) &= \sum_{s = 1}^m g_s(x_i)\\ \sigma^2 &\sim \text{IG}(a,b) \end{aligned} \]

In the equation above, \(f\) is the sum of \(m\) decision trees, and each function \(g_s\) denotes a single decision tree. The decision tree function partitions the covariate space into \(k_s\) mutually exclusive regions, \(\mathcal{A}_{s,j}\), indexed by \(j\). The prediction rule \(g_s\) can be represented mathematically as

\[ g_s(x) = \sum_{j=1}^{k_s} \mu_{s,j} \mathbb{I}\left(x \in A_{s,j}\right). \]

The leaf node parameters \(\mu_{s,j}\) take independent, conjugate normal priors

\[ \mu_{s,j} \sim \mathcal{N}\left(0, \sigma^2_{\mu}\right). \]

The structure of the decision tree (i.e. the number and placement of disjoint regions in \(g_s\)) is governed by a prior on the probability of splitting, where a node \(\eta\) of depth \(d_{\eta}\) splits with probability \[ p(\text{split } \eta) = \alpha (1+d_{\eta})^{-\beta} \]

Taken together, we refer to the prior on mean function \(f\) as \[ f \sim \text{BART}\left(\alpha, \beta, m\right) \]

stochtree fits this model by default by specifying the following terms in stochtree::bart() in R or BARTModel.sample() in Python:

Data R Python
\(X\) X_train X_train
\(y\) y_train y_train

Leaf Regression

In addition to the standard BART / XBART models, in which each tree’s leaves return a constant prediction, stochtree also supports arbitrary leaf regression on a user-provided basis (i.e. an expanded version of Chipman et al 2002 and Gramacy and Lee 2012).

For a fixed basis \(\Psi\), we adapt the decision tree function to \[ \begin{aligned} g_s(x,\Psi) &= \sum_{j = 1}^{k_s} \vec{\beta}_{s,j}^t\Psi(x) \mathbb{I}\left(x \in \mathcal{A}_{s,j}\right),\\ \end{aligned} \] where the leaf parameter vectors, \(\vec{\beta}_{s,j}\), are given independent multivariate normal priors: \[ \begin{aligned} \vec{\beta}_{s,j} &\sim \mathrm{N}\left(\vec{0}, \Sigma_0\right). \end{aligned} \]

Fitting a model with leaf regression requires passing a basis (\(\Psi\) above) to stochtree::bart() in R or BARTModel.sample() in Python as below:

Data R Python
\(X\) X_train X_train
\(\Psi\) leaf_basis_train leaf_basis_train
\(y\) y_train y_train

The prior covariance \(\Sigma_0\) defaults to a diagonal matrix with \(\hat{\sigma}^2 / m\) where \(\hat{\sigma}^2\) is the empirical outcome variance and \(m\) is the number of trees in the forest, but users can override this by specifying sigma2_leaf_init in the mean_forest_params list / dictionary in both R and Python.

Forest-based Heteroskedasticity

Both of the above models are homoskedastic, where each observation is assumed to have the same error variance \(\sigma^2\). We can relax this assumption by modeling the variance function with a log-linear forest

\[ \begin{aligned} y_i \mid X_i = x_i &\sim \mathcal{N}\left(f(x_i), \sigma^2_0 \exp{h(x_i)}\right)\\ f &\sim \text{BART}(\alpha_f, \beta_f, m_f) \\ h &\sim \text{logBART}(\alpha_h, \beta_h, m_h) \end{aligned} \]

The \(\text{logBART}\) prior for \(h\) employs the same \(p(split)\) prior as \(\text{BART}\), but places a log inverse-gamma prior on the leaf parameters \(\lambda_{s,j}\) \[ \exp\left(\lambda_{s,j}\right) \sim \text{IG}\left(a,b\right) \]

stochtree exposes a variance_forest_params list / dict in both R and Python, which defaults to num_trees = 0 / 'num_trees': 0 and omits the forest-based heteroskedasticity term. Enabling forest-based heteroskedastic is as straightforward as setting num_trees in the variance forest parameter list to an integer greater than zero. Other prior terms, such as \(\alpha_h\) and \(\beta_h\) can also be set via the variance_forest_params list / dict.

For more details on this model, see Murray (2021) or Pratola et al. (2020).

Additive Random Effects

Random effects models are a massive topic with lots of field-specific notation and terminology (see wikipedia as a starting point). We allow for an additive random effects term to be specified directly through the BART interface in R and Python. Notationally, if we think about our data having a hierarchical group structure, so that each observation \(i\) belongs to a group \(j\), where group \(j\) will generally contain multiple observations, then we model

\[ \begin{aligned} y_{i,j} \mid X_{i,j} = x_{i,j} &\sim \mathcal{N}\left(W_{i,j} \vec{\gamma}_j + f(x_{i,j}), \sigma^2\right)\\ f(x_{i,j}) &= \sum_{s = 1}^m g_s(x_{i,j})\\ \sigma^2 &\sim \text{IG}(a,b)\\ \vec{\gamma}_j &\sim \mathrm{N}(\vec{0}, \sigma^2_{\gamma} I),\\ \end{aligned} \]

where \(\vec{\gamma}_j\) is group \(j\)’s regression parameters on basis \(W\). The prior on variance components \(\sigma^2_{\gamma}\) is specified via the “redundant parameterization” of Gelman et al. (2013), which splits out a “working parameter,” \(\alpha\), from a “group parameter,” \(\xi\), for better convergence properties: \[ \begin{aligned} \gamma_j &= \alpha \xi_j,\\ \alpha &\sim \mathcal{N}(\mu_{\alpha}, \sigma^2_{\alpha}),\\ \xi_j &\sim \mathcal{N}(\mu_{\xi,j}, \sigma^2_{\xi}),\\ \sigma^2_{\alpha} &= 1,\\ \sigma^2_{\xi} &\sim \text{IG}(a_{\xi},b_{\xi}).\\ \end{aligned} \]

Fitting a model with additive random effects requires passing group labels (\(j\) above) and a basis (\(W\) above) to stochtree::bart() in R or BARTModel.sample() in Python as below:

Data R Python
\(X\) X_train X_train
\(j\) rfx_group_ids_train rfx_group_ids_train
\(W\) rfx_basis_train rfx_basis_train
\(y\) y_train y_train

Each of \(\mu_{\alpha}\), \(\sigma^2_{\alpha}\), \(\mu_{\xi,j}\), \(\sigma^2_{\xi}\), \(a_{\xi}\), and \(b_{\xi}\) can be set in the random_effects_params list / dictionary with the following mapping

Prior Term R Python
\(\mu_{\alpha}\) working_parameter_prior_mean working_parameter_prior_mean
\(\sigma^2_{\alpha}\) group_parameters_prior_mean group_parameter_prior_mean
\(\mu_{\xi,j}\) working_parameter_prior_cov working_parameter_prior_cov
\(\sigma^2_{\xi}\) group_parameter_prior_cov group_parameter_prior_cov
\(a_{\xi}\) variance_prior_shape variance_prior_shape
\(b_{\xi}\) variance_prior_scale variance_prior_scale

Finally, a common enough use case for additive random effects in BART is to model “random intercepts”, where \(\vec{\gamma}_j = \gamma_j\) is a single group-specific intercept (and \(W = \vec{1}\)), so we allow users to skip manually constructing a basis of ones and instead set model_spec to intercept_only in the random_effects_params list / dictionary. This setting propagates through to the predict.bartmodel() / BARTModel.predict() methods so that users do not have to pass a basis when predicting from a model fit this way (group IDs are still required, of course).

Sampler

MCMC

The “classic” BART model of Chipman et al. (2010) was sampled via MCMC: Gibbs draws for parametric terms like \(\sigma^2\) and a Metropolis-Hastings algorithm for the tree structure that accepts / rejects proposed modifications to the trees, one at a time.

We control how long to run this algorithm and how its draws should be retained through two core arguments:

  1. num_mcmc: the number of MCMC samples to run and retain for later analysis (defaults to 100)
  2. num_burnin: the number of MCMC samples to immediately discard after drawing (i.e. “burning in” a sampler, this defaults to 0)

The “burn-in” phenomenon is a practice used to allow the MCMC sampler to converge to higher probability regions of the posterior before retaining MCMC draws. There are two further arguments in general_params that govern the behavior of the post-burnin MCMC sampler:

  1. num_chains: the number of independent runs of the MCMC sampler (defaults to 1)
  2. keep_every: how frequently an MCMC draw should be retained (defaults to 1)

Taken together, these parameters tell us that a sampler will run for num_burnin + num_mcmc * num_chains * keep_every iterations.

Grow-From-Root

“Burn-in” is a common practice with complex MCMC samplers (such as BART) to allow for convergence. He and Hahn (2023) developed the “grow-from-root” (GFR) algorithm as an alternative to BART’s Metropolis-Hastings sampler. On it’s own, the GFR algorithm sampling the classic BART model defines a new method called “XBART”, but He and Hahn (2023) note that XBART samples can be used to initialize an MCMC chain, and samples taken via GFR tend to converge to high probability regions much faster than an MCMC chain.

stochtree is built to support this “warm-start” BART procedure out of the box, and number of GFR iterations to run before MCMC chains is specified by the num_gfr argument. If a user specifies num_chains > 1, then different GFR samples will be used to initialize each GFR iteration.

References

Albert, James H, and Siddhartha Chib. 1993. “Bayesian Analysis of Binary and Polychotomous Response Data.” Journal of the American Statistical Association 88 (422): 669–79.
Chipman, Hugh A., Edward I. George, and Robert E. McCulloch. 2010. BART: Bayesian additive regression trees.” The Annals of Applied Statistics 4 (1): 266–98. https://doi.org/10.1214/09-AOAS285.
Gelman, Andrew, John B Carlin, Hal S Stern, David B Dunson, Aki Vehtari, and Donald B Rubin. 2013. Bayesian Data Analysis. Third. Chapman; Hall/CRC.
He, Jingyu, and P Richard Hahn. 2023. “Stochastic Tree Ensembles for Regularized Nonlinear Regression.” Journal of the American Statistical Association 118 (541): 551–70.
Murray, Jared S. 2021. “Log-Linear Bayesian Additive Regression Trees for Multinomial Logistic and Count Regression Models.” Journal of the American Statistical Association 116 (534): 756–69.
Pratola, Matthew T, Hugh A Chipman, Edward I George, and Robert E McCulloch. 2020. “Heteroscedastic BART via Multiplicative Regression Trees.” Journal of Computational and Graphical Statistics 29 (2): 405–17.