Reparameterized Causal Inference¶
The classic BCF model of Hahn, Murray, and Carvalho (2020) is defined as
$$ Y_i \mid x_i, z_i \sim \mathrm{N}\!\left(f_0(x_i) + \tau(x_i)\, z_i,\, \sigma^2\right) $$
where $f_0$ and $\tau$ each have BART priors. Separating the prognostic function $f_0(x)$ from the CATE function $\tau(x)$ can improve estimation in settings with strong confounding and treatment effect heterogeneity.
stochtree implements a modification of this model that decomposes the treatment effect function into parametric and nonparametric components:
$$ Y_i \mid x_i, z_i \sim \mathrm{N}\!\left(f_0(x_i) + (\tau_0 + t(x_i))\, z_i,\, \sigma^2\right) $$
where $\tau_0 \sim \mathrm{N}(0,\, \sigma_{\tau_0}^2)$ is a global treatment effect intercept and $t(x_i)$ is a BART forest capturing heterogeneity around it. This allows the forest term to focus on heterogeneity "offsets" relative to a parametric average effect.
Load necessary libraries
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
from stochtree import BCFModel
Set random seed for reproducibility
random_seed = 1234
rng = np.random.default_rng(random_seed)
Binary Treatment with Homogeneous Treatment Effect¶
Consider the following data generating process:
$$ \begin{aligned} y &= \mu(X) + \tau(X)\, Z + \epsilon \\ \mu(X) &= 2\sin(2\pi X_1) - 2(2X_3 - 1) \\ \tau(X) &= 5 \\ \pi(X) &= \Phi\!\left(\mu(X)/4\right) \\ X_1,\ldots,X_p &\sim \mathrm{Uniform}(0,1) \\ Z &\sim \mathrm{Bernoulli}(\pi(X)) \\ \epsilon &\sim \mathrm{N}(0, \sigma^2) \end{aligned} $$
Simulation¶
We draw from the DGP defined above
n = 500
p = 20
snr = 2
X = rng.uniform(0, 1, (n, p))
mu_X = 2 * np.sin(2 * np.pi * X[:, 0]) - 2 * (2 * X[:, 2] - 1)
tau_X = 5.0
pi_X = norm.cdf(mu_X / 4)
Z = rng.binomial(1, pi_X, n).astype(float)
E_XZ = mu_X + Z * tau_X
sigma_true = np.std(E_XZ) / snr
y = E_XZ + rng.standard_normal(n) * sigma_true
And split data into test and train sets
n_test = round(0.5 * n)
n_train = n - n_test
test_inds = np.sort(rng.choice(n, n_test, replace=False))
train_inds = np.setdiff1d(np.arange(n), test_inds)
X_train, X_test = X[train_inds], X[test_inds]
Z_train, Z_test = Z[train_inds], Z[test_inds]
y_train, y_test = y[train_inds], y[test_inds]
pi_train, pi_test = pi_X[train_inds], pi_X[test_inds]
mu_train, mu_test = mu_X[train_inds], mu_X[test_inds]
num_gfr = 0
num_burnin = 1000
num_mcmc = 500
num_trees_tau = 50
general_params = {
"adaptive_coding": True,
"num_chains": 4,
"random_seed": random_seed,
"num_threads": 1,
}
treatment_effect_forest_params = {
"num_trees": num_trees_tau,
"sample_intercept": False,
"sigma2_leaf_init": 1 / num_trees_tau,
}
bcf_model_classic = BCFModel()
bcf_model_classic.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
propensity_train=pi_train,
X_test=X_test,
Z_test=Z_test,
propensity_test=pi_test,
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params=general_params,
treatment_effect_forest_params=treatment_effect_forest_params,
)
Compare the posterior distribution of the ATE to its true value
cate_posterior_classic = bcf_model_classic.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="cate",
)
ate_posterior_classic = np.mean(cate_posterior_classic, axis=0)
plt.figure(figsize=(7, 5))
plt.hist(ate_posterior_classic, density=True, bins=30, color="steelblue", edgecolor="white")
plt.axvline(tau_X, color="red", linestyle="dotted", linewidth=2, label=f"True ATE = {tau_X}")
plt.xlabel("ATE")
plt.ylabel("Density")
plt.title("Posterior Distribution of ATE (Classic BCF)")
plt.legend()
plt.show()
As a rough convergence check, inspect the traceplot of the global error variance $\sigma^2$
sigma2_samples = bcf_model_classic.global_var_samples
plt.figure(figsize=(7, 4))
plt.plot(sigma2_samples, color="steelblue", linewidth=0.8)
plt.axhline(sigma_true**2, color="red", linestyle="dotted", linewidth=2, label=f"True $\\sigma^2$ = {sigma_true**2:.3f}")
plt.xlabel("Iteration")
plt.ylabel("$\\sigma^2$")
plt.title("Traceplot of $\\sigma^2$ (Classic BCF)")
plt.legend()
plt.show()
Reparameterized BCF Model¶
Now we fit the reparameterized model, regularizing the $t(x)$ forest more heavily to account for the standard normal prior on $\tau_0$.
num_trees_tau = 50
general_params = {
"adaptive_coding": False,
"num_chains": 4,
"random_seed": random_seed,
"num_threads": 1,
}
treatment_effect_forest_params = {
"num_trees": num_trees_tau,
"sample_intercept": True,
"sigma2_leaf_init": 0.25 / num_trees_tau,
"tau_0_prior_var": 1.0,
}
bcf_model_reparam = BCFModel()
bcf_model_reparam.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
propensity_train=pi_train,
X_test=X_test,
Z_test=Z_test,
propensity_test=pi_test,
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params=general_params,
treatment_effect_forest_params=treatment_effect_forest_params,
)
Compare the posterior distribution of the ATE to its true value
cate_posterior_reparam = bcf_model_reparam.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="cate",
)
ate_posterior_reparam = np.mean(cate_posterior_reparam, axis=0)
plt.figure(figsize=(7, 5))
plt.hist(ate_posterior_reparam, density=True, bins=30, color="steelblue", edgecolor="white")
plt.axvline(tau_X, color="red", linestyle="dotted", linewidth=2, label=f"True ATE = {tau_X}")
plt.xlabel("ATE")
plt.ylabel("Density")
plt.title("Posterior Distribution of ATE (Reparameterized BCF)")
plt.legend()
plt.show()
Convergence check: traceplot of $\sigma^2$
sigma2_samples = bcf_model_reparam.global_var_samples
plt.figure(figsize=(7, 4))
plt.plot(sigma2_samples, color="steelblue", linewidth=0.8)
plt.axhline(sigma_true**2, color="red", linestyle="dotted", linewidth=2, label=f"True $\\sigma^2$ = {sigma_true**2:.3f}")
plt.xlabel("Iteration")
plt.ylabel("$\\sigma^2$")
plt.title("Traceplot of $\\sigma^2$ (Reparameterized BCF)")
plt.legend()
plt.show()
Since $t(X)$ is not constrained to sum to zero, $\tau_0$ does not directly identify the ATE. We can see this by comparing the posteriors of $\tau_0$ and $\bar{t}(X)$ (the test-set mean of $t(X)$ for each posterior draw) — they are strongly negatively correlated, reflecting the partial non-identifiability between the intercept and the forest mean.
tau_0_posterior = bcf_model_reparam.tau_0_samples[0, :]
tau_x_posterior = bcf_model_reparam.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="tau",
)
t_x_mean = np.mean(tau_x_posterior, axis=0)
plt.figure(figsize=(6, 5))
plt.scatter(tau_0_posterior, t_x_mean, alpha=0.3, s=10, color="steelblue")
plt.xlabel("$\\tau_0$")
plt.ylabel("$\\bar{t}(X)$")
plt.title("Posterior of $\\tau_0$ vs $\\bar{t}(X)$")
plt.show()
While stochtree does not currently support constraining $t(X)$ to sum to zero over the training set, we can more heavily regularize $t(X)$ so its values stay close to zero. Using a single tree with a very small leaf scale effectively collapses the forest to a constant near zero, making $\tau_0$ the primary vehicle for the treatment effect.
general_params = {
"adaptive_coding": False,
"num_chains": 4,
"random_seed": random_seed,
"num_threads": 1,
}
treatment_effect_forest_params = {
"num_trees": 1,
"sample_intercept": True,
"sigma2_leaf_init": 1e-6,
"tau_0_prior_var": 1.0,
}
bcf_model_reparam_shrunk = BCFModel()
bcf_model_reparam_shrunk.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
propensity_train=pi_train,
X_test=X_test,
Z_test=Z_test,
propensity_test=pi_test,
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params=general_params,
treatment_effect_forest_params=treatment_effect_forest_params,
)
cate_posterior_shrunk = bcf_model_reparam_shrunk.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="cate",
)
ate_posterior_shrunk = np.mean(cate_posterior_shrunk, axis=0)
plt.figure(figsize=(7, 5))
plt.hist(ate_posterior_shrunk, density=True, bins=30, color="steelblue", edgecolor="white")
plt.axvline(tau_X, color="red", linestyle="dotted", linewidth=2, label=f"True ATE = {tau_X}")
plt.xlabel("ATE")
plt.ylabel("Density")
plt.title("Posterior Distribution of ATE (Shrunk Forest)")
plt.legend()
plt.show()
With the forest heavily regularized, $\tau_0$ and $\bar{t}(X)$ are no longer correlated — $\bar{t}(X)$ is near zero and $\tau_0$ directly captures the treatment effect.
tau_0_posterior_shrunk = bcf_model_reparam_shrunk.tau_0_samples[0, :]
tau_x_posterior_shrunk = bcf_model_reparam_shrunk.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="tau",
)
t_x_mean_shrunk = np.mean(tau_x_posterior_shrunk, axis=0)
plt.figure(figsize=(6, 5))
plt.scatter(tau_0_posterior_shrunk, t_x_mean_shrunk, alpha=0.3, s=10, color="steelblue")
plt.xlabel("$\\tau_0$")
plt.ylabel("$\\bar{t}(X)$")
plt.title("Posterior of $\\tau_0$ vs $\\bar{t}(X)$ (Shrunk Forest)")
plt.show()
We can further regularize estimation of the ATE by reducing $\sigma_{\tau_0}^2$
general_params = {
"adaptive_coding": False,
"num_chains": 4,
"random_seed": random_seed,
"num_threads": 1,
}
treatment_effect_forest_params = {
"num_trees": 1,
"sample_intercept": True,
"sigma2_leaf_init": 1e-6,
"tau_0_prior_var": 0.05,
}
bcf_model_tight_prior = BCFModel()
bcf_model_tight_prior.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
propensity_train=pi_train,
X_test=X_test,
Z_test=Z_test,
propensity_test=pi_test,
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params=general_params,
treatment_effect_forest_params=treatment_effect_forest_params,
)
cate_posterior_tight = bcf_model_tight_prior.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="cate",
)
ate_posterior_tight = np.mean(cate_posterior_tight, axis=0)
plt.figure(figsize=(7, 5))
plt.hist(ate_posterior_tight, density=True, bins=30, color="steelblue", edgecolor="white")
plt.axvline(tau_X, color="red", linestyle="dotted", linewidth=2, label=f"True ATE = {tau_X}")
plt.xlabel("ATE")
plt.ylabel("Density")
plt.title("Posterior Distribution of ATE")
plt.legend()
plt.show()
Continuous Treatment with Homogeneous Treatment Effect¶
The $\tau_0 + t(x)$ reparameterization generalizes naturally to continuous treatment. With a continuous $Z$, $\tau(x)$ represents the marginal effect of a one-unit increase in $Z$, and $\tau_0$ captures the homogeneous component of that effect.
Consider the following data generating process:
$$ \begin{aligned} y &= \mu(X) + \tau(X)\, Z + \epsilon \\ \mu(X) &= 2\sin(2\pi X_1) - 2(2X_3 - 1) \\ \tau(X) &= 2 \\ \pi(X) &= \mathrm{E}[Z \mid X] = \mu(X)/8 \\ Z \mid X &\sim \mathrm{N}(\pi(X),\, 1) \\ \epsilon &\sim \mathrm{N}(0, \sigma^2) \end{aligned} $$
Simulation¶
We draw from the DGP defined above
n = 500
p = 20
snr = 2
X = rng.uniform(0, 1, (n, p))
mu_X = 2 * np.sin(2 * np.pi * X[:, 0]) - 2 * (2 * X[:, 2] - 1)
tau_X = 2.0
pi_X = mu_X / 8
Z = pi_X + rng.standard_normal(n)
E_XZ = mu_X + Z * tau_X
sigma_true = np.std(E_XZ) / snr
y = E_XZ + rng.standard_normal(n) * sigma_true
And split data into test and train sets
n_test = round(0.5 * n)
n_train = n - n_test
test_inds = np.sort(rng.choice(n, n_test, replace=False))
train_inds = np.setdiff1d(np.arange(n), test_inds)
X_train, X_test = X[train_inds], X[test_inds]
Z_train, Z_test = Z[train_inds], Z[test_inds]
y_train, y_test = y[train_inds], y[test_inds]
pi_train, pi_test = pi_X[train_inds], pi_X[test_inds]
mu_train, mu_test = mu_X[train_inds], mu_X[test_inds]
num_gfr = 0
num_burnin = 1000
num_mcmc = 500
num_trees_tau = 50
general_params = {
"adaptive_coding": False,
"num_chains": 4,
"random_seed": random_seed,
"num_threads": 1,
}
treatment_effect_forest_params = {
"num_trees": num_trees_tau,
"sample_intercept": False,
"sigma2_leaf_init": 1 / num_trees_tau,
}
bcf_model_classic = BCFModel()
bcf_model_classic.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
propensity_train=pi_train,
X_test=X_test,
Z_test=Z_test,
propensity_test=pi_test,
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params=general_params,
treatment_effect_forest_params=treatment_effect_forest_params,
)
We compare the posterior distribution of the ATE to its true value
cate_posterior_classic = bcf_model_classic.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="cate",
)
ate_posterior_classic = np.mean(cate_posterior_classic, axis=0)
plt.figure(figsize=(7, 5))
plt.hist(ate_posterior_classic, density=True, bins=30, color="steelblue", edgecolor="white")
plt.axvline(tau_X, color="red", linestyle="dotted", linewidth=2, label="True ATE")
plt.xlabel("ATE")
plt.ylabel("Density")
plt.title("Posterior Distribution of ATE")
plt.legend()
plt.show()
As a rough convergence check, we inspect the traceplot of $\sigma^2$
sigma2_samples = bcf_model_classic.global_var_samples
plt.figure(figsize=(7, 4))
plt.plot(sigma2_samples, color="steelblue", linewidth=0.8)
plt.axhline(sigma_true**2, color="red", linestyle="dotted", linewidth=2, label="True $\\sigma^2$")
plt.xlabel("Iteration")
plt.ylabel("$\\sigma^2$")
plt.title("Traceplot of $\\sigma^2$")
plt.legend()
plt.show()
Reparameterized BCF Model¶
general_params = {
"adaptive_coding": False,
"num_chains": 4,
"random_seed": random_seed,
"num_threads": 1,
}
treatment_effect_forest_params = {
"num_trees": num_trees_tau,
"sample_intercept": True,
"sigma2_leaf_init": 0.25 / num_trees_tau,
"tau_0_prior_var": 1.0,
}
bcf_model_reparam = BCFModel()
bcf_model_reparam.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
propensity_train=pi_train,
X_test=X_test,
Z_test=Z_test,
propensity_test=pi_test,
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params=general_params,
treatment_effect_forest_params=treatment_effect_forest_params,
)
And we compare the posterior distribution of the ATE to its true value
cate_posterior_reparam = bcf_model_reparam.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="cate",
)
ate_posterior_reparam = np.mean(cate_posterior_reparam, axis=0)
plt.figure(figsize=(7, 5))
plt.hist(ate_posterior_reparam, density=True, bins=30, color="steelblue", edgecolor="white")
plt.axvline(tau_X, color="red", linestyle="dotted", linewidth=2, label=f"True ATE = {tau_X}")
plt.xlabel("ATE")
plt.ylabel("Density")
plt.title("Posterior Distribution of ATE (Reparameterized BCF, Continuous Treatment)")
plt.legend()
plt.show()
As above, we check convergence by inspecting the traceplot of $\sigma^2$
sigma2_samples = bcf_model_reparam.global_var_samples
plt.figure(figsize=(7, 4))
plt.plot(sigma2_samples, color="steelblue", linewidth=0.8)
plt.axhline(sigma_true**2, color="red", linestyle="dotted", linewidth=2, label="True $\\sigma^2$")
plt.xlabel("Iteration")
plt.ylabel("$\\sigma^2$")
plt.title("Traceplot of $\\sigma^2$")
plt.legend()
plt.show()
As in the binary treatment case, $\tau_0$ and $\bar{t}(X)$ are negatively correlated across posterior draws
tau_0_posterior = bcf_model_reparam.tau_0_samples[0, :]
tau_x_posterior = bcf_model_reparam.predict(
X=X_test,
Z=Z_test,
propensity=pi_test,
type="posterior",
terms="tau",
)
t_x_mean = np.mean(tau_x_posterior, axis=0)
plt.figure(figsize=(6, 5))
plt.scatter(tau_0_posterior, t_x_mean, alpha=0.3, s=10, color="steelblue")
plt.xlabel("$\\tau_0$")
plt.ylabel("$\\bar{t}(X)$")
plt.title("Posterior of $\\tau_0$ vs $\\bar{t}(X)$")
plt.show()
References¶
Hahn, P Richard, Jared S Murray, and Carlos M Carvalho. 2020. “Bayesian Regression Tree Models for Causal Inference: Regularization, Confounding, and Heterogeneous Effects (with Discussion).” Bayesian Analysis 15 (3): 965–1056.