Skip to content

Overview of Stochastic Tree Models#

Stochastic tree models are a powerful addition to your modeling toolkit. As with many machine learning methods, understanding these models in depth is an involved task.

There are many excellent published papers on stochastic tree models (to name a few, the original BART paper, the XBART paper, and the BCF paper). Here, we aim to build up an abbreviated intuition for these models from their conceptually-simple building blocks.

Notation#

We're going to introduce some notation to make these concepts precise. In a traditional supervised learning setting, we hope to predict some outcome from features in a training dataset. We'll call the outcome y and the features X. Our goal is to come up with a function f that predicts the outcome y as well as possible from X alone.

Decision Trees#

Decision tree learning is a simple machine learning method that constructs a function f from a series of conditional statements. Consider the tree below.

We evaluate two conditional statments (X[,1] > 1 and X[,2] > -2), arranged in a tree-like sequence of branches, which determine whether the model predicts a, b, or c. We could similarly express this tree in math notation as

f(Xi)={a;Xi,11,Xi,22b;Xi,11,Xi,2>2c;Xi,1>1

We won't belabor the discussion of trees as there are many good textbooks and online articles on the topic, but we'll close by noting that training decision trees introduces a delicate balance between overfitting and underfitting. Simple trees like the one above do not capture much complexity in a dataset and may potentially be underfit while deep, complex trees are vulnerable to overfitting and tend to have high variance.

Boosted Decision Tree Ensembles#

One way to address the overfitting-underfitting tradeoff of decision trees is to build an "ensemble" of decision trees, so that the function f is defined by a sum of k individual decision trees gi

f(Xi)=g1(Xi)++gk(Xi)

There are several ways to train an ensemble of decision trees (sometimes called "forests"), the most popular of which are random forests and gradient boosting. Their main difference is that random forests train all m trees independently of one another, while boosting trains tree sequentially, so that tree j depends on the result of training trees 1 through j1. Libraries like xgboost and LightGBM are popular examples of boosted tree ensembles.

Tree ensembles often outperform neural networks and other machine learning methods on tabular datasets, but classic tree ensemble methods return a single estimated function f, without expressing uncertainty around its estimates.

Stochastic Tree Ensembles#

Stochastic tree ensembles differ from their classical counterparts in their use of randomness in learning a function. Rather than returning a single "best" tree ensemble, stochastic tree ensembles return a range of tree ensembles that fit the data well. Mechanically, it's useful to think of "sampling" -- rather than "fitting" -- a stochastic tree ensemble model.

Why is this useful? Suppose we've sampled m forests. For each observation i, we obtain m predictions: [f1(Xi),,fm(Xi)]. From this "dataset" of predictions, we can compute summary statistics, where a mean or median would give something akin to the predictions of an xgboost or lightgbm model, and the α and 1α quantiles give a credible interval.

Rather than explain each of the models that stochtree supports in depth here, we provide a high-level overview, with pointers to the relevant literature.

Supervised Learning#

The bart R function and the BARTModel python class are the primary interface for supervised prediction tasks in stochtree. The primary references for these models are BART (Chipman, George, McCulloch 2010) and XBART (He and Hahn 2021).

In addition to the standard BART / XBART models, in which each tree's leaves return a constant prediction, stochtree also supports arbitrary leaf regression on a user-provided basis (i.e. an expanded version of Chipman et al 2002 and Gramacy and Lee 2012).

Causal Inference#

The bcf R function and the BCFModel python class are the primary interface for causal effect estimation in stochtree. The primary references for these models are BCF (Hahn, Murray, Carvalho 2021) and XBCF (Krantsevich, He, Hahn 2022).

Additional Modeling Features#

Both the BART and BCF interfaces in stochtree support the following extensions: