Skip to contents

Predict from a sampled BCF model on new data

Usage

# S3 method for class 'bcfmodel'
predict(
  object,
  X,
  Z,
  propensity = NULL,
  rfx_group_ids = NULL,
  rfx_basis = NULL,
  type = "posterior",
  terms = "all",
  scale = "linear",
  ...
)

Arguments

object

Object of type bcfmodel containing draws of a Bayesian causal forest model and associated sampling outputs.

X

Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.

Z

Treatments used for prediction.

propensity

(Optional) Propensities used for prediction.

rfx_group_ids

(Optional) Test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set.

rfx_basis

(Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects model_spec of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used.

type

(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".

terms

(Optional) Which model terms to include in the prediction. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all".

The treatment effect terms follow a three-level hierarchy:

  • "tau" returns tau_0 + tau(X): the parametric treatment intercept (if sampled) plus the treatment forest. This matches model$tau_hat_train / model$tau_hat_test.

  • "cate" additionally folds in the random slope on treatment when random effects are fit with rfx_model_spec = "intercept_plus_treatment"; otherwise it is identical to "tau".

  • The raw forest-only component (without tau_0) is not directly returned by this method; use model$forests_tau to access it.

Similarly for the prognostic term: "mu" returns the prognostic forest only, while "prognostic_function" additionally folds in the random intercept when rfx_model_spec is "intercept_only" or "intercept_plus_treatment"; otherwise the two are identical.

If a model doesn't have random effects or variance forest predictions but one of those terms is requested, the request will simply be ignored. If none of the requested terms are present, this function will return NULL along with a warning. Default: "all".

scale

(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing y == 1. "probability" is only valid for models fit with a probit outcome model. Default: "linear".

...

(Optional) Other prediction parameters.

Value

List of prediction matrices or single prediction matrix / vector, depending on the terms requested.

Examples

n <- 500
p <- 5
X <- matrix(runif(n*p), ncol = p)
mu_x <- (
    ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
    ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
    ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
    ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
)
pi_x <- (
    ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
    ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
    ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
    ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
)
tau_x <- (
    ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
    ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
    ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
    ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
)
Z <- rbinom(n, 1, pi_x)
noise_sd <- 1
y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]
bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
                 propensity_train = pi_train, num_gfr = 10,
                 num_burnin = 0, num_mcmc = 10)
preds <- predict(bcf_model, X_test, Z_test, pi_test)