Skip to contents

Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference. This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or Y0 term and the minuend of the contrast as the Y1 term, though the requested contrast need not match the "control vs treatment" terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the predict.bartmodel function, labeling each prediction data term with a 1 to denote its contribution to the treatment prediction of a contrast and 0 to denote inclusion in the control prediction.

Only valid when there is either a mean forest or a random effects term in the BART model.

Usage

computeContrastBARTModel(
  object,
  X_0,
  X_1,
  leaf_basis_0 = NULL,
  leaf_basis_1 = NULL,
  rfx_group_ids_0 = NULL,
  rfx_group_ids_1 = NULL,
  rfx_basis_0 = NULL,
  rfx_basis_1 = NULL,
  type = "posterior",
  scale = "linear"
)

Arguments

object

Object of type bart containing draws of a regression forest and associated sampling outputs.

X_0

Covariates used for prediction in the "control" case. Must be a matrix or dataframe.

X_1

Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.

leaf_basis_0

(Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: NULL.

leaf_basis_1

(Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: NULL.

rfx_group_ids_0

(Optional) Test set group labels used for prediction from an additive random effects model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a vector.

rfx_group_ids_1

(Optional) Test set group labels used for prediction from an additive random effects model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a vector.

rfx_basis_0

(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector.

rfx_basis_1

(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.

type

(Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".

scale

(Optional) Scale of the contrast. Options are "linear", which returns contrast of predictions on the original scale of the mean forest / RFX terms, and "probability". scale = "probability" is only valid for models fit with a probit / cloglog links on binary or ordinal outcomes. For binary outcome models, scale = "probability" will return contrasts of the probability that y == 1. For ordinal outcome models, scale = "probability" will return contrasts over the "survival function" P(y > k) for k = 1, 2, ..., K-1 where K is the total number of categories. Default: "linear".

Value

Contrast matrix or vector, depending on whether type = "mean" or "posterior".

Examples

n <- 100
p <- 5
X <- matrix(runif(n*p), ncol = p)
W <- matrix(runif(n*1), ncol = 1)
f_XW <- (
    ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
    ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
    ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
    ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
)
noise_sd <- 1
y <- f_XW + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
W_test <- W[test_inds,]
W_train <- W[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]
bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train,
                   num_gfr = 10, num_burnin = 0, num_mcmc = 10)
contrast_test <- computeContrastBARTModel(
    bart_model,
    X_0 = X_test,
    X_1 = X_test,
    leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1),
    leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1),
    type = "posterior",
    scale = "linear"
)