Low-Level Interface¶
While the functions bart()
and bcf()
provide simple and performant
interfaces for supervised learning / causal inference, stochtree
also
offers access to many of the "low-level" data structures that are typically
implemented in C++.
This low-level interface is not designed for performance or even
simplicity --- rather the intent is to provide a "prototype" interface
to the C++ code that doesn't require modifying any C++.
To illustrate when such a prototype interface might be useful, consider that that "classic" BART algorithm is essentially a Metropolis-within-Gibbs sampler, in which the forest is sampled by MCMC, conditional on all of the other model parameters, and then the model parameters are updated by Gibbs.
While the algorithm itself is conceptually simple, much of the core computation is carried out in low-level languages such as C or C++ because of the tree data structures. As a result, any changes to this algorithm, such as supporting heteroskedasticity and categorical outcomes (Murray 2021) or causal effect estimation (Hahn et al 2020) require modifying low-level code.
The prototype interface exposes the core components of the loop above at the R level, thus making it possible to interchange C++ computation for steps like "update forest via Metropolis-Hastings" with R computation for a custom variance model, other user-specified additive mean model components, and so on.
Scenario 1: Supervised Learning¶
Load necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from stochtree import (
RNG,
Dataset,
Forest,
ForestContainer,
ForestSampler,
GlobalVarianceModel,
LeafVarianceModel,
Residual,
ForestModelConfig,
GlobalModelConfig,
)
Generate sample data
# RNG
random_seed = 1234
rng = np.random.default_rng(random_seed)
# Generate covariates and basis
n = 500
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))
# Define the outcome mean function
def outcome_mean(X, W):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * W[:, 0],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * W[:, 0],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
),
)
# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon
# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y - y_bar) / y_std
Set some sampling parameters
alpha = 0.9
beta = 1.25
min_samples_leaf = 1
max_depth = -1
num_trees = 100
cutpoint_grid_size = 100
global_variance_init = 1.0
tau_init = 0.5
leaf_prior_scale = np.array([[tau_init]], order="C")
a_global = 4.0
b_global = 2.0
a_leaf = 2.0
b_leaf = 0.5
leaf_regression = True
feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric
var_weights = np.repeat(1 / p_X, p_X)
if not leaf_regression:
leaf_model = 0
leaf_dimension = 1
elif leaf_regression and p_W == 1:
leaf_model = 1
leaf_dimension = 1
else:
leaf_model = 2
leaf_dimension = p_W
Convert data from numpy to StochTree
representation
# Dataset (covariates and basis)
dataset = Dataset()
dataset.add_covariates(X)
dataset.add_basis(W)
# Residual
residual = Residual(resid)
Initialize tracking and sampling classes
forest_container = ForestContainer(num_trees, W.shape[1], False, False)
active_forest = Forest(num_trees, W.shape[1], False, False)
global_model_config = GlobalModelConfig(global_error_variance=global_variance_init)
forest_model_config = ForestModelConfig(
num_trees=num_trees,
num_features=p_X,
num_observations=n,
feature_types=feature_types,
variable_weights=var_weights,
leaf_dimension=leaf_dimension,
alpha=alpha,
beta=beta,
min_samples_leaf=min_samples_leaf,
max_depth=max_depth,
leaf_model_type=leaf_model,
leaf_model_scale=leaf_prior_scale,
cutpoint_grid_size=cutpoint_grid_size,
)
forest_sampler = ForestSampler(
dataset, global_model_config, forest_model_config
)
cpp_rng = RNG(random_seed)
global_var_model = GlobalVarianceModel()
leaf_var_model = LeafVarianceModel()
# Initialize the leaves of each tree in the mean forest
if leaf_regression:
forest_init_val = np.repeat(0.0, W.shape[1])
else:
forest_init_val = np.array([0.0])
forest_sampler.prepare_for_sampler(
dataset,
residual,
active_forest,
leaf_model,
forest_init_val,
)
Prepare to run the sampler
num_warmstart = 10
num_mcmc = 100
num_samples = num_warmstart + num_mcmc
global_var_samples = np.concatenate(
(np.array([global_variance_init]), np.repeat(0, num_samples))
)
leaf_scale_samples = np.concatenate((np.array([tau_init]), np.repeat(0, num_samples)))
Run the "grow-from-root" (XBART) sampler
for i in range(num_warmstart):
forest_sampler.sample_one_iteration(
forest_container,
active_forest,
dataset,
residual,
cpp_rng,
global_model_config,
forest_model_config,
True,
True,
)
global_var_samples[i + 1] = global_var_model.sample_one_iteration(
residual, cpp_rng, a_global, b_global
)
leaf_scale_samples[i + 1] = leaf_var_model.sample_one_iteration(
active_forest, cpp_rng, a_leaf, b_leaf
)
leaf_prior_scale[0, 0] = leaf_scale_samples[i + 1]
Run the MCMC (BART) sampler, initialized at the last XBART sample
for i in range(num_warmstart, num_samples):
forest_sampler.sample_one_iteration(
forest_container,
active_forest,
dataset,
residual,
cpp_rng,
global_model_config,
forest_model_config,
True,
False,
)
global_var_samples[i + 1] = global_var_model.sample_one_iteration(
residual, cpp_rng, a_global, b_global
)
leaf_scale_samples[i + 1] = leaf_var_model.sample_one_iteration(
active_forest, cpp_rng, a_leaf, b_leaf
)
leaf_prior_scale[0, 0] = leaf_scale_samples[i + 1]
Extract mean function and error variance posterior samples
# Forest predictions
forest_preds = forest_container.predict(dataset) * y_std + y_bar
forest_preds_gfr = forest_preds[:, :num_warmstart]
forest_preds_mcmc = forest_preds[:, num_warmstart:num_samples]
# Global error variance
sigma_samples = np.sqrt(global_var_samples) * y_std
sigma_samples_gfr = sigma_samples[:num_warmstart]
sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]
Inspect the GFR (XBART) samples
forest_pred_avg_gfr = forest_preds_gfr.mean(axis=1, keepdims=True)
forest_pred_df_gfr = pd.DataFrame(
np.concatenate((np.expand_dims(y, axis=1), forest_pred_avg_gfr), axis=1),
columns=["True y", "Average predicted y"],
)
sns.scatterplot(data=forest_pred_df_gfr, x="True y", y="Average predicted y")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
sigma_df_gfr = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(num_warmstart), axis=1),
np.expand_dims(sigma_samples_gfr, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_gfr, x="Sample", y="Sigma")
plt.show()
Inspect the MCMC (BART) samples
forest_pred_avg_mcmc = forest_preds_mcmc.mean(axis=1, keepdims=True)
forest_pred_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(y, axis=1), forest_pred_avg_mcmc), axis=1),
columns=["True y", "Average predicted y"],
)
sns.scatterplot(data=forest_pred_df_mcmc, x="True y", y="Average predicted y")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),
np.expand_dims(sigma_samples_mcmc, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
Scenario 2: Causal Inference¶
Generate sample data
# RNG
random_seed = 101
rng = np.random.default_rng(random_seed)
# Generate covariates and basis
n = 500
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = 0.35 + 0.3 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)
# Define the outcome mean functions (prognostic and treatment effects)
mu_X = (pi_X - 0.5) * 30
# tau_X = np.sin(X[:,1]*2*np.pi)
tau_X = X[:, 1] * 2
# Generate outcome
epsilon = rng.normal(0, 1, n)
y = mu_X + tau_X * Z + epsilon
# Standardize outcome
y_bar = np.mean(y)
y_std = np.std(y)
resid = (y - y_bar) / y_std
Set some sampling parameters
# Prognostic forest parameters
alpha_mu = 0.95
beta_mu = 2.0
min_samples_leaf_mu = 1
max_depth_mu = -1
num_trees_mu = 200
cutpoint_grid_size_mu = 100
tau_init_mu = 1 / num_trees_mu
leaf_prior_scale_mu = np.array([[tau_init_mu]], order="C")
a_leaf_mu = 3.0
b_leaf_mu = 1 / num_trees_mu
leaf_regression_mu = False
feature_types_mu = np.repeat(0, p_X + 1).astype(int) # 0 = numeric
var_weights_mu = np.repeat(1 / (p_X + 1), p_X + 1)
leaf_model_mu = 0
leaf_dimension_mu = 1
# Treatment forest parameters
alpha_tau = 0.75
beta_tau = 3.0
min_samples_leaf_tau = 1
max_depth_tau = -1
num_trees_tau = 100
cutpoint_grid_size_tau = 100
tau_init_tau = 1 / num_trees_tau
leaf_prior_scale_tau = np.array([[tau_init_tau]], order="C")
a_leaf_tau = 3.0
b_leaf_tau = 1 / num_trees_tau
leaf_regression_tau = True
feature_types_tau = np.repeat(0, p_X).astype(int) # 0 = numeric
var_weights_tau = np.repeat(1 / p_X, p_X)
leaf_model_tau = 1
leaf_dimension_tau = 1
# Global parameters
a_global = 2.0
b_global = 1.0
global_variance_init = 1.0
Convert data from numpy to StochTree
representation
# Prognostic Forest Dataset (covariates)
dataset_mu = Dataset()
dataset_mu.add_covariates(np.c_[X, pi_X])
# Treatment Forest Dataset (covariates and treatment variable)
dataset_tau = Dataset()
dataset_tau.add_covariates(X)
dataset_tau.add_basis(Z)
# Residual
residual = Residual(resid)
Initialize tracking and sampling classes
# Global classes
global_model_config = GlobalModelConfig(global_error_variance=global_variance_init)
cpp_rng = RNG(random_seed)
global_var_model = GlobalVarianceModel()
# Prognostic forest sampling classes
forest_container_mu = ForestContainer(num_trees_mu, 1, True, False)
active_forest_mu = Forest(num_trees_mu, 1, True, False)
forest_model_config_mu = ForestModelConfig(
num_trees=num_trees_mu,
num_features=p_X + 1,
num_observations=n,
feature_types=feature_types_mu,
variable_weights=var_weights_mu,
leaf_dimension=leaf_dimension_mu,
alpha=alpha_mu,
beta=beta_mu,
min_samples_leaf=min_samples_leaf_mu,
max_depth=max_depth_mu,
leaf_model_type=leaf_model_mu,
leaf_model_scale=leaf_prior_scale_mu,
cutpoint_grid_size=cutpoint_grid_size_mu,
)
forest_sampler_mu = ForestSampler(
dataset_mu,
global_model_config,
forest_model_config_mu
)
leaf_var_model_mu = LeafVarianceModel()
# Treatment forest sampling classes
forest_container_tau = ForestContainer(
num_trees_tau, 1 if np.ndim(Z) == 1 else Z.shape[1], False, False
)
active_forest_tau = Forest(
num_trees_tau, 1 if np.ndim(Z) == 1 else Z.shape[1], False, False
)
forest_model_config_tau = ForestModelConfig(
num_trees=num_trees_tau,
num_features=p_X,
num_observations=n,
feature_types=feature_types_tau,
variable_weights=var_weights_tau,
leaf_dimension=leaf_dimension_tau,
alpha=alpha_tau,
beta=beta_tau,
min_samples_leaf=min_samples_leaf_tau,
max_depth=max_depth_tau,
leaf_model_type=leaf_model_tau,
leaf_model_scale=leaf_prior_scale_tau,
cutpoint_grid_size=cutpoint_grid_size_tau,
)
forest_sampler_tau = ForestSampler(
dataset_tau,
global_model_config,
forest_model_config_tau
)
leaf_var_model_tau = LeafVarianceModel()
Initialize the leaves of the prognostic and treatment forests
init_mu = np.array([np.squeeze(np.mean(resid))])
forest_sampler_mu.prepare_for_sampler(
dataset_mu,
residual,
active_forest_mu,
leaf_model_mu,
init_mu,
)
init_tau = np.array([0.0])
forest_sampler_tau.prepare_for_sampler(
dataset_tau,
residual,
active_forest_tau,
leaf_model_tau,
init_tau,
)
Prepare to run the sampler
num_warmstart = 10
num_mcmc = 100
num_samples = num_warmstart + num_mcmc
global_var_samples = np.empty(num_samples)
leaf_scale_samples_mu = np.empty(num_samples)
leaf_scale_samples_tau = np.empty(num_samples)
leaf_prior_scale_mu = np.array([[tau_init_mu]])
leaf_prior_scale_tau = np.array([[tau_init_tau]])
current_b0 = -0.5
current_b1 = 0.5
b_0_samples = np.empty(num_samples)
b_1_samples = np.empty(num_samples)
tau_basis = (1 - Z) * current_b0 + Z * current_b1
dataset_tau.update_basis(tau_basis)
Run the "grow-from-root" (XBART) sampler
for i in range(num_warmstart):
# Sample the prognostic forest
forest_sampler_mu.sample_one_iteration(
forest_container_mu,
active_forest_mu,
dataset_mu,
residual,
cpp_rng,
global_model_config,
forest_model_config_mu,
True,
True,
)
# Sample global variance
current_sigma2 = global_var_model.sample_one_iteration(
residual, cpp_rng, a_global, b_global
)
global_model_config.update_global_error_variance(current_sigma2)
# Sample prognostic forest leaf scale
leaf_prior_scale_mu[0, 0] = leaf_var_model_mu.sample_one_iteration(
active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu
)
leaf_scale_samples_mu[i] = leaf_prior_scale_mu[0, 0]
forest_model_config_mu.update_leaf_model_scale(
leaf_prior_scale_mu
)
# Sample the treatment effect forest
forest_sampler_tau.sample_one_iteration(
forest_container_tau,
active_forest_tau,
dataset_tau,
residual,
cpp_rng,
global_model_config,
forest_model_config_tau,
True,
True,
)
# Sample adaptive coding parameters
mu_x = active_forest_mu.predict_raw(dataset_mu)
tau_x = np.squeeze(active_forest_tau.predict_raw(dataset_tau))
s_tt0 = np.sum(tau_x * tau_x * (Z == 0))
s_tt1 = np.sum(tau_x * tau_x * (Z == 1))
partial_resid_mu = resid - np.squeeze(mu_x)
s_ty0 = np.sum(tau_x * partial_resid_mu * (Z == 0))
s_ty1 = np.sum(tau_x * partial_resid_mu * (Z == 1))
current_b0 = rng.normal(
loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)),
scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)),
size=1,
)[0]
current_b1 = rng.normal(
loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)),
scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)),
size=1,
)[0]
tau_basis = (1 - Z) * current_b0 + Z * current_b1
dataset_tau.update_basis(tau_basis)
forest_sampler_tau.propagate_basis_update(dataset_tau, residual, active_forest_tau)
b_0_samples[i] = current_b0
b_1_samples[i] = current_b1
# Sample global variance
current_sigma2 = global_var_model.sample_one_iteration(
residual, cpp_rng, a_global, b_global
)
global_model_config.update_global_error_variance(current_sigma2)
global_var_samples[i] = current_sigma2
# Sample treatment forest leaf scale
leaf_prior_scale_tau[0, 0] = leaf_var_model_tau.sample_one_iteration(
active_forest_tau, cpp_rng, a_leaf_tau, b_leaf_tau
)
leaf_scale_samples_tau[i] = leaf_prior_scale_tau[0, 0]
forest_model_config_tau.update_leaf_model_scale(
leaf_prior_scale_tau
)
Run the MCMC (BART) sampler, initialized at the last XBART sample
for i in range(num_warmstart, num_samples):
# Sample the prognostic forest
forest_sampler_mu.sample_one_iteration(
forest_container_mu,
active_forest_mu,
dataset_mu,
residual,
cpp_rng,
global_model_config,
forest_model_config_mu,
True,
False,
)
# Sample global variance
current_sigma2 = global_var_model.sample_one_iteration(
residual, cpp_rng, a_global, b_global
)
global_model_config.update_global_error_variance(current_sigma2)
# Sample prognostic forest leaf scale
leaf_prior_scale_mu[0, 0] = leaf_var_model_mu.sample_one_iteration(
active_forest_mu, cpp_rng, a_leaf_mu, b_leaf_mu
)
leaf_scale_samples_mu[i] = leaf_prior_scale_mu[0, 0]
forest_model_config_mu.update_leaf_model_scale(
leaf_prior_scale_mu
)
# Sample the treatment effect forest
forest_sampler_tau.sample_one_iteration(
forest_container_tau,
active_forest_tau,
dataset_tau,
residual,
cpp_rng,
global_model_config,
forest_model_config_tau,
True,
False,
)
# Sample adaptive coding parameters
mu_x = active_forest_mu.predict_raw(dataset_mu)
tau_x = np.squeeze(active_forest_tau.predict_raw(dataset_tau))
s_tt0 = np.sum(tau_x * tau_x * (Z == 0))
s_tt1 = np.sum(tau_x * tau_x * (Z == 1))
partial_resid_mu = resid - np.squeeze(mu_x)
s_ty0 = np.sum(tau_x * partial_resid_mu * (Z == 0))
s_ty1 = np.sum(tau_x * partial_resid_mu * (Z == 1))
current_b0 = rng.normal(
loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)),
scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)),
size=1,
)[0]
current_b1 = rng.normal(
loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)),
scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)),
size=1,
)[0]
tau_basis = (1 - Z) * current_b0 + Z * current_b1
dataset_tau.update_basis(tau_basis)
forest_sampler_tau.propagate_basis_update(dataset_tau, residual, active_forest_tau)
b_0_samples[i] = current_b0
b_1_samples[i] = current_b1
# Sample global variance
current_sigma2 = global_var_model.sample_one_iteration(
residual, cpp_rng, a_global, b_global
)
global_model_config.update_global_error_variance(current_sigma2)
global_var_samples[i] = current_sigma2
# Sample treatment forest leaf scale
leaf_prior_scale_tau[0, 0] = leaf_var_model_tau.sample_one_iteration(
active_forest_tau, cpp_rng, a_leaf_tau, b_leaf_tau
)
leaf_scale_samples_tau[i] = leaf_prior_scale_tau[0, 0]
forest_model_config_tau.update_leaf_model_scale(
leaf_prior_scale_tau
)
Extract mean function and error variance posterior samples
# Forest predictions
forest_preds_mu = forest_container_mu.predict(dataset_mu) * y_std + y_bar
forest_preds_mu_gfr = forest_preds_mu[:, :num_warmstart]
forest_preds_mu_mcmc = forest_preds_mu[:, num_warmstart:num_samples]
treatment_coding_samples = b_1_samples - b_0_samples
forest_preds_tau = (
forest_container_tau.predict_raw(dataset_tau)
* y_std
* np.expand_dims(treatment_coding_samples, axis=(0, 2))
)
forest_preds_tau_gfr = forest_preds_tau[:, :num_warmstart]
forest_preds_tau_mcmc = forest_preds_tau[:, num_warmstart:num_samples]
# Global error variance
sigma_samples = np.sqrt(global_var_samples) * y_std
sigma_samples_gfr = sigma_samples[:num_warmstart]
sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]
# Adaptive coding parameters
b_1_samples_gfr = b_1_samples[:num_warmstart] * y_std
b_0_samples_gfr = b_0_samples[:num_warmstart] * y_std
b_1_samples_mcmc = b_1_samples[num_warmstart:] * y_std
b_0_samples_mcmc = b_0_samples[num_warmstart:] * y_std
Inspect the GFR (XBART) samples
forest_preds_tau_avg_gfr = np.squeeze(forest_preds_tau_gfr).mean(axis=1, keepdims=True)
forest_pred_tau_df_gfr = pd.DataFrame(
np.concatenate((np.expand_dims(tau_X, 1), forest_preds_tau_avg_gfr), axis=1),
columns=["True tau", "Average estimated tau"],
)
sns.scatterplot(data=forest_pred_tau_df_gfr, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_pred_avg_gfr = np.squeeze(forest_preds_mu_gfr).mean(axis=1, keepdims=True)
forest_pred_df_gfr = pd.DataFrame(
np.concatenate((np.expand_dims(mu_X, 1), forest_pred_avg_gfr), axis=1),
columns=["True mu", "Average estimated mu"],
)
sns.scatterplot(data=forest_pred_df_gfr, x="True mu", y="Average estimated mu")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
sigma_df_gfr = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(num_warmstart), axis=1),
np.expand_dims(sigma_samples_gfr, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_gfr, x="Sample", y="Sigma")
plt.show()
b_df_gfr = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(num_warmstart), axis=1),
np.expand_dims(b_0_samples_gfr, axis=1),
np.expand_dims(b_1_samples_gfr, axis=1),
),
axis=1,
),
columns=["Sample", "Beta_0", "Beta_1"],
)
sns.scatterplot(data=b_df_gfr, x="Sample", y="Beta_0")
sns.scatterplot(data=b_df_gfr, x="Sample", y="Beta_1")
plt.show()
Inspect the MCMC (BART) samples
forest_pred_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis=1, keepdims=True)
forest_pred_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(tau_X, 1), forest_pred_avg_mcmc), axis=1),
columns=["True tau", "Average estimated tau"],
)
sns.scatterplot(data=forest_pred_df_mcmc, x="True tau", y="Average estimated tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
forest_pred_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis=1, keepdims=True)
forest_pred_df_mcmc = pd.DataFrame(
np.concatenate((np.expand_dims(mu_X, 1), forest_pred_avg_mcmc), axis=1),
columns=["True mu", "Average estimated mu"],
)
sns.scatterplot(data=forest_pred_df_mcmc, x="True mu", y="Average estimated mu")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.show()
sigma_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),
np.expand_dims(sigma_samples_mcmc, axis=1),
),
axis=1,
),
columns=["Sample", "Sigma"],
)
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()
b_df_mcmc = pd.DataFrame(
np.concatenate(
(
np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),
np.expand_dims(b_0_samples_mcmc, axis=1),
np.expand_dims(b_1_samples_mcmc, axis=1),
),
axis=1,
),
columns=["Sample", "Beta_0", "Beta_1"],
)
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
plt.show()
References¶
Murray, Jared S. "Log-linear Bayesian additive regression trees for multinomial logistic and count regression models." Journal of the American Statistical Association 116, no. 534 (2021): 756-769.
Hahn, P. Richard, Jared S. Murray, and Carlos M. Carvalho. "Bayesian regression tree models for causal inference: Regularization, confounding, and heterogeneous effects (with discussion)." Bayesian Analysis 15, no. 3 (2020): 965-1056.