Summary and Plotting Utilities¶
This notebook demonstrates the summary and plotting utilities available for stochtree models in Python.
We begin by loading all necessary libraries.
import numpy as np
import matplotlib.pyplot as plt
from stochtree import (
BARTModel,
BCFModel,
plot_parameter_trace
)
And set a random seed for reproducibility.
random_seed = 1234
rng = np.random.default_rng(random_seed)
Supervised Learning¶
We begin with the supervised learning use case served by the BARTModel class.
Below we simulate a simple regression dataset.
# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75),
2.5 * W[:, 0],
7.5 * W[:, 0]),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon
Now we fit a simple BART model to the data
bart_model = BARTModel()
general_params = {"num_chains": 3}
bart_model.sample(
X_train=X,
y_train=y,
leaf_basis_train=W,
num_gfr=10,
num_mcmc=1000,
general_params=general_params,
)
We obtain a high level summary of the BART model by running print()
print(bart_model)
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest Outcome was standardized The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
For a more detailed summary (including the information above), we use the summary() method of BARTModel.
print(bart_model.summary())
BART Model Summary:
-------------------
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
Summary of sigma^2 posterior: 3000 samples, mean = 0.925, standard deviation = 0.049, quantiles:
2.5%: 0.830
10.0%: 0.862
25.0%: 0.892
50.0%: 0.925
75.0%: 0.957
90.0%: 0.985
97.5%: 1.023
Summary of leaf scale posterior: 3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
2.5%: 0.005
10.0%: 0.006
25.0%: 0.006
50.0%: 0.007
75.0%: 0.008
90.0%: 0.009
97.5%: 0.010
Summary of in-sample posterior mean predictions:
1000 observations, mean = 0.222, standard deviation = 3.230, quantiles:
2.5%: -6.731
10.0%: -4.278
25.0%: -1.692
50.0%: 0.352
75.0%: 2.012
90.0%: 4.552
97.5%: 6.687
Finally, we can use the plot_parameter_trace utility function to make quick traceplots of any parametric terms, which in this case involves the global error scale $\sigma^2$ and the leaf scale $\sigma^2_{\ell}$
ax = plot_parameter_trace(bart_model, term="global_error_scale")
plt.show()
ax = plot_parameter_trace(bart_model, term="leaf_scale")
plt.show()
For finer-grained control over which parameters to plot, we can also use the extract_parameter() method to pull the posterior distribution of any valid model term (e.g., global error scale $\sigma^2$, leaf scale $\sigma^2_{\ell}$, in-sample mean function predictions y_hat_train) and then plot any subset or transformation of these values.
y_hat_train_samples = bart_model.extract_parameter("y_hat_train")
obs_index = 0
_, ax = plt.subplots()
ax.plot(y_hat_train_samples[obs_index,:])
ax.set_title(f"Parameter Trace: In-Sample Predictions, Observation {obs_index:d}")
ax.set_xlabel("Iteration")
ax.set_ylabel("Parameter Values")
Text(0, 0.5, 'Parameter Values')
Causal Inference¶
We now run the same demo for the causal inference use case served by the BCFModel class.
Below we simulate a simple dataset for a causal inference problem with binary treatment and continuous outcome.
# Generate covariates and treatment
n = 1000
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)
# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[:, 2]
tau_X = X[:, 1] * 2 - 1
# Generate outcome
epsilon = rng.normal(0, 1, n)
y = mu_X + tau_X * Z + epsilon
Now we fit a simple BCF model to the data
bcf_model = BCFModel()
general_params = {"num_chains": 3}
bcf_model.sample(
X_train=X,
Z_train=Z,
y_train=y,
propensity_train=pi_X,
num_gfr=10,
num_mcmc=1000,
general_params=general_params,
)
As above, we can print() this model for a quick overview
print(bcf_model)
BCFModel run with prognostic forest, treatment effect forest, global error variance model, and prognostic forest leaf scale model Outcome was modeled as gaussian Treatment was binary and its effect was estimated with adaptive coding Outcome was standardized User-provided propensity scores were included in the model The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
And we can use the summary() method for a more detailed look at sampled model terms
print(bcf_model.summary())
BCF Model Summary:
------------------
BCFModel run with prognostic forest, treatment effect forest, global error variance model, and prognostic forest leaf scale model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
Summary of sigma^2 posterior: 3000 samples, mean = 0.873, standard deviation = 0.042, quantiles:
2.5%: 0.794
10.0%: 0.820
25.0%: 0.845
50.0%: 0.872
75.0%: 0.901
90.0%: 0.928
97.5%: 0.959
Summary of prognostic forest leaf scale posterior: 3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
2.5%: 0.001
10.0%: 0.001
25.0%: 0.001
50.0%: 0.002
75.0%: 0.002
90.0%: 0.002
97.5%: 0.002
Summary of adaptive coding parameters:
3000 samples, mean (control) = -0.400, mean (treated) = 0.888, standard deviation (control) = 0.263, standard deviation (treated) = 0.285
quantiles (control):
2.5%: -0.984
10.0%: -0.755
25.0%: -0.562
50.0%: -0.373
75.0%: -0.219
90.0%: -0.084
97.5%: 0.067
quantiles (treated):
2.5%: 0.380
10.0%: 0.535
25.0%: 0.692
50.0%: 0.865
75.0%: 1.063
90.0%: 1.264
97.5%: 1.509
Summary of in-sample posterior mean predictions:
1000 observations, mean = 3.482, standard deviation = 0.961, quantiles:
2.5%: 1.836
10.0%: 2.205
25.0%: 2.773
50.0%: 3.469
75.0%: 4.162
90.0%: 4.758
97.5%: 5.326
Summary of in-sample posterior mean CATEs:
1000 observations, mean = -0.026, standard deviation = 0.634, quantiles:
2.5%: -1.084
10.0%: -0.946
25.0%: -0.564
50.0%: 0.100
75.0%: 0.557
90.0%: 0.754
97.5%: 0.901
Finally, we can also plot parametric terms with plot_parameter_trace()
ax = plot_parameter_trace(bcf_model, term="global_error_scale")
plt.show()
ax = plot_parameter_trace(bcf_model, term="adaptive_coding")
plt.show()
For finer-grained control over which parameters to plot, we can also use the extract_parameter() method to pull the posterior distribution of any valid model term (e.g., global error scale $\sigma^2$, prognostic forest leaf scale $\sigma^2_{\mu}$, CATE forest leaf scale $\sigma^2_{\tau}$, adaptive coding parameters $b_0$ and $b_1$ for binary treatment, in-sample mean function predictions y_hat_train, in-sample CATE function predictions tau_hat_train) and then plot any subset or transformation of these values.
tau_hat_train_samples = bcf_model.extract_parameter("tau_hat_train")
obs_index = 0
_, ax = plt.subplots()
ax.plot(tau_hat_train_samples[obs_index,:])
ax.set_title(f"Parameter Trace: In-Sample CATE Predictions, Observation {obs_index:d}")
ax.set_xlabel("Iteration")
ax.set_ylabel("Parameter Values")
Text(0, 0.5, 'Parameter Values')