Heteroskedastic Supervised Learning¶
Load necessary libraries
In [1]:
Copied!
from math import sqrt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from stochtree import BARTModel
from math import sqrt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from stochtree import BARTModel
Generate sample data
In [2]:
Copied!
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)
# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
),
)
# Define the outcome standard deviation function
def outcome_stddev(X):
return np.where(
(X[:, 1] >= 0.0) & (X[:, 1] < 0.25),
sqrt(0.5),
np.where(
(X[:, 1] >= 0.25) & (X[:, 1] < 0.5),
1.0,
np.where((X[:, 1] >= 0.5) & (X[:, 1] < 0.75), 2.0, 3.0),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
f_x = outcome_mean(X, W)
s_x = outcome_stddev(X)
y = f_x + epsilon * s_x
# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y - y_bar) / y_std
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)
# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
),
)
# Define the outcome standard deviation function
def outcome_stddev(X):
return np.where(
(X[:, 1] >= 0.0) & (X[:, 1] < 0.25),
sqrt(0.5),
np.where(
(X[:, 1] >= 0.25) & (X[:, 1] < 0.5),
1.0,
np.where((X[:, 1] >= 0.5) & (X[:, 1] < 0.75), 2.0, 3.0),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
f_x = outcome_mean(X, W)
s_x = outcome_stddev(X)
y = f_x + epsilon * s_x
# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y - y_bar) / y_std
Test-train split
In [3]:
Copied!
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
basis_train = W[train_inds, :]
basis_test = W[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
f_x_train = f_x[train_inds]
f_x_test = f_x[test_inds]
s_x_train = s_x[train_inds]
s_x_test = s_x[test_inds]
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
basis_train = W[train_inds, :]
basis_test = W[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
f_x_train = f_x[train_inds]
f_x_test = f_x[test_inds]
s_x_train = s_x[train_inds]
s_x_test = s_x[test_inds]
Run BART
In [4]:
Copied!
bart_model = BARTModel()
global_params = {"sample_sigma2_global": True}
mean_params = {"num_trees": 100, "sample_sigma2_leaf": False}
variance_params = {"num_trees": 50}
bart_model.sample(
X_train=X_train,
y_train=y_train,
X_test=X_test,
leaf_basis_train=basis_train,
leaf_basis_test=basis_test,
num_gfr=10,
num_mcmc=100,
general_params=global_params,
mean_forest_params=mean_params,
variance_forest_params=variance_params,
)
bart_model = BARTModel()
global_params = {"sample_sigma2_global": True}
mean_params = {"num_trees": 100, "sample_sigma2_leaf": False}
variance_params = {"num_trees": 50}
bart_model.sample(
X_train=X_train,
y_train=y_train,
X_test=X_test,
leaf_basis_train=basis_train,
leaf_basis_test=basis_test,
num_gfr=10,
num_mcmc=100,
general_params=global_params,
mean_forest_params=mean_params,
variance_forest_params=variance_params,
)
Inspect the MCMC (BART) samples
In [5]:
Copied!
forest_preds_y_mcmc = bart_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis=1, keepdims=True)
y_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y_test, 1), y_avg_mcmc), axis=1),
columns=["True outcome", "Average estimated outcome"],
)
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_preds_y_mcmc = bart_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis=1, keepdims=True)
y_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y_test, 1), y_avg_mcmc), axis=1),
columns=["True outcome", "Average estimated outcome"],
)
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
In [6]:
Copied!
forest_preds_s_x_mcmc = np.sqrt(bart_model.sigma2_x_test)
s_x_avg_mcmc = np.squeeze(forest_preds_s_x_mcmc).mean(axis=1, keepdims=True)
s_x_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(s_x_test, 1), s_x_avg_mcmc), axis=1),
columns=["True standard deviation", "Average estimated standard deviation"],
)
sns.scatterplot(
data=s_x_df_mcmc,
x="Average estimated standard deviation",
y="True standard deviation",
)
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_preds_s_x_mcmc = np.sqrt(bart_model.sigma2_x_test)
s_x_avg_mcmc = np.squeeze(forest_preds_s_x_mcmc).mean(axis=1, keepdims=True)
s_x_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(s_x_test, 1), s_x_avg_mcmc), axis=1),
columns=["True standard deviation", "Average estimated standard deviation"],
)
sns.scatterplot(
data=s_x_df_mcmc,
x="Average estimated standard deviation",
y="True standard deviation",
)
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
In [7]:
Copied!
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(bart_model.global_var_samples.shape[0]), axis=1),
np.expand_dims(bart_model.global_var_samples, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(bart_model.global_var_samples.shape[0]), axis=1),
np.expand_dims(bart_model.global_var_samples, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
Compute the test set RMSE
In [8]:
Copied!
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc), 2)))
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc), 2)))
Out[8]:
np.float64(1.878250353960839)