Ordinal Regression in StochTree¶
This notebook demonstrates how to use BART to model ordinal outcomes with a complementary log-log (cloglog) link function (Alam and Linero (2025)).
We begin by loading the requisite libraries.
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from stochtree import BARTModel, OutcomeModel
Introduction to Ordinal BART with Cloglog Link¶
Ordinal data refers to outcomes that have a natural ordering but undefined distances between categories. Examples include survey responses (strongly disagree, disagree, neutral, agree, strongly agree), severity ratings (mild, moderate, severe), or educational levels (elementary, high school, college, graduate).
The cloglog link function is: $$\text{cloglog}(p) = \log(-\log(1-p))$$
In the BART framework with cloglog ordinal regression, we model: $$P(Y = k \mid Y \geq k, X = x) = 1 - \exp\left(-e^{\gamma_k + \lambda(x)}\right)$$
where $\lambda(x)$ is represented by a stochastic tree ensemble and $\gamma_k$ are cutpoints for the ordinal categories. This link function is asymmetric and particularly appropriate when the probability of being in higher categories changes rapidly at certain thresholds, making it different from the symmetric probit or logit links commonly used in ordinal regression.
Data Simulation¶
We begin by simulating from a dataset with an ordinal outcome with three categories, $y_i \in \left\{1,2,3\right\}$ whose probabilities depend on covariates, $X$.
# RNG
random_seed = 2026
rng = np.random.default_rng(random_seed)
# Sample size and number of predictors
n = 2000
p = 5
# Design matrix and true lambda function
X = rng.standard_normal((n, p))
beta = np.ones(p) / np.sqrt(p)
true_lambda = X @ beta
# Set cutpoints for ordinal categories (3 categories: 1, 2, 3)
n_categories = 3
gamma_true = np.array([-2.0, 1.0])
# True ordinal class probabilities
true_probs = np.zeros((n, n_categories))
true_probs[:, 0] = 1 - np.exp(-np.exp(gamma_true[0] + true_lambda))
for j in range(1, n_categories - 1):
true_probs[:, j] = (
np.exp(-np.exp(gamma_true[j - 1] + true_lambda))
* (1 - np.exp(-np.exp(gamma_true[j] + true_lambda)))
)
true_probs[:, n_categories - 1] = 1 - true_probs[:, :-1].sum(axis=1)
# Generate ordinal outcomes (1-indexed integers)
y = np.array(
[rng.choice(np.arange(1, n_categories + 1), p=true_probs[i]) for i in range(n)],
dtype=float,
)
# Print outcome distribution
unique, counts = np.unique(y, return_counts=True)
print("Outcome distribution:", dict(zip(unique.astype(int), counts)))
# Train-test split
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.2, random_state=random_seed)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
Outcome distribution: {np.int64(1): np.int64(354), np.int64(2): np.int64(1332), np.int64(3): np.int64(314)}
Model Fitting¶
We specify the cloglog link function for modeling an ordinal outcome by setting outcome_model=OutcomeModel(outcome="ordinal", link="cloglog") in the general_params argument. Since ordinal outcomes are incompatible with the Gaussian global error variance model, we also set sample_sigma2_global=False.
We also override the default num_trees for the mean forest (200) in favor of greater regularization for the ordinal model and set sample_sigma2_leaf=False.
bart_model = BARTModel()
bart_model.sample(
X_train=X_train,
y_train=y_train,
X_test=X_test,
num_gfr=0,
num_burnin=1000,
num_mcmc=1000,
general_params={
"cutpoint_grid_size": 100,
"sample_sigma2_global": False,
"keep_every": 1,
"num_chains": 1,
"random_seed": random_seed,
"outcome_model": OutcomeModel(outcome="ordinal", link="cloglog"),
},
mean_forest_params={"num_trees": 50, "sample_sigma2_leaf": False},
)
Prediction¶
As with any other BART model in stochtree, we can use the predict function on our ordinal model. Specifying scale = "linear" and terms = "y_hat" will simply return predictions from the estimated $\lambda(x)$ function, but users can estimate class probabilities via scale = "probability", which by default will return an array of dimension (num_observations, num_categories, num_samples), where num_observations = nrow(X), num_categories is the number of unique ordinal labels that the outcome takes, and num_samples is the number of draws of the model. Specifying type = "mean" collapses the output to a num_observations x num_categories matrix, with the average posterior class probability for each observation. Users can also specify type = "class" for the maximum a posteriori (MAP) class label estimate for each draw of each observation.
Below we compute the posterior class probabilities for the train and test sets.
est_probs_train = bart_model.predict(
X_train, scale="probability", terms="y_hat", type="mean"
)
est_probs_test = bart_model.predict(
X_test, scale="probability", terms="y_hat", type="mean"
)
Model Results and Interpretation¶
Since one of the "cutpoints" is fixed for identifiability, we plot the posterior distributions of the other two cutpoints and compare them to their true simulated values (blue dotted lines).
The cutpoint samples are accessed via bart_model.extract_parameter("cloglog_cutpoints") (shape: (n_categories - 1, num_samples)) and are shifted by the per-sample mean of the training predictions to account for the non-identifiable intercept.
cutpoint_samples = bart_model.extract_parameter("cloglog_cutpoints")
y_hat_train_post = bart_model.predict(X=X_train, scale="linear", terms="y_hat", type="posterior")
gamma1 = cutpoint_samples[0, :] + y_hat_train_post.mean(axis=0)
gamma2 = cutpoint_samples[1, :] + y_hat_train_post.mean(axis=0)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.hist(gamma1, density=True, bins=40)
ax1.axvline(gamma_true[0], color="blue", linestyle="dotted", linewidth=2)
ax1.set_title("Posterior Distribution of Cutpoint 1")
ax1.set_xlabel("Cutpoint 1")
ax1.set_ylabel("Density")
ax2.hist(gamma2, density=True, bins=40)
ax2.axvline(gamma_true[1], color="blue", linestyle="dotted", linewidth=2)
ax2.set_title("Posterior Distribution of Cutpoint 2")
ax2.set_xlabel("Cutpoint 2")
ax2.set_ylabel("Density")
plt.tight_layout()
plt.show()
Similarly, we can compare the true latent "utility function" $\lambda(x)$ to the (mean-shifted) BART forest predictions.
y_hat_train = bart_model.predict(X=X_train, scale="linear", terms="y_hat", type="mean")
y_hat_test = bart_model.predict(X=X_test, scale="linear", terms="y_hat", type="mean")
lambda_pred_train = y_hat_train - y_hat_train.mean()
lambda_pred_test = y_hat_test - y_hat_test.mean()
corr_train = np.corrcoef(true_lambda[train_inds], lambda_pred_train)[0, 1]
corr_test = np.corrcoef(true_lambda[test_inds], lambda_pred_test)[0, 1]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.scatter(lambda_pred_train, true_lambda[train_inds], alpha=0.3, s=10)
ax1.axline((0, 0), slope=1, color="blue", linewidth=2)
ax1.set_title("Train Set: Predicted vs Actual")
ax1.set_xlabel("Predicted")
ax1.set_ylabel("Actual")
ax1.text(0.05, 0.95, f"Correlation: {corr_train:.3f}", transform=ax1.transAxes,
color="red", verticalalignment="top")
ax2.scatter(lambda_pred_test, true_lambda[test_inds], alpha=0.3, s=10)
ax2.axline((0, 0), slope=1, color="blue", linewidth=2)
ax2.set_title("Test Set: Predicted vs Actual")
ax2.set_xlabel("Predicted")
ax2.set_ylabel("Actual")
ax2.text(0.05, 0.95, f"Correlation: {corr_test:.3f}", transform=ax2.transAxes,
color="red", verticalalignment="top")
plt.tight_layout()
plt.show()
Finally, we compare the estimated posterior mean class probabilities with the true simulated value for each class on the training set.
fig, axes = plt.subplots(1, n_categories, figsize=(15, 5))
for j in range(n_categories):
corr = np.corrcoef(true_probs[train_inds, j], est_probs_train[:, j])[0, 1]
axes[j].scatter(true_probs[train_inds, j], est_probs_train[:, j], alpha=0.3, s=10)
axes[j].axline((0, 0), slope=1, color="blue", linewidth=2)
axes[j].set_title(f"Training Set: True vs Estimated Probability, Class {j + 1}")
axes[j].set_xlabel("True Class Probability")
axes[j].set_ylabel("Estimated Class Probability")
axes[j].text(0.05, 0.95, f"Correlation: {corr:.3f}", transform=axes[j].transAxes,
color="red", verticalalignment="top")
plt.tight_layout()
plt.show()
And we run the same comparison on the test set.
fig, axes = plt.subplots(1, n_categories, figsize=(15, 5))
for j in range(n_categories):
corr = np.corrcoef(true_probs[test_inds, j], est_probs_test[:, j])[0, 1]
axes[j].scatter(true_probs[test_inds, j], est_probs_test[:, j], alpha=0.3, s=10)
axes[j].axline((0, 0), slope=1, color="blue", linewidth=2)
axes[j].set_title(f"Test Set: True vs Estimated Probability, Class {j + 1}")
axes[j].set_xlabel("True Class Probability")
axes[j].set_ylabel("Estimated Class Probability")
axes[j].text(0.05, 0.95, f"Correlation: {corr:.3f}", transform=axes[j].transAxes,
color="red", verticalalignment="top")
plt.tight_layout()
plt.show()
References¶
Alam, Entejar, and Antonio R Linero. 2025. “A Unified Bayesian Nonparametric Framework for Ordinal, Survival, and Density Regression Using the Complementary Log-Log Link.” arXiv Preprint arXiv:2502.00606.