Summary and Plotting Utilities¶
This notebook demonstrates the summary and plotting utilities available for stochtree models in Python.
We begin by loading all necessary libraries.
import numpy as np
import matplotlib.pyplot as plt
from stochtree import (
BARTModel,
BCFModel,
plot_parameter_trace
)
And set a random seed for reproducibility.
random_seed = 1234
rng = np.random.default_rng(random_seed)
Supervised Learning¶
We begin with the supervised learning use case served by the BARTModel class.
Below we simulate a simple regression dataset.
# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75),
2.5 * W[:, 0],
7.5 * W[:, 0]),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon
Now we fit a simple BART model to the data
bart_model = BARTModel()
general_params = {"num_chains": 3}
bart_model.sample(
X_train=X,
y_train=y,
leaf_basis_train=W,
num_gfr=10,
num_mcmc=1000,
general_params=general_params,
)
We obtain a high level summary of the BART model by running print()
print(bart_model)
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest Outcome was standardized The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
For a more detailed summary (including the information above), we use the summary() method of BARTModel.
print(bart_model.summary())
BART Model Summary:
-------------------
BARTModel run with mean forest, global error variance model, and mean forest leaf scale model
Outcome was modeled as gaussian with a leaf regression prior with 1 bases for the mean forest
Outcome was standardized
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
Summary of sigma^2 posterior: 3000 samples, mean = 0.935, standard deviation = 0.050, quantiles:
2.5%: 0.841
10.0%: 0.872
25.0%: 0.901
50.0%: 0.935
75.0%: 0.966
90.0%: 1.001
97.5%: 1.034
Summary of leaf scale posterior: 3000 samples, mean = 0.007, standard deviation = 0.001, quantiles:
2.5%: 0.005
10.0%: 0.006
25.0%: 0.006
50.0%: 0.007
75.0%: 0.008
90.0%: 0.009
97.5%: 0.011
Summary of in-sample posterior mean predictions:
1000 observations, mean = 0.223, standard deviation = 3.230, quantiles:
2.5%: -6.720
10.0%: -4.291
25.0%: -1.667
50.0%: 0.353
75.0%: 2.024
90.0%: 4.597
97.5%: 6.634
Finally, we can use the plot_parameter_trace utility function to make quick traceplots of any parametric terms, which in this case involves the global error scale $\sigma^2$ and the leaf scale $\sigma^2_{\ell}$
ax = plot_parameter_trace(bart_model, term="global_error_scale")
plt.show()
ax = plot_parameter_trace(bart_model, term="leaf_scale")
plt.show()
For finer-grained control over which parameters to plot, we can also use the extract_parameter() method to pull the posterior distribution of any valid model term (e.g., global error scale $\sigma^2$, leaf scale $\sigma^2_{\ell}$, in-sample mean function predictions y_hat_train) and then plot any subset or transformation of these values.
y_hat_train_samples = bart_model.extract_parameter("y_hat_train")
obs_index = 0
_, ax = plt.subplots()
ax.plot(y_hat_train_samples[obs_index,:])
ax.set_title(f"Parameter Trace: In-Sample Predictions, Observation {obs_index:d}")
ax.set_xlabel("Iteration")
ax.set_ylabel("Parameter Values")
Text(0, 0.5, 'Parameter Values')
Causal Inference¶
We now run the same demo for the causal inference use case served by the BCFModel class.
Below we simulate a simple dataset for a causal inference problem with binary treatment and continuous outcome.
# Generate covariates and treatment
n = 1000
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)
# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[:, 2]
tau_X = X[:, 1] * 2 - 1
# Generate outcome
epsilon = rng.normal(0, 1, n)
y = mu_X + tau_X * Z + epsilon
Now we fit a simple BCF model to the data
bcf_model = BCFModel()
general_params = {"num_chains": 3}
bcf_model.sample(
X_train=X,
Z_train=Z,
y_train=y,
propensity_train=pi_X,
num_gfr=10,
num_mcmc=1000,
general_params=general_params,
)
As above, we can print() this model for a quick overview
print(bcf_model)
BCFModel run with prognostic forest, treatment effect forest, global error variance model, and prognostic forest leaf scale model Outcome was modeled as gaussian Treatment was binary and its effect was estimated with adaptive coding Outcome was standardized User-provided propensity scores were included in the model The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
And we can use the summary() method for a more detailed look at sampled model terms
print(bcf_model.summary())
BCF Model Summary:
------------------
BCFModel run with prognostic forest, treatment effect forest, global error variance model, and prognostic forest leaf scale model
Outcome was modeled as gaussian
Treatment was binary and its effect was estimated with adaptive coding
Outcome was standardized
User-provided propensity scores were included in the model
The sampler was run for 10 GFR iterations, with 3 chains of 0 burn-in iterations and 1000 MCMC iterations, retaining every iteration (i.e. no thinning)
Summary of sigma^2 posterior: 3000 samples, mean = 0.871, standard deviation = 0.041, quantiles:
2.5%: 0.796
10.0%: 0.821
25.0%: 0.843
50.0%: 0.870
75.0%: 0.898
90.0%: 0.924
97.5%: 0.958
Summary of prognostic forest leaf scale posterior: 3000 samples, mean = 0.002, standard deviation = 0.000, quantiles:
2.5%: 0.001
10.0%: 0.001
25.0%: 0.002
50.0%: 0.002
75.0%: 0.002
90.0%: 0.002
97.5%: 0.003
Summary of adaptive coding parameters:
3000 samples, mean (control) = -0.406, mean (treated) = 0.846, standard deviation (control) = 0.269, standard deviation (treated) = 0.286
quantiles (control):
2.5%: -0.978
10.0%: -0.773
25.0%: -0.568
50.0%: -0.373
75.0%: -0.227
90.0%: -0.096
97.5%: 0.067
quantiles (treated):
2.5%: 0.344
10.0%: 0.498
25.0%: 0.654
50.0%: 0.826
75.0%: 1.012
90.0%: 1.210
97.5%: 1.500
Summary of in-sample posterior mean predictions:
1000 observations, mean = 3.482, standard deviation = 0.964, quantiles:
2.5%: 1.828
10.0%: 2.212
25.0%: 2.773
50.0%: 3.472
75.0%: 4.181
90.0%: 4.773
97.5%: 5.338
Summary of in-sample posterior mean CATEs:
1000 observations, mean = -0.029, standard deviation = 0.633, quantiles:
2.5%: -1.077
10.0%: -0.940
25.0%: -0.579
50.0%: 0.098
75.0%: 0.548
90.0%: 0.746
97.5%: 0.896
Finally, we can also plot parametric terms with plot_parameter_trace()
ax = plot_parameter_trace(bcf_model, term="global_error_scale")
plt.show()
ax = plot_parameter_trace(bcf_model, term="adaptive_coding")
plt.show()
For finer-grained control over which parameters to plot, we can also use the extract_parameter() method to pull the posterior distribution of any valid model term (e.g., global error scale $\sigma^2$, prognostic forest leaf scale $\sigma^2_{\mu}$, CATE forest leaf scale $\sigma^2_{\tau}$, adaptive coding parameters $b_0$ and $b_1$ for binary treatment, in-sample mean function predictions y_hat_train, in-sample CATE function predictions tau_hat_train) and then plot any subset or transformation of these values.
tau_hat_train_samples = bcf_model.extract_parameter("tau_hat_train")
obs_index = 0
_, ax = plt.subplots()
ax.plot(tau_hat_train_samples[obs_index,:])
ax.set_title(f"Parameter Trace: In-Sample CATE Predictions, Observation {obs_index:d}")
ax.set_xlabel("Iteration")
ax.set_ylabel("Parameter Values")
Text(0, 0.5, 'Parameter Values')