library(stochtree)Posterior Summary and Visualization Utilities
This vignette demonstrates the summary and plotting utilities available for stochtree models.
Setup
Load necessary packages
import numpy as np
import matplotlib.pyplot as plt
from stochtree import BARTModel, BCFModel, plot_parameter_traceSet a seed for reproducibility
random_seed = 1234
set.seed(random_seed)random_seed = 1234
rng = np.random.default_rng(random_seed)Supervised Learning
We begin with the supervised learning use case served by the bart() function.
Below we simulate a simple regression dataset.
n <- 1000
p_x <- 10
p_w <- 1
X <- matrix(runif(n * p_x), ncol = p_x)
W <- matrix(runif(n * p_w), ncol = p_w)
f_XW <- (((0 <= X[, 10]) & (0.25 > X[, 10])) *
(-7.5 * W[, 1]) +
((0.25 <= X[, 10]) & (0.5 > X[, 10])) * (-2.5 * W[, 1]) +
((0.5 <= X[, 10]) & (0.75 > X[, 10])) * (2.5 * W[, 1]) +
((0.75 <= X[, 10]) & (1 > X[, 10])) * (7.5 * W[, 1]))
noise_sd <- 1
y <- f_XW + rnorm(n, 0, 1) * noise_sdn = 1000
p_x = 10
p_w = 1
X = rng.uniform(size=(n, p_x))
W = rng.uniform(size=(n, p_w))
# R uses X[,10] (1-indexed) = Python X[:,9]
f_XW = (
((X[:, 9] >= 0) & (X[:, 9] < 0.25)) * (-7.5 * W[:, 0]) +
((X[:, 9] >= 0.25) & (X[:, 9] < 0.5)) * (-2.5 * W[:, 0]) +
((X[:, 9] >= 0.5) & (X[:, 9] < 0.75)) * ( 2.5 * W[:, 0]) +
((X[:, 9] >= 0.75) & (X[:, 9] < 1.0)) * ( 7.5 * W[:, 0])
)
noise_sd = 1.0
y = f_XW + rng.standard_normal(n) * noise_sdNow we fit a simple BART model to the data.
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
general_params <- list(
num_threads = 1,
num_chains = 3
)
bart_model <- stochtree::bart(
X_train = X,
y_train = y,
leaf_basis_train = W,
num_gfr = num_gfr,
num_burnin = num_burnin,
num_mcmc = num_mcmc,
general_params = general_params
)bart_model = BARTModel()
bart_model.sample(
X_train=X,
y_train=y,
leaf_basis_train=W,
num_gfr=10,
num_burnin=0,
num_mcmc=1000,
general_params={
"num_threads": 1,
"num_chains": 3
},
)We obtain a high level summary of the BART model by running print().
print(bart_model)stochtree::bart() run with mean forest, global error variance model, and mean forest leaf scale model
Continuous outcome was modeled as Gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
print(bart_model)BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
For a more detailed summary (including the information above), we use the summary() function.
summary(bart_model)stochtree::bart() run with mean forest, global error variance model, and mean forest leaf scale model
Continuous outcome was modeled as Gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
Summary of sigma^2 posterior:
3000 samples, mean = 0.900, standard deviation = 0.047, quantiles:
2.5% 10% 25% 50% 75% 90% 97.5%
0.8107628 0.8407881 0.8675898 0.8980279 0.9307229 0.9611398 0.9942541
Summary of leaf scale posterior:
3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
2.5% 10% 25% 50% 75% 90%
0.004485701 0.005126561 0.005641411 0.006285762 0.007238720 0.008307585
97.5%
0.010510942
Summary of in-sample posterior mean predictions:
1000 observations, mean = -0.063, standard deviation = 3.283, quantiles:
2.5% 10% 25% 50% 75% 90% 97.5%
-6.8552153 -4.9346264 -1.8170496 -0.1200449 2.0538402 4.2706207 6.4520534
print(bart_model.summary())BART Model Summary:
-------------------
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
Summary of sigma^2 posterior: 3000 samples, mean = 0.937, standard deviation = 0.047, quantiles:
2.5%: 0.849
10.0%: 0.879
25.0%: 0.906
50.0%: 0.935
75.0%: 0.968
90.0%: 0.998
97.5%: 1.031
Summary of leaf scale posterior: 3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
2.5%: 0.005
10.0%: 0.006
25.0%: 0.006
50.0%: 0.007
75.0%: 0.008
90.0%: 0.009
97.5%: 0.010
Summary of in-sample posterior mean predictions:
1000 observations, mean = 0.107, standard deviation = 3.286, quantiles:
2.5%: -6.716
10.0%: -4.329
25.0%: -1.935
50.0%: 0.170
75.0%: 2.087
90.0%: 4.432
97.5%: 6.600
None
We can use the plot() function to produce a traceplot of model terms like the global error scale \(\sigma^2\) or (if \(\sigma^2\) is not sampled) the first observation of cached train set predictions.
plot(bart_model)
ax = plot_parameter_trace(bart_model, term="global_error_scale")
plt.show()
For finer-grained control over which parameters to plot, we can also use the extractParameter() function to pull the posterior distribution of any valid model term (e.g., global error scale \(\sigma^2\), leaf scale \(\sigma^2_{\ell}\), in-sample mean function predictions y_hat_train) and then plot any subset or transformation of these values.
y_hat_train_samples <- extractParameter(bart_model, "y_hat_train")
obs_index <- 1
plot(
y_hat_train_samples[obs_index, ],
type = "l",
main = paste0("In-Sample Predictions Traceplot, Observation ", obs_index),
xlab = "Index",
ylab = "Parameter Values"
)
y_hat_train_samples = bart_model.extract_parameter("y_hat_train")
obs_index = 0
fig, ax = plt.subplots()
ax.plot(y_hat_train_samples[obs_index, :])
ax.set_title(f"In-Sample Predictions Traceplot, Observation {obs_index}")
ax.set_xlabel("Index")
ax.set_ylabel("Parameter Values")
plt.show()
Causal Inference
We now run the same demo for the causal inference use case served by the bcf() function in R and the BCFModel Python class.
Below we simulate a simple dataset for a causal inference problem with binary treatment and continuous outcome.
# Generate covariates and treatment
n <- 1000
p_X = 5
X = matrix(runif(n * p_X), ncol = p_X)
pi_X = 0.25 + 0.5 * X[, 1]
Z = rbinom(n, 1, pi_X)
# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[, 3]
tau_X = X[, 2] * 2 - 1
# Generate outcome
epsilon = rnorm(n, 0, 1)
y = mu_X + tau_X * Z + epsilon# Generate covariates and treatment
n = 1000
p_X = 5
X = rng.uniform(size=(n, p_X))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)
# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[:, 2]
tau_X = X[:, 1] * 2 - 1
# Generate outcome
epsilon = rng.standard_normal(n)
y = mu_X + tau_X * Z + epsilonNow we fit a simple BCF model to the data
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
general_params <- list(
num_threads = 1,
num_chains = 3,
adaptive_coding = TRUE
)
bcf_model <- stochtree::bcf(
X_train = X,
y_train = y,
Z_train = Z,
num_gfr = num_gfr,
num_burnin = num_burnin,
num_mcmc = num_mcmc,
general_params = general_params
)bcf_model = BCFModel()
bcf_model.sample(
X_train=X,
Z_train=Z,
y_train=y,
propensity_train=pi_X,
num_gfr=10,
num_burnin=0,
num_mcmc=1000,
general_params={
"num_threads": 1,
"num_chains": 3,
"adaptive_coding": True
},
)We obtain a high level summary of the BCF model by running print().
print(bcf_model)stochtree::bcf() run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
outcome was standardized
An internal propensity model was fit using stochtree::bart() in lieu of user-provided propensity scores
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
print(bcf_model)BCFModel run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
For a more detailed summary (including the information above), we use the summary() function / method.
summary(bcf_model)stochtree::bcf() run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
outcome was standardized
An internal propensity model was fit using stochtree::bart() in lieu of user-provided propensity scores
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
Summary of sigma^2 posterior:
3000 samples, mean = 0.946, standard deviation = 0.046, quantiles:
2.5% 10% 25% 50% 75% 90% 97.5%
0.8602718 0.8884035 0.9141399 0.9435992 0.9753794 1.0056469 1.0410952
Summary of prognostic forest leaf scale posterior:
3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
2.5% 10% 25% 50% 75% 90%
0.0009160522 0.0010774573 0.0012350772 0.0014498265 0.0016907946 0.0019845796
97.5%
0.0025799952
Summary of adaptive coding parameters:
3000 samples, mean (control) = -0.461, mean (treated) = 1.019, standard deviation (control) = 0.353, standard deviation (treated) = 0.328
quantiles (control):
2.5% 10% 25% 50% 75% 90%
-1.23065170 -0.91900031 -0.68260202 -0.44082441 -0.21038925 -0.02613343
97.5%
0.15527754
quantiles (treated):
2.5% 10% 25% 50% 75% 90% 97.5%
0.4182956 0.5986667 0.7902708 1.0229387 1.2282032 1.4285370 1.7057750
Summary of treatment effect intercept (tau_0) posterior:
3000 samples, mean = 0.198, standard deviation = 0.820, quantiles:
2.5% 10% 25% 50% 75% 90% 97.5%
-1.5270194 -1.0527016 -0.2388395 0.1543970 0.8088434 1.2127939 1.7210237
Summary of in-sample posterior mean predictions:
1000 observations, mean = 3.488, standard deviation = 0.944, quantiles:
2.5% 10% 25% 50% 75% 90% 97.5%
1.696618 2.269925 2.806166 3.483249 4.140743 4.736517 5.435912
Summary of in-sample posterior mean CATEs:
1000 observations, mean = -0.216, standard deviation = 0.523, quantiles:
2.5% 10% 25% 50% 75% 90% 97.5%
-1.0219221 -0.8474964 -0.6778336 -0.2933795 0.2888800 0.4871224 0.6127401
print(bcf_model.summary())BCF Model Summary:
------------------
BCFModel run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
Summary of sigma^2 posterior: 3000 samples, mean = 0.873, standard deviation = 0.041, quantiles:
2.5%: 0.797
10.0%: 0.821
25.0%: 0.844
50.0%: 0.871
75.0%: 0.901
90.0%: 0.926
97.5%: 0.956
Summary of prognostic forest leaf scale posterior: 3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
2.5%: 0.001
10.0%: 0.001
25.0%: 0.001
50.0%: 0.002
75.0%: 0.002
90.0%: 0.002
97.5%: 0.003
Summary of adaptive coding parameters:
3000 samples, mean (control) = -0.540, mean (treated) = 1.151, standard deviation (control) = 0.324, standard deviation (treated) = 0.339
quantiles (control):
2.5%: -1.230
10.0%: -0.964
25.0%: -0.743
50.0%: -0.514
75.0%: -0.324
90.0%: -0.156
97.5%: 0.058
quantiles (treated):
2.5%: 0.530
10.0%: 0.732
25.0%: 0.907
50.0%: 1.140
75.0%: 1.370
90.0%: 1.592
97.5%: 1.881
Summary of treatment effect intercept (tau_0) posterior: 3000 samples, mean = -0.243, standard deviation = 0.592, quantiles:
2.5%: -1.592
10.0%: -1.207
25.0%: -0.509
50.0%: -0.166
75.0%: 0.197
90.0%: 0.443
97.5%: 0.671
Summary of in-sample posterior mean predictions:
1000 observations, mean = 3.482, standard deviation = 0.961, quantiles:
2.5%: 1.847
10.0%: 2.199
25.0%: 2.781
50.0%: 3.468
75.0%: 4.163
90.0%: 4.771
97.5%: 5.313
Summary of in-sample posterior mean CATEs:
1000 observations, mean = 0.359, standard deviation = 0.625, quantiles:
2.5%: -0.693
10.0%: -0.530
25.0%: -0.186
50.0%: 0.447
75.0%: 0.926
90.0%: 1.133
97.5%: 1.298
None
In R, we have a plot() that produces a traceplot of model terms like the global error scale \(\sigma^2\) or (if \(\sigma^2\) is not sampled) the first observation of cached train set predictions.
In Python, we provide a plot_parameter_trace() function for requesting a traceplot of a specific model parameter.
plot(bcf_model)
ax = plot_parameter_trace(bcf_model, term="global_error_scale")
plt.show()
For finer-grained control over which parameters to plot, we can also use the extractParameter() function in R or the extract_parameter() method in Python to query the posterior distribution of any valid model term (e.g., global error scale \(\sigma^2\), prognostic forest leaf scale \(\sigma^2_{\mu}\), CATE forest leaf scale \(\sigma^2_{\tau}\), adaptive coding parameters \(b_0\) and \(b_1\) for binary treatment, in-sample mean function predictions y_hat_train, in-sample CATE function predictions tau_hat_train) and then plot any subset or transformation of these values.
adaptive_coding_samples <- extractParameter(bcf_model, "adaptive_coding")
plot(
adaptive_coding_samples[1, ],
type = "l",
main = "Adaptive Coding Parameter Traceplot",
xlab = "Index",
ylab = "Parameter Values",
ylim = range(adaptive_coding_samples),
col = "blue"
)
lines(adaptive_coding_samples[2, ], col = "orange")
legend(
"topright",
legend = c("Control", "Treated"),
lty = 1,
col = c("blue", "orange")
)
adaptive_coding_samples = bcf_model.extract_parameter("adaptive_coding")
fig, ax = plt.subplots()
ax.plot(adaptive_coding_samples[0, :], color="blue", label="Control")
ax.plot(adaptive_coding_samples[1, :], color="orange", label="Treated")
ax.set_title("Adaptive Coding Parameter Traceplot")
ax.set_xlabel("Index")
ax.set_ylabel("Parameter Values")
ax.legend(loc="upper right")
plt.show()