Predict from a sampled BCF model on new data
Usage
# S3 method for class 'bcfmodel'
predict(
object,
X,
Z,
propensity = NULL,
rfx_group_ids = NULL,
rfx_basis = NULL,
type = "posterior",
terms = "all",
scale = "linear",
...
)Arguments
- object
Object of type
bcfmodelcontaining draws of a Bayesian causal forest model and associated sampling outputs.- X
Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
- Z
Treatments used for prediction.
- propensity
(Optional) Propensities used for prediction.
- rfx_group_ids
(Optional) Test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set.
- rfx_basis
(Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects
model_specof "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used.- type
(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
- terms
(Optional) Which model terms to include in the prediction. Options include
"y_hat","prognostic_function","mu","cate","tau","rfx","variance_forest", or"all".The treatment effect terms follow a three-level hierarchy:
"tau"returnstau_0 + tau(X): the parametric treatment intercept (if sampled) plus the treatment forest. This matchesmodel$tau_hat_train/model$tau_hat_test."cate"additionally folds in the random slope on treatment when random effects are fit withrfx_model_spec = "intercept_plus_treatment"; otherwise it is identical to"tau".The raw forest-only component (without
tau_0) is not directly returned by this method; usemodel$forests_tauto access it.
Similarly for the prognostic term:
"mu"returns the prognostic forest only, while"prognostic_function"additionally folds in the random intercept whenrfx_model_specis"intercept_only"or"intercept_plus_treatment"; otherwise the two are identical.If a model doesn't have random effects or variance forest predictions but one of those terms is requested, the request will simply be ignored. If none of the requested terms are present, this function will return
NULLalong with a warning. Default:"all".- scale
(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing
y == 1. "probability" is only valid for models fit with a probit outcome model. Default: "linear".- ...
(Optional) Other prediction parameters.
Value
List of prediction matrices or single prediction matrix / vector, depending on the terms requested.
Examples
n <- 500
p <- 5
X <- matrix(runif(n*p), ncol = p)
mu_x <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
)
pi_x <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
)
tau_x <- (
((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
)
Z <- rbinom(n, 1, pi_x)
noise_sd <- 1
y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]
bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
propensity_train = pi_train, num_gfr = 10,
num_burnin = 0, num_mcmc = 10)
preds <- predict(bcf_model, X_test, Z_test, pi_test)