Overview of Stochastic Tree Models#
Stochastic tree models are a powerful addition to your modeling toolkit. As with many machine learning methods, understanding these models in depth is an involved task.
There are many excellent published papers on stochastic tree models (to name a few, the original BART paper, the XBART paper, and the BCF paper). Here, we aim to build up an abbreviated intuition for these models from their conceptually-simple building blocks.
Notation#
We're going to introduce some notation to make these concepts precise.
In a traditional supervised learning setting, we hope to predict some outcome from features in a training dataset.
We'll call the outcome
Decision Trees#
Decision tree learning is a simple machine learning method that
constructs a function
We evaluate two conditional statments (X[,1] > 1
and X[,2] > -2
), arranged in a tree-like sequence of branches,
which determine whether the model predicts a
, b
, or c
. We could similarly express this tree in math notation as
We won't belabor the discussion of trees as there are many good textbooks and online articles on the topic, but we'll close by noting that training decision trees introduces a delicate balance between overfitting and underfitting. Simple trees like the one above do not capture much complexity in a dataset and may potentially be underfit while deep, complex trees are vulnerable to overfitting and tend to have high variance.
Boosted Decision Tree Ensembles#
One way to address the overfitting-underfitting tradeoff of decision trees is to build an "ensemble" of decision
trees, so that the function
There are several ways to train an ensemble of decision trees (sometimes called "forests"), the most popular of which are random forests and
gradient boosting. Their main difference is that random forests train
all
Tree ensembles often outperform neural networks and other machine learning methods on tabular datasets,
but classic tree ensemble methods return a single estimated function
Stochastic Tree Ensembles#
Stochastic tree ensembles differ from their classical counterparts in their use of randomness in learning a function. Rather than returning a single "best" tree ensemble, stochastic tree ensembles return a range of tree ensembles that fit the data well. Mechanically, it's useful to think of "sampling" -- rather than "fitting" -- a stochastic tree ensemble model.
Why is this useful? Suppose we've sampled
Rather than explain each of the models that stochtree
supports in depth here, we provide a high-level overview, with pointers to the relevant literature.
Supervised Learning#
The bart
R function and the BARTModel
python class are the primary interface for supervised
prediction tasks in stochtree
. The primary references for these models are
BART (Chipman, George, McCulloch 2010) and
XBART (He and Hahn 2021).
In addition to the standard BART / XBART models, in which each tree's leaves return a constant prediction, stochtree
also supports
arbitrary leaf regression on a user-provided basis (i.e. an expanded version of Chipman et al 2002 and Gramacy and Lee 2012).
Causal Inference#
The bcf
R function and the BCFModel
python class are the primary interface for causal effect
estimation in stochtree
. The primary references for these models are
BCF (Hahn, Murray, Carvalho 2021) and
XBCF (Krantsevich, He, Hahn 2022).
Additional Modeling Features#
Both the BART and BCF interfaces in stochtree
support the following extensions:
- Accelerated / "warm-start" sampling of forests (i.e. He and Hahn 2021)
- Forest-based heteroskedasticity (i.e. Murray 2021)
- Additive random effects (i.e. Gelman et al 2008)