Scikit-Learn Estimator Wrappers for BART¶
stochtree.BARTModel is fundamentally a Bayesian interface in which users specify a prior, provide data, sample from the posterior, and manage / inspect the resulting posterior samples.
However, the basic BART model
$$y_i \sim \mathcal{N}\left(f(X_i), \sigma^2\right)$$
involves samples of a nonparametric function $f$ which estimates the expected value of $y$ given $X$. Averaging over these draws, the posterior mean $\bar{f}$ alone may satisfy some supervised learning use cases. In order to serve this use case straightforwardly, we offer scikit-learn-compatible estimator wrappers around BARTModel which implement the familiar API of sklearn models.
For continuous outcomes, the stochtree.StochTreeBARTRegressor class provides fit, predict and score methods.
For binary outcomes (deployed via probit BART), the stochtree.StochTreeBARTBinaryClassifier class provides fit, predict, predict_proba, decision_function, and score methods.
Users can fit multi-class classifiers by wrapping a OneVsRestClassifier around StochTreeBARTBinaryClassifier.
We begin by loading necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.datasets import load_wine, load_breast_cancer
from sklearn.model_selection import GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from stochtree import (
StochTreeBARTRegressor,
StochTreeBARTBinaryClassifier,
)
Next, we seed a random number generator
random_seed = 1234
rng = np.random.default_rng(random_seed)
BART Regression via sklearn Estimator¶
We simulate some simple regression data to demonstrate the continuous outcome use case
n = 100
p = 10
X = rng.normal(size=(n, p))
y = X[:, 0] * 3 + rng.normal(size=n)
We fit a BART regression model by initializing a StochTreeBARTRegressor and calling its fit() method.
Since stochtree.BARTModel is configured primarily through parameter dictionaries, any downstream parameters that we wish to set are passed through as parameter dictionaries. In this case, we only specify the random seed.
reg = StochTreeBARTRegressor(general_params={"random_seed": random_seed})
reg.fit(X, y)
StochTreeBARTRegressor(general_params={'random_seed': 1234})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| num_gfr | 10 | |
| num_burnin | 0 | |
| num_mcmc | 100 | |
| general_params | {'random_seed': 1234} | |
| mean_forest_params | None | |
| variance_forest_params | None | |
| rfx_params | None |
Now, we can predict from this model and compare the (posterior mean) predictions to the true outcome
pred = reg.predict(X)
plt.scatter(pred, y)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()
We can also test the determinism of the model by running it again with the same seed and comparing predictions to the first model
reg2 = StochTreeBARTRegressor(general_params={"random_seed": random_seed})
reg2.fit(X, y)
pred2 = reg2.predict(X)
plt.scatter(pred, pred2)
plt.xlabel("First model")
plt.ylabel("Second model")
plt.show()
Cross-Validating a BART Model¶
While the default hyperparameters of stochtree.BARTModel are designed to work well "out of the box," we can use posterior mean prediction error to cross-validate the model's parameters.
Below we use grid search to consider the effect of several BART parameters:
- Number of GFR iterations (
num_gfr) - Number of MCMC iterations (
num_mcmc) num_trees,alphaandbetafor the mean forest
param_grid = {
"num_gfr": [10, 40],
"num_mcmc": [0, 1000],
"mean_forest_params": [
{"num_trees": 50, "alpha": 0.95, "beta": 2.0},
{"num_trees": 100, "alpha": 0.90, "beta": 1.5},
{"num_trees": 200, "alpha": 0.85, "beta": 1.0},
],
}
grid_search = GridSearchCV(
estimator=StochTreeBARTRegressor(),
param_grid=param_grid,
cv=5,
scoring="r2",
n_jobs=-1,
)
grid_search.fit(X, y)
GridSearchCV(cv=5, estimator=StochTreeBARTRegressor(), n_jobs=-1,
param_grid={'mean_forest_params': [{'alpha': 0.95, 'beta': 2.0,
'num_trees': 50},
{'alpha': 0.9, 'beta': 1.5,
'num_trees': 100},
{'alpha': 0.85, 'beta': 1.0,
'num_trees': 200}],
'num_gfr': [10, 40], 'num_mcmc': [0, 1000]},
scoring='r2')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
StochTreeBARTRegressor(mean_forest_params={'alpha': 0.95, 'beta': 2.0,
'num_trees': 50},
num_mcmc=1000)Parameters
| num_gfr | 10 | |
| num_burnin | 0 | |
| num_mcmc | 1000 | |
| general_params | None | |
| mean_forest_params | {'alpha': 0.95, 'beta': 2.0, 'num_trees': 50} | |
| variance_forest_params | None | |
| rfx_params | None |
cv_best_ind = np.argwhere(grid_search.cv_results_['rank_test_score'] == 1).item(0)
best_num_gfr = grid_search.cv_results_['param_num_gfr'][cv_best_ind].item(0)
best_num_mcmc = grid_search.cv_results_['param_num_mcmc'][cv_best_ind].item(0)
best_mean_forest_params = grid_search.cv_results_['param_mean_forest_params'][cv_best_ind]
best_num_trees = best_mean_forest_params['num_trees']
best_alpha = best_mean_forest_params['alpha']
best_beta = best_mean_forest_params['beta']
print_message = f"""
Hyperparameters chosen by grid search:
num_gfr: {best_num_gfr}
num_mcmc: {best_num_mcmc}
num_trees: {best_num_trees}
alpha: {best_alpha}
beta: {best_beta}
"""
print(print_message)
Hyperparameters chosen by grid search: num_gfr: 10 num_mcmc: 1000 num_trees: 50 alpha: 0.95 beta: 2.0
BART Classification via sklearn Estimator¶
Now, we demonstrate the same functionality with binary and categorical outcomes, which require working with the StochTreeBARTBinaryClassifier class (and a wrapper for multi-class outcomes).
First, we load a dataset from sklearn with a binary outcome.
dataset = load_breast_cancer()
X = dataset.data
y = dataset.target
And we fit a binary classification model as follows
clf = StochTreeBARTBinaryClassifier(general_params={"random_seed": random_seed})
clf.fit(X=X, y=y)
StochTreeBARTBinaryClassifier(general_params={'random_seed': 1234})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| num_gfr | 10 | |
| num_burnin | 0 | |
| num_mcmc | 100 | |
| general_params | {'random_seed': 1234} | |
| mean_forest_params | None | |
| variance_forest_params | None | |
| rfx_params | None |
In addition to class predictions, we can compute and visualize the predicted probability of each class via predict_proba().
probs = clf.predict_proba(X)
plt.hist(probs[:, 1], bins=30)
plt.show()
Now, we load a multi-class classification dataset from sklearn.
dataset = load_wine()
X = dataset.data
y = dataset.target
And fit a multi-class classification model by wrapping a OneVsRestClassifier around StochTreeBARTBinaryClassifier.
clf = OneVsRestClassifier(
StochTreeBARTBinaryClassifier(general_params={"random_seed": random_seed})
)
clf.fit(X=X, y=y)
OneVsRestClassifier(estimator=StochTreeBARTBinaryClassifier(general_params={'random_seed': 1234}))In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
StochTreeBARTBinaryClassifier(general_params={'random_seed': 1234})Parameters
| num_gfr | 10 | |
| num_burnin | 0 | |
| num_mcmc | 100 | |
| general_params | {'random_seed': 1234} | |
| mean_forest_params | None | |
| variance_forest_params | None | |
| rfx_params | None |
And we visualize the histogram of predicted probabilities for each outcome category.
fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
fig.tight_layout(pad=3.0)
probs = clf.predict_proba(X)
ax1.hist(probs[y == 0, 0], bins=30)
ax1.set_title("Predicted Probabilities for Class 0")
ax1.set_xlim(0, 1)
ax2.hist(probs[y == 1, 1], bins=30)
ax2.set_title("Predicted Probabilities for Class 1")
ax2.set_xlim(0, 1)
ax3.hist(probs[y == 2, 2], bins=30)
ax3.set_title("Predicted Probabilities for Class 2")
ax3.set_xlim(0, 1)
plt.show()