Posterior Summary and Visualization Utilities

This vignette demonstrates the summary and plotting utilities available for stochtree models.

Setup

Load necessary packages

library(stochtree)
import numpy as np
import matplotlib.pyplot as plt
from stochtree import BARTModel, BCFModel, plot_parameter_trace

Set a seed for reproducibility

random_seed = 1234
set.seed(random_seed)
random_seed = 1234
rng = np.random.default_rng(random_seed)

Supervised Learning

We begin with the supervised learning use case served by the bart() function.

Below we simulate a simple regression dataset.

n <- 1000
p_x <- 10
p_w <- 1
X <- matrix(runif(n * p_x), ncol = p_x)
W <- matrix(runif(n * p_w), ncol = p_w)
f_XW <- (((0 <= X[, 10]) & (0.25 > X[, 10])) *
  (-7.5 * W[, 1]) +
  ((0.25 <= X[, 10]) & (0.5 > X[, 10])) * (-2.5 * W[, 1]) +
  ((0.5 <= X[, 10]) & (0.75 > X[, 10])) * (2.5 * W[, 1]) +
  ((0.75 <= X[, 10]) & (1 > X[, 10])) * (7.5 * W[, 1]))
noise_sd <- 1
y <- f_XW + rnorm(n, 0, 1) * noise_sd
n = 1000
p_x = 10
p_w = 1
X = rng.uniform(size=(n, p_x))
W = rng.uniform(size=(n, p_w))
# R uses X[,10] (1-indexed) = Python X[:,9]
f_XW = (
    ((X[:, 9] >= 0)    & (X[:, 9] < 0.25)) * (-7.5 * W[:, 0]) +
    ((X[:, 9] >= 0.25) & (X[:, 9] < 0.5))  * (-2.5 * W[:, 0]) +
    ((X[:, 9] >= 0.5)  & (X[:, 9] < 0.75)) * ( 2.5 * W[:, 0]) +
    ((X[:, 9] >= 0.75) & (X[:, 9] < 1.0))  * ( 7.5 * W[:, 0])
)
noise_sd = 1.0
y = f_XW + rng.standard_normal(n) * noise_sd

Now we fit a simple BART model to the data.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
general_params <- list(
  num_threads = 1, 
  num_chains = 3
)
bart_model <- stochtree::bart(
  X_train = X,
  y_train = y,
  leaf_basis_train = W,
  num_gfr = num_gfr,
  num_burnin = num_burnin,
  num_mcmc = num_mcmc,
  general_params = general_params
)
bart_model = BARTModel()
bart_model.sample(
    X_train=X,
    y_train=y,
    leaf_basis_train=W,
    num_gfr=10,
    num_burnin=0,
    num_mcmc=1000,
    general_params={
      "num_threads": 1, 
      "num_chains": 3
    },
)

We obtain a high level summary of the BART model by running print().

print(bart_model)
stochtree::bart() run with mean forest, global error variance model, and mean forest leaf scale model
Continuous outcome was modeled as Gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
print(bart_model)
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

For a more detailed summary (including the information above), we use the summary() function.

summary(bart_model)
stochtree::bart() run with mean forest, global error variance model, and mean forest leaf scale model
Continuous outcome was modeled as Gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
Summary of sigma^2 posterior: 
3000 samples, mean = 0.895, standard deviation = 0.047, quantiles:
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.8084663 0.8354853 0.8619096 0.8927078 0.9272022 0.9552168 0.9908970 
Summary of leaf scale posterior: 
3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
       2.5%         10%         25%         50%         75%         90% 
0.005022884 0.005492911 0.005980975 0.006564506 0.007374873 0.008304079 
      97.5% 
0.009749566 
Summary of in-sample posterior mean predictions: 
1000 observations, mean = -0.063, standard deviation = 3.284, quantiles:
      2.5%        10%        25%        50%        75%        90%      97.5% 
-6.8439105 -4.9413282 -1.8213713 -0.1199945  2.0077018  4.2607277  6.4900123 
print(bart_model.summary())
BART Model Summary:
-------------------
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

Summary of sigma^2 posterior: 3000 samples, mean = 0.938, standard deviation = 0.049, quantiles:
    2.5%: 0.847
   10.0%: 0.877
   25.0%: 0.905
   50.0%: 0.937
   75.0%: 0.970
   90.0%: 1.002
   97.5%: 1.038
Summary of leaf scale posterior: 3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
    2.5%: 0.005
   10.0%: 0.005
   25.0%: 0.006
   50.0%: 0.007
   75.0%: 0.008
   90.0%: 0.008
   97.5%: 0.009
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 0.106, standard deviation = 3.286, quantiles:
    2.5%: -6.741
   10.0%: -4.345
   25.0%: -1.943
   50.0%: 0.169
   75.0%: 2.113
   90.0%: 4.431
   97.5%: 6.589

None

We can use the plot() function to produce a traceplot of model terms like the global error scale \(\sigma^2\) or (if \(\sigma^2\) is not sampled) the first observation of cached train set predictions.

plot(bart_model)

ax = plot_parameter_trace(bart_model, term="global_error_scale")
plt.show()

For finer-grained control over which parameters to plot, we can also use the extractParameter() function to pull the posterior distribution of any valid model term (e.g., global error scale \(\sigma^2\), leaf scale \(\sigma^2_{\ell}\), in-sample mean function predictions y_hat_train) and then plot any subset or transformation of these values.

y_hat_train_samples <- extractParameter(bart_model, "y_hat_train")
obs_index <- 1
plot(
  y_hat_train_samples[obs_index, ],
  type = "l",
  main = paste0("In-Sample Predictions Traceplot, Observation ", obs_index),
  xlab = "Index",
  ylab = "Parameter Values"
)

y_hat_train_samples = bart_model.extract_parameter("y_hat_train")
obs_index = 0
fig, ax = plt.subplots()
ax.plot(y_hat_train_samples[obs_index, :])
ax.set_title(f"In-Sample Predictions Traceplot, Observation {obs_index}")
ax.set_xlabel("Index")
ax.set_ylabel("Parameter Values")
plt.show()

Causal Inference

We now run the same demo for the causal inference use case served by the bcf() function in R and the BCFModel Python class.

Below we simulate a simple dataset for a causal inference problem with binary treatment and continuous outcome.

# Generate covariates and treatment
n <- 1000
p_X = 5
X = matrix(runif(n * p_X), ncol = p_X)
pi_X = 0.25 + 0.5 * X[, 1]
Z = rbinom(n, 1, pi_X)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[, 3]
tau_X = X[, 2] * 2 - 1

# Generate outcome
epsilon = rnorm(n, 0, 1)
y = mu_X + tau_X * Z + epsilon
# Generate covariates and treatment
n = 1000
p_X = 5
X = rng.uniform(size=(n, p_X))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[:, 2]
tau_X = X[:, 1] * 2 - 1

# Generate outcome
epsilon = rng.standard_normal(n)
y = mu_X + tau_X * Z + epsilon

Now we fit a simple BCF model to the data

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
general_params <- list(
  num_threads = 1, 
  num_chains = 3, 
  adaptive_coding = TRUE
)
bcf_model <- stochtree::bcf(
  X_train = X,
  y_train = y,
  Z_train = Z,
  num_gfr = num_gfr,
  num_burnin = num_burnin,
  num_mcmc = num_mcmc,
  general_params = general_params
)
bcf_model = BCFModel()
bcf_model.sample(
    X_train=X,
    Z_train=Z,
    y_train=y,
    propensity_train=pi_X,
    num_gfr=10,
    num_burnin=0,
    num_mcmc=1000,
    general_params={
      "num_threads": 1, 
      "num_chains": 3, 
      "adaptive_coding": True
    },
)

We obtain a high level summary of the BCF model by running print().

print(bcf_model)
stochtree::bcf() run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
outcome was standardized
An internal propensity model was fit using stochtree::bart() in lieu of user-provided propensity scores
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
print(bcf_model)
BCFModel run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

For a more detailed summary (including the information above), we use the summary() function / method.

summary(bcf_model)
stochtree::bcf() run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
outcome was standardized
An internal propensity model was fit using stochtree::bart() in lieu of user-provided propensity scores
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
Summary of sigma^2 posterior: 
3000 samples, mean = 0.948, standard deviation = 0.046, quantiles:
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.8612806 0.8911718 0.9157341 0.9462258 0.9779726 1.0074856 1.0385681 
Summary of prognostic forest leaf scale posterior: 
3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
       2.5%         10%         25%         50%         75%         90% 
0.000908978 0.001080862 0.001262013 0.001498533 0.001784155 0.002043084 
      97.5% 
0.002475996 
Summary of adaptive coding parameters: 
3000 samples, mean (control) = -0.454, mean (treated) = 1.002, standard deviation (control) = 0.349, standard deviation (treated) = 0.354
quantiles (control):
       2.5%         10%         25%         50%         75%         90% 
-1.19410739 -0.90348454 -0.68372289 -0.44269060 -0.21349489 -0.02423392 
      97.5% 
 0.21857728 
quantiles (treated):
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.3798293 0.5736983 0.7543187 0.9740120 1.2132759 1.4570656 1.8109981 
Summary of treatment effect intercept (tau_0) posterior: 
3000 samples, mean = 0.136, standard deviation = 0.782, quantiles:
      2.5%        10%        25%        50%        75%        90%      97.5% 
-1.3717617 -1.0891631 -0.3088470  0.1693315  0.7446541  1.0738739  1.5073765 
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 3.487, standard deviation = 0.945, quantiles:
    2.5%      10%      25%      50%      75%      90%    97.5% 
1.666454 2.278896 2.818287 3.483078 4.138062 4.746473 5.434899 
Summary of in-sample posterior mean CATEs: 
1000 observations, mean = 0.064, standard deviation = 0.520, quantiles:
      2.5%        10%        25%        50%        75%        90%      97.5% 
-0.7253229 -0.5625624 -0.3996405 -0.0266764  0.5708963  0.7679216  0.8995149 
print(bcf_model.summary())
BCF Model Summary:
------------------
BCFModel run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

Summary of sigma^2 posterior: 3000 samples, mean = 0.873, standard deviation = 0.042, quantiles:
    2.5%: 0.795
   10.0%: 0.820
   25.0%: 0.844
   50.0%: 0.870
   75.0%: 0.899
   90.0%: 0.928
   97.5%: 0.959
Summary of prognostic forest leaf scale posterior: 3000 samples, mean = 0.002, standard deviation = 0.001, quantiles:
    2.5%: 0.001
   10.0%: 0.001
   25.0%: 0.001
   50.0%: 0.002
   75.0%: 0.002
   90.0%: 0.003
   97.5%: 0.003
Summary of adaptive coding parameters: 
3000 samples, mean (control) = -0.508, mean (treated) = 1.139, standard deviation (control) = 0.337, standard deviation (treated) = 0.328
quantiles (control):
    2.5%: -1.226
   10.0%: -0.970
   25.0%: -0.723
   50.0%: -0.477
   75.0%: -0.269
   90.0%: -0.096
   97.5%: 0.067

quantiles (treated):
    2.5%: 0.560
   10.0%: 0.724
   25.0%: 0.914
   50.0%: 1.133
   75.0%: 1.346
   90.0%: 1.558
   97.5%: 1.846
Summary of treatment effect intercept (tau_0) posterior: 3000 samples, mean = -0.369, standard deviation = 0.558, quantiles:
    2.5%: -1.458
   10.0%: -1.140
   25.0%: -0.811
   50.0%: -0.312
   75.0%: 0.077
   90.0%: 0.308
   97.5%: 0.631
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 3.482, standard deviation = 0.962, quantiles:
    2.5%: 1.849
   10.0%: 2.200
   25.0%: 2.781
   50.0%: 3.476
   75.0%: 4.166
   90.0%: 4.760
   97.5%: 5.331
Summary of in-sample posterior mean CATEs: 
1000 observations, mean = -0.030, standard deviation = 0.623, quantiles:
    2.5%: -1.078
   10.0%: -0.920
   25.0%: -0.578
   50.0%: 0.048
   75.0%: 0.539
   90.0%: 0.737
   97.5%: 0.915

None

In R, we have a plot() that produces a traceplot of model terms like the global error scale \(\sigma^2\) or (if \(\sigma^2\) is not sampled) the first observation of cached train set predictions.

In Python, we provide a plot_parameter_trace() function for requesting a traceplot of a specific model parameter.

plot(bcf_model)

ax = plot_parameter_trace(bcf_model, term="global_error_scale")
plt.show()

For finer-grained control over which parameters to plot, we can also use the extractParameter() function in R or the extract_parameter() method in Python to query the posterior distribution of any valid model term (e.g., global error scale \(\sigma^2\), prognostic forest leaf scale \(\sigma^2_{\mu}\), CATE forest leaf scale \(\sigma^2_{\tau}\), adaptive coding parameters \(b_0\) and \(b_1\) for binary treatment, in-sample mean function predictions y_hat_train, in-sample CATE function predictions tau_hat_train) and then plot any subset or transformation of these values.

adaptive_coding_samples <- extractParameter(bcf_model, "adaptive_coding")
plot(
  adaptive_coding_samples[1, ],
  type = "l",
  main = "Adaptive Coding Parameter Traceplot",
  xlab = "Index",
  ylab = "Parameter Values",
  ylim = range(adaptive_coding_samples),
  col = "blue"
)
lines(adaptive_coding_samples[2, ], col = "orange")
legend(
  "topright",
  legend = c("Control", "Treated"),
  lty = 1,
  col = c("blue", "orange")
)

adaptive_coding_samples = bcf_model.extract_parameter("adaptive_coding")
fig, ax = plt.subplots()
ax.plot(adaptive_coding_samples[0, :], color="blue", label="Control")
ax.plot(adaptive_coding_samples[1, :], color="orange", label="Treated")
ax.set_title("Adaptive Coding Parameter Traceplot")
ax.set_xlabel("Index")
ax.set_ylabel("Parameter Values")
ax.legend(loc="upper right")
plt.show()