Skip to contents

Extract a vector, matrix or array of parameter samples from a BART model by name. Random effects are handled by a separate getRandomEffectSamples function due to the complexity of the random effects parameters. If the requested model term is not found, an error is thrown. The following conventions are used for parameter names:

  • Global error variance: "sigma2", "global_error_scale", "sigma2_global"

  • Leaf scale: "sigma2_leaf", "leaf_scale"

  • In-sample mean function predictions: "y_hat_train"

  • Test set mean function predictions: "y_hat_test"

  • In-sample variance forest predictions: "sigma2_x_train", "var_x_train"

  • Test set variance forest predictions: "sigma2_x_test", "var_x_test"

Usage

# S3 method for class 'bartmodel'
extract_parameter(object, term)

Arguments

object

Object of type bartmodel containing draws of a BART model and associated sampling outputs.

term

Name of the parameter to extract (e.g., "sigma2", "y_hat_train", etc.)

Value

Array of parameter samples. If the underlying parameter is a scalar, this will be a vector of length num_samples. If the underlying parameter is vector-valued, this will be (parameter_dimension x num_samples) matrix, and if the underlying parameter is multidimensional, this will be an array of dimension (parameter_dimension_1 x parameter_dimension_2 x ... x num_samples).

Examples

n <- 100
p <- 5
X <- matrix(runif(n*p), ncol = p)
f_XW <- (
    ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
    ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
    ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
    ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
)
snr <- 3
group_ids <- rep(c(1,2), n %/% 2)
rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
rfx_basis <- cbind(1, runif(n, -1, 1))
rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis)
E_y <- f_XW + rfx_term
y <- E_y + rnorm(n, 0, 1)*(sd(E_y)/snr)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]
rfx_group_ids_test <- group_ids[test_inds]
rfx_group_ids_train <- group_ids[train_inds]
rfx_basis_test <- rfx_basis[test_inds,]
rfx_basis_train <- rfx_basis[train_inds,]
rfx_term_test <- rfx_term[test_inds]
rfx_term_train <- rfx_term[train_inds]
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
                   rfx_group_ids_train = rfx_group_ids_train,
                   rfx_group_ids_test = rfx_group_ids_test,
                   rfx_basis_train = rfx_basis_train,
                   rfx_basis_test = rfx_basis_test,
                   num_gfr = 10, num_burnin = 0, num_mcmc = 10)
sigma2_samples <- extract_parameter(bart_model, "sigma2")