Extract a vector, matrix or array of parameter samples from a BART model by name.
Random effects are handled by a separate getRandomEffectSamples function due to the complexity of the random effects parameters.
If the requested model term is not found, an error is thrown.
The following conventions are used for parameter names:
Global error variance:
"sigma2","global_error_scale","sigma2_global"Leaf scale:
"sigma2_leaf","leaf_scale"In-sample mean function predictions:
"y_hat_train"Test set mean function predictions:
"y_hat_test"In-sample variance forest predictions:
"sigma2_x_train","var_x_train"Test set variance forest predictions:
"sigma2_x_test","var_x_test"
Usage
# S3 method for class 'bartmodel'
extract_parameter(object, term)Value
Array of parameter samples. If the underlying parameter is a scalar, this will be a vector of length num_samples.
If the underlying parameter is vector-valued, this will be (parameter_dimension x num_samples) matrix, and if the underlying
parameter is multidimensional, this will be an array of dimension (parameter_dimension_1 x parameter_dimension_2 x ... x num_samples).
Examples
n <- 100
p <- 5
X <- matrix(runif(n*p), ncol = p)
f_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
)
snr <- 3
group_ids <- rep(c(1,2), n %/% 2)
rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
rfx_basis <- cbind(1, runif(n, -1, 1))
rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis)
E_y <- f_XW + rfx_term
y <- E_y + rnorm(n, 0, 1)*(sd(E_y)/snr)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]
rfx_group_ids_test <- group_ids[test_inds]
rfx_group_ids_train <- group_ids[train_inds]
rfx_basis_test <- rfx_basis[test_inds,]
rfx_basis_train <- rfx_basis[train_inds,]
rfx_term_test <- rfx_term[test_inds]
rfx_term_train <- rfx_term[train_inds]
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
rfx_group_ids_train = rfx_group_ids_train,
rfx_group_ids_test = rfx_group_ids_test,
rfx_basis_train = rfx_basis_train,
rfx_basis_test = rfx_basis_test,
num_gfr = 10, num_burnin = 0, num_mcmc = 10)
sigma2_samples <- extract_parameter(bart_model, "sigma2")