Model Serialization¶
Demo 1: Supervised Learning¶
Load necessary libraries
In [1]:
Copied!
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from stochtree import BARTModel
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from stochtree import BARTModel
Generate sample data
In [2]:
Copied!
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)
# Generate covariates and basis
n = 100
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon
# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y - y_bar) / y_std
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)
# Generate covariates and basis
n = 100
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon
# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y - y_bar) / y_std
Test-train split
In [3]:
Copied!
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
basis_train = W[train_inds, :]
basis_test = W[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
basis_train = W[train_inds, :]
basis_test = W[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]
Run BART
In [4]:
Copied!
bart_model = BARTModel()
bart_model.sample(
X_train=X_train,
y_train=y_train,
leaf_basis_train=basis_train,
X_test=X_test,
leaf_basis_test=basis_test,
num_gfr=10,
num_mcmc=10,
)
bart_model = BARTModel()
bart_model.sample(
X_train=X_train,
y_train=y_train,
leaf_basis_train=basis_train,
X_test=X_test,
leaf_basis_test=basis_test,
num_gfr=10,
num_mcmc=10,
)
Inspect the MCMC (BART) samples
In [5]:
Copied!
forest_preds_y_mcmc = bart_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis=1, keepdims=True)
y_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y_test, 1), y_avg_mcmc), axis=1),
columns=["True outcome", "Average estimated outcome"],
)
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_preds_y_mcmc = bart_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis=1, keepdims=True)
y_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y_test, 1), y_avg_mcmc), axis=1),
columns=["True outcome", "Average estimated outcome"],
)
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
In [6]:
Copied!
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(bart_model.num_samples), axis=1),
np.expand_dims(bart_model.global_var_samples, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(bart_model.num_samples), axis=1),
np.expand_dims(bart_model.global_var_samples, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
Compute the test set RMSE
In [7]:
Copied!
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc), 2)))
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc), 2)))
Out[7]:
np.float64(1.3023253795721579)
Serialize the BART model to JSON
In [8]:
Copied!
bart_json_string = bart_model.to_json()
bart_json_string = bart_model.to_json()
Deserialize BART model from JSON string
In [9]:
Copied!
bart_model_deserialized = BARTModel()
bart_model_deserialized.from_json(bart_json_string)
bart_model_deserialized = BARTModel()
bart_model_deserialized.from_json(bart_json_string)
Compare predictions
In [10]:
Copied!
y_hat_deserialized = bart_model_deserialized.predict(X_test, basis_test)
y_avg_mcmc_deserialized = np.squeeze(y_hat_deserialized).mean(axis=1, keepdims=True)
y_df = pd.DataFrame(
np.concatenate((y_avg_mcmc, y_avg_mcmc_deserialized), axis=1),
columns=["Original model", "Deserialized model"],
)
sns.scatterplot(data=y_df, x="Original model", y="Deserialized model")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
y_hat_deserialized = bart_model_deserialized.predict(X_test, basis_test)
y_avg_mcmc_deserialized = np.squeeze(y_hat_deserialized).mean(axis=1, keepdims=True)
y_df = pd.DataFrame(
np.concatenate((y_avg_mcmc, y_avg_mcmc_deserialized), axis=1),
columns=["Original model", "Deserialized model"],
)
sns.scatterplot(data=y_df, x="Original model", y="Deserialized model")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
Compare parameter samples
In [11]:
Copied!
sigma2_df = pd.DataFrame(
np.c_[bart_model.global_var_samples, bart_model_deserialized.global_var_samples],
columns=["Original model", "Deserialized model"],
)
sns.scatterplot(data=sigma2_df, x="Original model", y="Deserialized model")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
sigma2_df = pd.DataFrame(
np.c_[bart_model.global_var_samples, bart_model_deserialized.global_var_samples],
columns=["Original model", "Deserialized model"],
)
sns.scatterplot(data=sigma2_df, x="Original model", y="Deserialized model")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
Save to JSON file
In [12]:
Copied!
with open("bart.json", "w") as f:
bart_json_python = json.loads(bart_json_string)
json.dump(bart_json_python, f)
with open("bart.json", "w") as f:
bart_json_python = json.loads(bart_json_string)
json.dump(bart_json_python, f)
Reload from JSON file
In [13]:
Copied!
with open("bart.json", "r") as f:
bart_json_python_reload = json.load(f)
bart_json_string_reload = json.dumps(bart_json_python_reload)
bart_model_file_deserialized = BARTModel()
bart_model_file_deserialized.from_json(bart_json_string_reload)
with open("bart.json", "r") as f:
bart_json_python_reload = json.load(f)
bart_json_string_reload = json.dumps(bart_json_python_reload)
bart_model_file_deserialized = BARTModel()
bart_model_file_deserialized.from_json(bart_json_string_reload)
Compare predictions
In [14]:
Copied!
y_hat_file_deserialized = bart_model_file_deserialized.predict(X_test, basis_test)
y_avg_mcmc_file_deserialized = np.squeeze(y_hat_file_deserialized).mean(
axis=1, keepdims=True
)
y_df = pd.DataFrame(
np.concatenate((y_avg_mcmc, y_avg_mcmc_file_deserialized), axis=1),
columns=["Original model", "Deserialized model"],
)
sns.scatterplot(data=y_df, x="Original model", y="Deserialized model")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
y_hat_file_deserialized = bart_model_file_deserialized.predict(X_test, basis_test)
y_avg_mcmc_file_deserialized = np.squeeze(y_hat_file_deserialized).mean(
axis=1, keepdims=True
)
y_df = pd.DataFrame(
np.concatenate((y_avg_mcmc, y_avg_mcmc_file_deserialized), axis=1),
columns=["Original model", "Deserialized model"],
)
sns.scatterplot(data=y_df, x="Original model", y="Deserialized model")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
Compare parameter samples
In [15]:
Copied!
sigma2_df = pd.DataFrame(
np.c_[
bart_model.global_var_samples, bart_model_file_deserialized.global_var_samples
],
columns=["Original model", "Deserialized model"],
)
sns.scatterplot(data=sigma2_df, x="Original model", y="Deserialized model")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
sigma2_df = pd.DataFrame(
np.c_[
bart_model.global_var_samples, bart_model_file_deserialized.global_var_samples
],
columns=["Original model", "Deserialized model"],
)
sns.scatterplot(data=sigma2_df, x="Original model", y="Deserialized model")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
Clean up JSON file
In [16]:
Copied!
os.remove("bart.json")
os.remove("bart.json")