Posterior Summary and Visualization Utilities

This vignette demonstrates the summary and plotting utilities available for stochtree models.

Setup

Load necessary packages

library(stochtree)
import numpy as np
import matplotlib.pyplot as plt
from stochtree import BARTModel, BCFModel, plot_parameter_trace

Set a seed for reproducibility

random_seed = 1234
set.seed(random_seed)
random_seed = 1234
rng = np.random.default_rng(random_seed)

Supervised Learning

We begin with the supervised learning use case served by the bart() function.

Below we simulate a simple regression dataset.

n <- 1000
p_x <- 10
p_w <- 1
X <- matrix(runif(n * p_x), ncol = p_x)
W <- matrix(runif(n * p_w), ncol = p_w)
f_XW <- (((0 <= X[, 10]) & (0.25 > X[, 10])) *
  (-7.5 * W[, 1]) +
  ((0.25 <= X[, 10]) & (0.5 > X[, 10])) * (-2.5 * W[, 1]) +
  ((0.5 <= X[, 10]) & (0.75 > X[, 10])) * (2.5 * W[, 1]) +
  ((0.75 <= X[, 10]) & (1 > X[, 10])) * (7.5 * W[, 1]))
noise_sd <- 1
y <- f_XW + rnorm(n, 0, 1) * noise_sd
n = 1000
p_x = 10
p_w = 1
X = rng.uniform(size=(n, p_x))
W = rng.uniform(size=(n, p_w))
# R uses X[,10] (1-indexed) = Python X[:,9]
f_XW = (
    ((X[:, 9] >= 0)    & (X[:, 9] < 0.25)) * (-7.5 * W[:, 0]) +
    ((X[:, 9] >= 0.25) & (X[:, 9] < 0.5))  * (-2.5 * W[:, 0]) +
    ((X[:, 9] >= 0.5)  & (X[:, 9] < 0.75)) * ( 2.5 * W[:, 0]) +
    ((X[:, 9] >= 0.75) & (X[:, 9] < 1.0))  * ( 7.5 * W[:, 0])
)
noise_sd = 1.0
y = f_XW + rng.standard_normal(n) * noise_sd

Now we fit a simple BART model to the data.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
general_params <- list(
  num_threads = 1, 
  num_chains = 3
)
bart_model <- stochtree::bart(
  X_train = X,
  y_train = y,
  leaf_basis_train = W,
  num_gfr = num_gfr,
  num_burnin = num_burnin,
  num_mcmc = num_mcmc,
  general_params = general_params
)
bart_model = BARTModel()
bart_model.sample(
    X_train=X,
    y_train=y,
    leaf_basis_train=W,
    num_gfr=10,
    num_burnin=0,
    num_mcmc=1000,
    general_params={
      "num_threads": 1, 
      "num_chains": 3
    },
)

We obtain a high level summary of the BART model by running print().

print(bart_model)
stochtree::bart() run with mean forest, global error variance model, and mean forest leaf scale model
Continuous outcome was modeled as Gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
print(bart_model)
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

For a more detailed summary (including the information above), we use the summary() function.

summary(bart_model)
stochtree::bart() run with mean forest, global error variance model, and mean forest leaf scale model
Continuous outcome was modeled as Gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
Summary of sigma^2 posterior: 
3000 samples, mean = 0.900, standard deviation = 0.047, quantiles:
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.8107628 0.8407881 0.8675898 0.8980279 0.9307229 0.9611398 0.9942541 
Summary of leaf scale posterior: 
3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
       2.5%         10%         25%         50%         75%         90% 
0.004485701 0.005126561 0.005641411 0.006285762 0.007238720 0.008307585 
      97.5% 
0.010510942 
Summary of in-sample posterior mean predictions: 
1000 observations, mean = -0.063, standard deviation = 3.283, quantiles:
      2.5%        10%        25%        50%        75%        90%      97.5% 
-6.8552153 -4.9346264 -1.8170496 -0.1200449  2.0538402  4.2706207  6.4520534 
print(bart_model.summary())
BART Model Summary:
-------------------
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

Summary of sigma^2 posterior: 3000 samples, mean = 0.937, standard deviation = 0.047, quantiles:
    2.5%: 0.849
   10.0%: 0.879
   25.0%: 0.906
   50.0%: 0.935
   75.0%: 0.968
   90.0%: 0.998
   97.5%: 1.031
Summary of leaf scale posterior: 3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
    2.5%: 0.005
   10.0%: 0.006
   25.0%: 0.006
   50.0%: 0.007
   75.0%: 0.008
   90.0%: 0.009
   97.5%: 0.010
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 0.107, standard deviation = 3.286, quantiles:
    2.5%: -6.716
   10.0%: -4.329
   25.0%: -1.935
   50.0%: 0.170
   75.0%: 2.087
   90.0%: 4.432
   97.5%: 6.600

None

We can use the plot() function to produce a traceplot of model terms like the global error scale \(\sigma^2\) or (if \(\sigma^2\) is not sampled) the first observation of cached train set predictions.

plot(bart_model)

ax = plot_parameter_trace(bart_model, term="global_error_scale")
plt.show()

For finer-grained control over which parameters to plot, we can also use the extractParameter() function to pull the posterior distribution of any valid model term (e.g., global error scale \(\sigma^2\), leaf scale \(\sigma^2_{\ell}\), in-sample mean function predictions y_hat_train) and then plot any subset or transformation of these values.

y_hat_train_samples <- extractParameter(bart_model, "y_hat_train")
obs_index <- 1
plot(
  y_hat_train_samples[obs_index, ],
  type = "l",
  main = paste0("In-Sample Predictions Traceplot, Observation ", obs_index),
  xlab = "Index",
  ylab = "Parameter Values"
)

y_hat_train_samples = bart_model.extract_parameter("y_hat_train")
obs_index = 0
fig, ax = plt.subplots()
ax.plot(y_hat_train_samples[obs_index, :])
ax.set_title(f"In-Sample Predictions Traceplot, Observation {obs_index}")
ax.set_xlabel("Index")
ax.set_ylabel("Parameter Values")
plt.show()

Causal Inference

We now run the same demo for the causal inference use case served by the bcf() function in R and the BCFModel Python class.

Below we simulate a simple dataset for a causal inference problem with binary treatment and continuous outcome.

# Generate covariates and treatment
n <- 1000
p_X = 5
X = matrix(runif(n * p_X), ncol = p_X)
pi_X = 0.25 + 0.5 * X[, 1]
Z = rbinom(n, 1, pi_X)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[, 3]
tau_X = X[, 2] * 2 - 1

# Generate outcome
epsilon = rnorm(n, 0, 1)
y = mu_X + tau_X * Z + epsilon
# Generate covariates and treatment
n = 1000
p_X = 5
X = rng.uniform(size=(n, p_X))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[:, 2]
tau_X = X[:, 1] * 2 - 1

# Generate outcome
epsilon = rng.standard_normal(n)
y = mu_X + tau_X * Z + epsilon

Now we fit a simple BCF model to the data

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
general_params <- list(
  num_threads = 1, 
  num_chains = 3, 
  adaptive_coding = TRUE
)
bcf_model <- stochtree::bcf(
  X_train = X,
  y_train = y,
  Z_train = Z,
  num_gfr = num_gfr,
  num_burnin = num_burnin,
  num_mcmc = num_mcmc,
  general_params = general_params
)
bcf_model = BCFModel()
bcf_model.sample(
    X_train=X,
    Z_train=Z,
    y_train=y,
    propensity_train=pi_X,
    num_gfr=10,
    num_burnin=0,
    num_mcmc=1000,
    general_params={
      "num_threads": 1, 
      "num_chains": 3, 
      "adaptive_coding": True
    },
)

We obtain a high level summary of the BCF model by running print().

print(bcf_model)
stochtree::bcf() run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
outcome was standardized
An internal propensity model was fit using stochtree::bart() in lieu of user-provided propensity scores
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
print(bcf_model)
BCFModel run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

For a more detailed summary (including the information above), we use the summary() function / method.

summary(bcf_model)
stochtree::bcf() run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
outcome was standardized
An internal propensity model was fit using stochtree::bart() in lieu of user-provided propensity scores
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning) 
Summary of sigma^2 posterior: 
3000 samples, mean = 0.946, standard deviation = 0.046, quantiles:
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.8602718 0.8884035 0.9141399 0.9435992 0.9753794 1.0056469 1.0410952 
Summary of prognostic forest leaf scale posterior: 
3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
        2.5%          10%          25%          50%          75%          90% 
0.0009160522 0.0010774573 0.0012350772 0.0014498265 0.0016907946 0.0019845796 
       97.5% 
0.0025799952 
Summary of adaptive coding parameters: 
3000 samples, mean (control) = -0.461, mean (treated) = 1.019, standard deviation (control) = 0.353, standard deviation (treated) = 0.328
quantiles (control):
       2.5%         10%         25%         50%         75%         90% 
-1.23065170 -0.91900031 -0.68260202 -0.44082441 -0.21038925 -0.02613343 
      97.5% 
 0.15527754 
quantiles (treated):
     2.5%       10%       25%       50%       75%       90%     97.5% 
0.4182956 0.5986667 0.7902708 1.0229387 1.2282032 1.4285370 1.7057750 
Summary of treatment effect intercept (tau_0) posterior: 
3000 samples, mean = 0.198, standard deviation = 0.820, quantiles:
      2.5%        10%        25%        50%        75%        90%      97.5% 
-1.5270194 -1.0527016 -0.2388395  0.1543970  0.8088434  1.2127939  1.7210237 
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 3.488, standard deviation = 0.944, quantiles:
    2.5%      10%      25%      50%      75%      90%    97.5% 
1.696618 2.269925 2.806166 3.483249 4.140743 4.736517 5.435912 
Summary of in-sample posterior mean CATEs: 
1000 observations, mean = -0.216, standard deviation = 0.523, quantiles:
      2.5%        10%        25%        50%        75%        90%      97.5% 
-1.0219221 -0.8474964 -0.6778336 -0.2933795  0.2888800  0.4871224  0.6127401 
print(bcf_model.summary())
BCF Model Summary:
------------------
BCFModel run with prognostic forest, treatment effect forest, global error variance model, prognostic forest leaf scale model, and treatment effect intercept model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)

Summary of sigma^2 posterior: 3000 samples, mean = 0.873, standard deviation = 0.041, quantiles:
    2.5%: 0.797
   10.0%: 0.821
   25.0%: 0.844
   50.0%: 0.871
   75.0%: 0.901
   90.0%: 0.926
   97.5%: 0.956
Summary of prognostic forest leaf scale posterior: 3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
    2.5%: 0.001
   10.0%: 0.001
   25.0%: 0.001
   50.0%: 0.002
   75.0%: 0.002
   90.0%: 0.002
   97.5%: 0.003
Summary of adaptive coding parameters: 
3000 samples, mean (control) = -0.540, mean (treated) = 1.151, standard deviation (control) = 0.324, standard deviation (treated) = 0.339
quantiles (control):
    2.5%: -1.230
   10.0%: -0.964
   25.0%: -0.743
   50.0%: -0.514
   75.0%: -0.324
   90.0%: -0.156
   97.5%: 0.058

quantiles (treated):
    2.5%: 0.530
   10.0%: 0.732
   25.0%: 0.907
   50.0%: 1.140
   75.0%: 1.370
   90.0%: 1.592
   97.5%: 1.881
Summary of treatment effect intercept (tau_0) posterior: 3000 samples, mean = -0.243, standard deviation = 0.592, quantiles:
    2.5%: -1.592
   10.0%: -1.207
   25.0%: -0.509
   50.0%: -0.166
   75.0%: 0.197
   90.0%: 0.443
   97.5%: 0.671
Summary of in-sample posterior mean predictions: 
1000 observations, mean = 3.482, standard deviation = 0.961, quantiles:
    2.5%: 1.847
   10.0%: 2.199
   25.0%: 2.781
   50.0%: 3.468
   75.0%: 4.163
   90.0%: 4.771
   97.5%: 5.313
Summary of in-sample posterior mean CATEs: 
1000 observations, mean = 0.359, standard deviation = 0.625, quantiles:
    2.5%: -0.693
   10.0%: -0.530
   25.0%: -0.186
   50.0%: 0.447
   75.0%: 0.926
   90.0%: 1.133
   97.5%: 1.298

None

In R, we have a plot() that produces a traceplot of model terms like the global error scale \(\sigma^2\) or (if \(\sigma^2\) is not sampled) the first observation of cached train set predictions.

In Python, we provide a plot_parameter_trace() function for requesting a traceplot of a specific model parameter.

plot(bcf_model)

ax = plot_parameter_trace(bcf_model, term="global_error_scale")
plt.show()

For finer-grained control over which parameters to plot, we can also use the extractParameter() function in R or the extract_parameter() method in Python to query the posterior distribution of any valid model term (e.g., global error scale \(\sigma^2\), prognostic forest leaf scale \(\sigma^2_{\mu}\), CATE forest leaf scale \(\sigma^2_{\tau}\), adaptive coding parameters \(b_0\) and \(b_1\) for binary treatment, in-sample mean function predictions y_hat_train, in-sample CATE function predictions tau_hat_train) and then plot any subset or transformation of these values.

adaptive_coding_samples <- extractParameter(bcf_model, "adaptive_coding")
plot(
  adaptive_coding_samples[1, ],
  type = "l",
  main = "Adaptive Coding Parameter Traceplot",
  xlab = "Index",
  ylab = "Parameter Values",
  ylim = range(adaptive_coding_samples),
  col = "blue"
)
lines(adaptive_coding_samples[2, ], col = "orange")
legend(
  "topright",
  legend = c("Control", "Treated"),
  lty = 1,
  col = c("blue", "orange")
)

adaptive_coding_samples = bcf_model.extract_parameter("adaptive_coding")
fig, ax = plt.subplots()
ax.plot(adaptive_coding_samples[0, :], color="blue", label="Control")
ax.plot(adaptive_coding_samples[1, :], color="orange", label="Treated")
ax.set_title("Adaptive Coding Parameter Traceplot")
ax.set_xlabel("Index")
ax.set_ylabel("Parameter Values")
ax.legend(loc="upper right")
plt.show()