Multivariate Treatment Causal Inference¶
Load necessary libraries
InĀ [1]:
Copied!
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from stochtree import BCFModel
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from stochtree import BCFModel
Generate sample data
InĀ [2]:
Copied!
# RNG
rng = np.random.default_rng()
# Generate covariates and basis
n = 500
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = np.c_[0.25 + 0.5 * X[:, 0], 0.75 - 0.5 * X[:, 1]]
# Z = rng.uniform(0, 1, (n, 2))
Z = rng.binomial(1, pi_X, (n, 2))
# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X[:, 0] * 5 + pi_X[:, 1] * 2 + 2 * X[:, 2]
tau_X = np.stack((X[:, 1], X[:, 2]), axis=-1)
# Generate outcome
epsilon = rng.normal(0, 1, n)
treatment_term = np.multiply(tau_X, Z).sum(axis=1)
y = mu_X + treatment_term + epsilon
# RNG
rng = np.random.default_rng()
# Generate covariates and basis
n = 500
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = np.c_[0.25 + 0.5 * X[:, 0], 0.75 - 0.5 * X[:, 1]]
# Z = rng.uniform(0, 1, (n, 2))
Z = rng.binomial(1, pi_X, (n, 2))
# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X[:, 0] * 5 + pi_X[:, 1] * 2 + 2 * X[:, 2]
tau_X = np.stack((X[:, 1], X[:, 2]), axis=-1)
# Generate outcome
epsilon = rng.normal(0, 1, n)
treatment_term = np.multiply(tau_X, Z).sum(axis=1)
y = mu_X + treatment_term + epsilon
Test-train split
InĀ [3]:
Copied!
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
Z_train = Z[train_inds, :]
Z_test = Z[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
pi_train = pi_X[train_inds]
pi_test = pi_X[test_inds]
mu_train = mu_X[train_inds]
mu_test = mu_X[test_inds]
tau_train = tau_X[train_inds, :]
tau_test = tau_X[test_inds, :]
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
Z_train = Z[train_inds, :]
Z_test = Z[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
pi_train = pi_X[train_inds]
pi_test = pi_X[test_inds]
mu_train = mu_X[train_inds]
mu_test = mu_X[test_inds]
tau_train = tau_X[train_inds, :]
tau_test = tau_X[test_inds, :]
Run BCF
InĀ [4]:
Copied!
bcf_model = BCFModel()
bcf_model.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
pi_train=pi_train,
X_test=X_test,
Z_test=Z_test,
pi_test=pi_test,
num_gfr=10,
num_mcmc=100,
)
bcf_model = BCFModel()
bcf_model.sample(
X_train=X_train,
Z_train=Z_train,
y_train=y_train,
pi_train=pi_train,
X_test=X_test,
Z_test=Z_test,
pi_test=pi_test,
num_gfr=10,
num_mcmc=100,
)
Inspect the MCMC (BART) samples
InĀ [5]:
Copied!
forest_preds_y_mcmc = bcf_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis=1, keepdims=True)
y_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y_test, 1), y_avg_mcmc), axis=1),
columns=["True outcome", "Average estimated outcome"],
)
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_preds_y_mcmc = bcf_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis=1, keepdims=True)
y_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y_test, 1), y_avg_mcmc), axis=1),
columns=["True outcome", "Average estimated outcome"],
)
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
InĀ [6]:
Copied!
np.sqrt(np.mean(np.power(y_avg_mcmc - y_test, 2)))
np.sqrt(np.mean(np.power(y_avg_mcmc - y_test, 2)))
Out[6]:
np.float64(1.913979587642076)
InĀ [7]:
Copied!
treatment_idx = 0
forest_preds_tau_mcmc = np.squeeze(bcf_model.tau_hat_test[:, :, treatment_idx])
tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis=1, keepdims=True)
tau_df_mcmc = pd.DataFrame(
np.concatenate(
(np.expand_dims(tau_test[:, treatment_idx], 1), tau_avg_mcmc), axis=1
),
columns=["True tau", "Average estimated tau"],
)
sns.scatterplot(data=tau_df_mcmc, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
treatment_idx = 0
forest_preds_tau_mcmc = np.squeeze(bcf_model.tau_hat_test[:, :, treatment_idx])
tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis=1, keepdims=True)
tau_df_mcmc = pd.DataFrame(
np.concatenate(
(np.expand_dims(tau_test[:, treatment_idx], 1), tau_avg_mcmc), axis=1
),
columns=["True tau", "Average estimated tau"],
)
sns.scatterplot(data=tau_df_mcmc, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
InĀ [8]:
Copied!
treatment_idx = 1
forest_preds_tau_mcmc = np.squeeze(bcf_model.tau_hat_test[:, :, treatment_idx])
tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis=1, keepdims=True)
tau_df_mcmc = pd.DataFrame(
np.concatenate(
(np.expand_dims(tau_test[:, treatment_idx], 1), tau_avg_mcmc), axis=1
),
columns=["True tau", "Average estimated tau"],
)
sns.scatterplot(data=tau_df_mcmc, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
treatment_idx = 1
forest_preds_tau_mcmc = np.squeeze(bcf_model.tau_hat_test[:, :, treatment_idx])
tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis=1, keepdims=True)
tau_df_mcmc = pd.DataFrame(
np.concatenate(
(np.expand_dims(tau_test[:, treatment_idx], 1), tau_avg_mcmc), axis=1
),
columns=["True tau", "Average estimated tau"],
)
sns.scatterplot(data=tau_df_mcmc, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
InĀ [9]:
Copied!
treatment_term_mcmc_test = np.multiply(
np.atleast_3d(Z_test).swapaxes(1, 2), bcf_model.tau_hat_test
).sum(axis=2)
treatment_term_test = np.multiply(tau_test, Z_test).sum(axis=1)
treatment_term_mcmc_avg = np.squeeze(treatment_term_mcmc_test).mean(
axis=1, keepdims=True
)
mu_df_mcmc = pd.DataFrame(
np.concatenate(
(np.expand_dims(treatment_term_test, 1), treatment_term_mcmc_avg), axis=1
),
columns=["True treatment term", "Average estimated treatment term"],
)
sns.scatterplot(
data=mu_df_mcmc, x="True treatment term", y="Average estimated treatment term"
)
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
treatment_term_mcmc_test = np.multiply(
np.atleast_3d(Z_test).swapaxes(1, 2), bcf_model.tau_hat_test
).sum(axis=2)
treatment_term_test = np.multiply(tau_test, Z_test).sum(axis=1)
treatment_term_mcmc_avg = np.squeeze(treatment_term_mcmc_test).mean(
axis=1, keepdims=True
)
mu_df_mcmc = pd.DataFrame(
np.concatenate(
(np.expand_dims(treatment_term_test, 1), treatment_term_mcmc_avg), axis=1
),
columns=["True treatment term", "Average estimated treatment term"],
)
sns.scatterplot(
data=mu_df_mcmc, x="True treatment term", y="Average estimated treatment term"
)
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
InĀ [10]:
Copied!
forest_preds_mu_mcmc = bcf_model.mu_hat_test
mu_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis=1, keepdims=True)
mu_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(mu_test, 1), mu_avg_mcmc), axis=1),
columns=["True mu", "Average estimated mu"],
)
sns.scatterplot(data=mu_df_mcmc, x="True mu", y="Average estimated mu")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_preds_mu_mcmc = bcf_model.mu_hat_test
mu_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis=1, keepdims=True)
mu_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(mu_test, 1), mu_avg_mcmc), axis=1),
columns=["True mu", "Average estimated mu"],
)
sns.scatterplot(data=mu_df_mcmc, x="True mu", y="Average estimated mu")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
InĀ [11]:
Copied!
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(
np.arange(bcf_model.num_samples - bcf_model.num_gfr), axis=1
),
np.expand_dims(bcf_model.global_var_samples[bcf_model.num_gfr :], axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(
np.arange(bcf_model.num_samples - bcf_model.num_gfr), axis=1
),
np.expand_dims(bcf_model.global_var_samples[bcf_model.num_gfr :], axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()