Posterior Summary and Visualization Utilities

This vignette demonstrates the summary and plotting utilities available for stochtree models.

Setup

Load necessary packages

library(stochtree)
import numpy as np
import matplotlib.pyplot as plt
from stochtree import BARTModel, BCFModel, plot_parameter_trace

Set a seed for reproducibility

random_seed = 1234
set.seed(random_seed)
random_seed = 1234
rng = np.random.default_rng(random_seed)

Supervised Learning

We begin with the supervised learning use case served by the bart() function.

Below we simulate a simple regression dataset.

n <- 1000
p_x <- 10
p_w <- 1
X <- matrix(runif(n * p_x), ncol = p_x)
W <- matrix(runif(n * p_w), ncol = p_w)
f_XW <- (((0 <= X[, 10]) & (0.25 > X[, 10])) *
  (-7.5 * W[, 1]) +
  ((0.25 <= X[, 10]) & (0.5 > X[, 10])) * (-2.5 * W[, 1]) +
  ((0.5 <= X[, 10]) & (0.75 > X[, 10])) * (2.5 * W[, 1]) +
  ((0.75 <= X[, 10]) & (1 > X[, 10])) * (7.5 * W[, 1]))
noise_sd <- 1
y <- f_XW + rnorm(n, 0, 1) * noise_sd
n = 1000
p_x = 10
p_w = 1
X = rng.uniform(size=(n, p_x))
W = rng.uniform(size=(n, p_w))
# R uses X[,10] (1-indexed) = Python X[:,9]
f_XW = (
    ((X[:, 9] >= 0)    & (X[:, 9] < 0.25)) * (-7.5 * W[:, 0]) +
    ((X[:, 9] >= 0.25) & (X[:, 9] < 0.5))  * (-2.5 * W[:, 0]) +
    ((X[:, 9] >= 0.5)  & (X[:, 9] < 0.75)) * ( 2.5 * W[:, 0]) +
    ((X[:, 9] >= 0.75) & (X[:, 9] < 1.0))  * ( 7.5 * W[:, 0])
)
noise_sd = 1.0
y = f_XW + rng.standard_normal(n) * noise_sd

Now we fit a simple BART model to the data.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
general_params <- list(
  num_threads = 1, 
  num_chains = 3
)
bart_model <- stochtree::bart(
  X_train = X,
  y_train = y,
  leaf_basis_train = W,
  num_gfr = num_gfr,
  num_burnin = num_burnin,
  num_mcmc = num_mcmc,
  general_params = general_params
)
bart_model = BARTModel()
bart_model.sample(
    X_train=X,
    y_train=y,
    leaf_basis_train=W,
    num_gfr=10,
    num_burnin=0,
    num_mcmc=1000,
    general_params={
      "num_threads": 1, 
      "num_chains": 3
    },
)

We obtain a high level summary of the BART model by running print().

print(bart_model)
stochtree::bart() run with mean forest, global error variance model, and mean forest leaf scale model
Continuous outcome was modeled as Gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
print(bart_model)
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

For a more detailed summary (including the information above), we use the summary() function.

summary(bart_model)
stochtree::bart() run with mean forest, global error variance model, and mean forest leaf scale model
Continuous outcome was modeled as Gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
Summary of sigma^2 posterior: 
3000 samples, mean = 0.895, standard deviation = 0.049, quantiles:
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.8021322 0.8345788 0.8608693 0.8925349 0.9267356 0.9582750 0.9957126 
Summary of leaf scale posterior: 
3000 samples, mean = 0.006, standard deviation = 0.001, quantiles:
       2.5%         10%         25%         50%         75%         90% 
0.004581890 0.005038048 0.005544240 0.006148174 0.006933297 0.008029237 
      97.5% 
0.010108081 
Summary of in-sample posterior mean predictions: 
1000 observations, mean = -0.064, standard deviation = 3.284, quantiles:
     2.5%       10%       25%       50%       75%       90%     97.5% 
-6.864170 -4.917854 -1.813222 -0.121295  2.020283  4.284544  6.484986 
print(bart_model.summary())
BART Model Summary:
-------------------
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

Summary of sigma^2 posterior: 3000 samples, mean = 0.937, standard deviation = 0.049, quantiles:
    2.5%: 0.843
   10.0%: 0.877
   25.0%: 0.904
   50.0%: 0.935
   75.0%: 0.970
   90.0%: 1.000
   97.5%: 1.037
Summary of leaf scale posterior: 3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
    2.5%: 0.005
   10.0%: 0.006
   25.0%: 0.006
   50.0%: 0.007
   75.0%: 0.008
   90.0%: 0.009
   97.5%: 0.010
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 0.107, standard deviation = 3.287, quantiles:
    2.5%: -6.737
   10.0%: -4.301
   25.0%: -1.934
   50.0%: 0.170
   75.0%: 2.113
   90.0%: 4.387
   97.5%: 6.588

None

We can use the plot() function to produce a traceplot of model terms like the global error scale \(\sigma^2\) or (if \(\sigma^2\) is not sampled) the first observation of cached train set predictions.

plot(bart_model)

ax = plot_parameter_trace(bart_model, term="global_error_scale")
plt.show()

For finer-grained control over which parameters to plot, we can also use the extractParameter() function to pull the posterior distribution of any valid model term (e.g., global error scale \(\sigma^2\), leaf scale \(\sigma^2_{\ell}\), in-sample mean function predictions y_hat_train) and then plot any subset or transformation of these values.

y_hat_train_samples <- extractParameter(bart_model, "y_hat_train")
obs_index <- 1
plot(
  y_hat_train_samples[obs_index, ],
  type = "l",
  main = paste0("In-Sample Predictions Traceplot, Observation ", obs_index),
  xlab = "Index",
  ylab = "Parameter Values"
)

y_hat_train_samples = bart_model.extract_parameter("y_hat_train")
obs_index = 0
fig, ax = plt.subplots()
ax.plot(y_hat_train_samples[obs_index, :])
ax.set_title(f"In-Sample Predictions Traceplot, Observation {obs_index}")
ax.set_xlabel("Index")
ax.set_ylabel("Parameter Values")
plt.show()

Causal Inference

We now run the same demo for the causal inference use case served by the bcf() function in R and the BCFModel Python class.

Below we simulate a simple dataset for a causal inference problem with binary treatment and continuous outcome.

# Generate covariates and treatment
n <- 1000
p_X = 5
X = matrix(runif(n * p_X), ncol = p_X)
pi_X = 0.25 + 0.5 * X[, 1]
Z = rbinom(n, 1, pi_X)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[, 3]
tau_X = X[, 2] * 2 - 1

# Generate outcome
epsilon = rnorm(n, 0, 1)
y = mu_X + tau_X * Z + epsilon
# Generate covariates and treatment
n = 1000
p_X = 5
X = rng.uniform(size=(n, p_X))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[:, 2]
tau_X = X[:, 1] * 2 - 1

# Generate outcome
epsilon = rng.standard_normal(n)
y = mu_X + tau_X * Z + epsilon

Now we fit a simple BCF model to the data

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
general_params <- list(
  num_threads = 1, 
  num_chains = 3, 
  adaptive_coding = TRUE
)
bcf_model <- stochtree::bcf(
  X_train = X,
  y_train = y,
  Z_train = Z,
  num_gfr = num_gfr,
  num_burnin = num_burnin,
  num_mcmc = num_mcmc,
  general_params = general_params
)
bcf_model = BCFModel()
bcf_model.sample(
    X_train=X,
    Z_train=Z,
    y_train=y,
    propensity_train=pi_X,
    num_gfr=10,
    num_burnin=0,
    num_mcmc=1000,
    general_params={
      "num_threads": 1, 
      "num_chains": 3, 
      "adaptive_coding": True
    },
)

We obtain a high level summary of the BCF model by running print().

print(bcf_model)
stochtree::bcf() run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
outcome was standardized
An internal propensity model was fit using stochtree::bart() in lieu of user-provided propensity scores
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
print(bcf_model)
BCFModel run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

For a more detailed summary (including the information above), we use the summary() function / method.

summary(bcf_model)
stochtree::bcf() run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
outcome was standardized
An internal propensity model was fit using stochtree::bart() in lieu of user-provided propensity scores
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
Summary of sigma^2 posterior: 
3000 samples, mean = 0.950, standard deviation = 0.046, quantiles:
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.8659101 0.8930222 0.9190229 0.9485672 0.9805341 1.0113277 1.0445292 
Summary of prognostic forest leaf scale posterior: 
3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
       2.5%         10%         25%         50%         75%         90% 
0.001002187 0.001177501 0.001341383 0.001549201 0.001822880 0.002086417 
      97.5% 
0.002437231 
Summary of adaptive coding parameters: 
3000 samples, mean (control) = -0.412, mean (treated) = 1.036, standard deviation (control) = 0.336, standard deviation (treated) = 0.331
quantiles (control):
         2.5%           10%           25%           50%           75% 
-1.1668062964 -0.8387171955 -0.6159133938 -0.3932125884 -0.1823605100 
          90%         97.5% 
 0.0008645181  0.2050807144 
quantiles (treated):
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.4376757 0.6368674 0.8028234 1.0146934 1.2417872 1.4655859 1.7583584 
Summary of treatment effect intercept (tau_0) posterior: 
3000 samples, mean = 0.127, standard deviation = 0.814, quantiles:
       2.5%         10%         25%         50%         75%         90% 
-1.42699209 -0.92715753 -0.41169882  0.05445136  0.81308006  1.14882174 
      97.5% 
 1.65344491 
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 3.487, standard deviation = 0.944, quantiles:
    2.5%      10%      25%      50%      75%      90%    97.5% 
1.647473 2.280257 2.834242 3.488781 4.132127 4.715860 5.415987 
Summary of in-sample posterior mean CATEs: 
1000 observations, mean = -0.107, standard deviation = 0.516, quantiles:
      2.5%        10%        25%        50%        75%        90%      97.5% 
-0.9043420 -0.7347875 -0.5639919 -0.1941318  0.3966237  0.5890300  0.7144146 
print(bcf_model.summary())
BCF Model Summary:
------------------
BCFModel run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

Summary of sigma^2 posterior: 3000 samples, mean = 0.872, standard deviation = 0.042, quantiles:
    2.5%: 0.792
   10.0%: 0.818
   25.0%: 0.844
   50.0%: 0.870
   75.0%: 0.899
   90.0%: 0.927
   97.5%: 0.959
Summary of prognostic forest leaf scale posterior: 3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
    2.5%: 0.001
   10.0%: 0.001
   25.0%: 0.001
   50.0%: 0.002
   75.0%: 0.002
   90.0%: 0.002
   97.5%: 0.003
Summary of adaptive coding parameters: 
3000 samples, mean (control) = -0.566, mean (treated) = 1.137, standard deviation (control) = 0.323, standard deviation (treated) = 0.371
quantiles (control):
    2.5%: -1.260
   10.0%: -0.974
   25.0%: -0.775
   50.0%: -0.550
   75.0%: -0.350
   90.0%: -0.175
   97.5%: 0.037

quantiles (treated):
    2.5%: 0.430
   10.0%: 0.690
   25.0%: 0.884
   50.0%: 1.119
   75.0%: 1.371
   90.0%: 1.619
   97.5%: 1.900
Summary of treatment effect intercept (tau_0) posterior: 3000 samples, mean = -0.472, standard deviation = 0.976, quantiles:
    2.5%: -2.301
   10.0%: -2.033
   25.0%: -1.446
   50.0%: -0.117
   75.0%: 0.225
   90.0%: 0.591
   97.5%: 0.995
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 3.481, standard deviation = 0.962, quantiles:
    2.5%: 1.833
   10.0%: 2.210
   25.0%: 2.779
   50.0%: 3.472
   75.0%: 4.162
   90.0%: 4.775
   97.5%: 5.335
Summary of in-sample posterior mean CATEs: 
1000 observations, mean = 0.787, standard deviation = 0.626, quantiles:
    2.5%: -0.260
   10.0%: -0.111
   25.0%: 0.240
   50.0%: 0.883
   75.0%: 1.359
   90.0%: 1.554
   97.5%: 1.722

None

In R, we have a plot() that produces a traceplot of model terms like the global error scale \(\sigma^2\) or (if \(\sigma^2\) is not sampled) the first observation of cached train set predictions.

In Python, we provide a plot_parameter_trace() function for requesting a traceplot of a specific model parameter.

plot(bcf_model)

ax = plot_parameter_trace(bcf_model, term="global_error_scale")
plt.show()

For finer-grained control over which parameters to plot, we can also use the extractParameter() function in R or the extract_parameter() method in Python to query the posterior distribution of any valid model term (e.g., global error scale \(\sigma^2\), prognostic forest leaf scale \(\sigma^2_{\mu}\), CATE forest leaf scale \(\sigma^2_{\tau}\), adaptive coding parameters \(b_0\) and \(b_1\) for binary treatment, in-sample mean function predictions y_hat_train, in-sample CATE function predictions tau_hat_train) and then plot any subset or transformation of these values.

adaptive_coding_samples <- extractParameter(bcf_model, "adaptive_coding")
plot(
  adaptive_coding_samples[1, ],
  type = "l",
  main = "Adaptive Coding Parameter Traceplot",
  xlab = "Index",
  ylab = "Parameter Values",
  ylim = range(adaptive_coding_samples),
  col = "blue"
)
lines(adaptive_coding_samples[2, ], col = "orange")
legend(
  "topright",
  legend = c("Control", "Treated"),
  lty = 1,
  col = c("blue", "orange")
)

adaptive_coding_samples = bcf_model.extract_parameter("adaptive_coding")
fig, ax = plt.subplots()
ax.plot(adaptive_coding_samples[0, :], color="blue", label="Control")
ax.plot(adaptive_coding_samples[1, :], color="orange", label="Treated")
ax.set_title("Adaptive Coding Parameter Traceplot")
ax.set_xlabel("Index")
ax.set_ylabel("Parameter Values")
ax.legend(loc="upper right")
plt.show()