Heteroskedastic Supervised Learning¶
Load necessary libraries
In [1]:
Copied!
from math import sqrt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from stochtree import BARTModel
from math import sqrt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from stochtree import BARTModel
Generate sample data
In [2]:
Copied!
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)
# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
),
)
# Define the outcome standard deviation function
def outcome_stddev(X):
return np.where(
(X[:, 1] >= 0.0) & (X[:, 1] < 0.25),
sqrt(0.5),
np.where(
(X[:, 1] >= 0.25) & (X[:, 1] < 0.5),
1.0,
np.where((X[:, 1] >= 0.5) & (X[:, 1] < 0.75), 2.0, 3.0),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
f_x = outcome_mean(X, W)
s_x = outcome_stddev(X)
y = f_x + epsilon * s_x
# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y - y_bar) / y_std
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)
# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
),
)
# Define the outcome standard deviation function
def outcome_stddev(X):
return np.where(
(X[:, 1] >= 0.0) & (X[:, 1] < 0.25),
sqrt(0.5),
np.where(
(X[:, 1] >= 0.25) & (X[:, 1] < 0.5),
1.0,
np.where((X[:, 1] >= 0.5) & (X[:, 1] < 0.75), 2.0, 3.0),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
f_x = outcome_mean(X, W)
s_x = outcome_stddev(X)
y = f_x + epsilon * s_x
# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y - y_bar) / y_std
Test-train split
In [3]:
Copied!
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
basis_train = W[train_inds, :]
basis_test = W[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
f_x_train = f_x[train_inds]
f_x_test = f_x[test_inds]
s_x_train = s_x[train_inds]
s_x_test = s_x[test_inds]
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
basis_train = W[train_inds, :]
basis_test = W[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
f_x_train = f_x[train_inds]
f_x_test = f_x[test_inds]
s_x_train = s_x[train_inds]
s_x_test = s_x[test_inds]
Run BART
In [4]:
Copied!
bart_model = BARTModel()
global_params = {"sample_sigma2_global": True}
mean_params = {"num_trees": 100, "sample_sigma2_leaf": False}
variance_params = {"num_trees": 50}
bart_model.sample(
X_train=X_train,
y_train=y_train,
X_test=X_test,
leaf_basis_train=basis_train,
leaf_basis_test=basis_test,
num_gfr=10,
num_mcmc=100,
general_params=global_params,
mean_forest_params=mean_params,
variance_forest_params=variance_params,
)
bart_model = BARTModel()
global_params = {"sample_sigma2_global": True}
mean_params = {"num_trees": 100, "sample_sigma2_leaf": False}
variance_params = {"num_trees": 50}
bart_model.sample(
X_train=X_train,
y_train=y_train,
X_test=X_test,
leaf_basis_train=basis_train,
leaf_basis_test=basis_test,
num_gfr=10,
num_mcmc=100,
general_params=global_params,
mean_forest_params=mean_params,
variance_forest_params=variance_params,
)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/stochtree/bart.py:895: UserWarning: Sampling global error variance not yet supported for models with variance forests, so the global error variance parameter will not be sampled in this model. warnings.warn(
Inspect the MCMC (BART) samples
In [5]:
Copied!
forest_preds_y_mcmc = bart_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis=1, keepdims=True)
y_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y_test, 1), y_avg_mcmc), axis=1),
columns=["True outcome", "Average estimated outcome"],
)
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_preds_y_mcmc = bart_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis=1, keepdims=True)
y_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y_test, 1), y_avg_mcmc), axis=1),
columns=["True outcome", "Average estimated outcome"],
)
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
In [6]:
Copied!
forest_preds_s_x_mcmc = np.sqrt(bart_model.sigma2_x_test)
s_x_avg_mcmc = np.squeeze(forest_preds_s_x_mcmc).mean(axis=1, keepdims=True)
s_x_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(s_x_test, 1), s_x_avg_mcmc), axis=1),
columns=["True standard deviation", "Average estimated standard deviation"],
)
sns.scatterplot(
data=s_x_df_mcmc,
x="Average estimated standard deviation",
y="True standard deviation",
)
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_preds_s_x_mcmc = np.sqrt(bart_model.sigma2_x_test)
s_x_avg_mcmc = np.squeeze(forest_preds_s_x_mcmc).mean(axis=1, keepdims=True)
s_x_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(s_x_test, 1), s_x_avg_mcmc), axis=1),
columns=["True standard deviation", "Average estimated standard deviation"],
)
sns.scatterplot(
data=s_x_df_mcmc,
x="Average estimated standard deviation",
y="True standard deviation",
)
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
In [7]:
Copied!
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(bart_model.global_var_samples.shape[0]), axis=1),
np.expand_dims(bart_model.global_var_samples, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(bart_model.global_var_samples.shape[0]), axis=1),
np.expand_dims(bart_model.global_var_samples, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[7], line 4 1 sigma_df_mcmc = pd.DataFrame( 2 np.concatenate( 3 ( ----> 4 np.expand_dims(np.arange(bart_model.global_var_samples.shape[0]), axis=1), 5 np.expand_dims(bart_model.global_var_samples, axis=1), 6 ), 7 axis=1, 8 ), 9 columns=["Sample", "Sigma"], 10 ) 11 sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") 12 plt.show() AttributeError: 'BARTModel' object has no attribute 'global_var_samples'
Compute the test set RMSE
In [8]:
Copied!
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc), 2)))
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc), 2)))
Out[8]:
np.float64(1.991940817359479)