Skip to contents

Compute posterior credible intervals for specified terms from a fitted BART model. Supports intervals for mean functions, variance functions, random effects, and overall outcome predictions.

Usage

computeBARTPosteriorInterval(
  model_object,
  terms,
  level = 0.95,
  scale = "linear",
  X = NULL,
  leaf_basis = NULL,
  rfx_group_ids = NULL,
  rfx_basis = NULL
)

Arguments

model_object

A fitted BART or BCF model object of class bartmodel.

terms

A character string specifying the model term(s) for which to compute intervals. Options for BART models are "mean_forest", "variance_forest", "rfx", or "y_hat".

level

A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).

scale

(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability". scale = "probability" is only valid for models fit with a probit / cloglog links on binary or ordinal outcomes. For binary outcome models, scale = "probability" will return an interval over the probability that y == 1. For ordinal outcome models, scale = "probability" will return intervals over the "survival function" P(y > k) for k = 1, 2, ..., K-1 where K is the total number of categories. Default: "linear".

X

A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).

leaf_basis

An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.

rfx_group_ids

An optional vector of group IDs for random effects. Required if the requested term includes random effects.

rfx_basis

An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.

Value

A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned.

Examples

n <- 100
p <- 5
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
y <- 2 * X[,1] + rnorm(n)
bart_model <- bart(y_train = y, X_train = X)
intervals <- computeBARTPosteriorInterval(
 model_object = bart_model,
 terms = c("mean_forest", "y_hat"),
 X = X,
 level = 0.90
)
#> Warning: Multiple posterior dimensions matching the number of posterior draws found in the array, using the last one as the MCMC index
#> Warning: Multiple posterior dimensions matching the number of posterior draws found in the array, using the last one as the MCMC index