Skip to contents

Decision tree ensembles can be represented in part by a "kernel" function whose distance metric is based on the extent to which two observations are mapped to the same leaf nodes. This function group offers utilities for evaluating this kernel.

computeForestLeafIndices computes and return a vector representation of a forest's leaf predictions for every observation in a dataset. The resulting vector has a "row-major" format that can be easily re-represented as as a CSR sparse matrix: elements are organized so that the first n elements correspond to leaf predictions for all n observations in a dataset for the first tree in an ensemble, the next n elements correspond to predictions for the second tree and so on. The "data" for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's leaf indices begin at 3, etc...).

computeForestLeafVariances returns each forest's leaf node scale parameters. If leaf scale is not sampled for the forest in question, the function throws an error that the leaf model does not have a stochastic scale parameter.

computeForestMaxLeafIndex computes and returns the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.

These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/

Usage

computeForestLeafIndices(
  model_object,
  covariates,
  forest_type = NULL,
  propensity = NULL,
  forest_inds = NULL
)

computeForestLeafVariances(model_object, forest_type, forest_inds = NULL)

computeForestMaxLeafIndex(model_object, forest_type = NULL, forest_inds = NULL)

Arguments

model_object

Object of type bartmodel, bcfmodel, or ForestSamples corresponding to a BART / BCF model with at least one forest sample, or a low-level ForestSamples object.

covariates

Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest.

forest_type

Which forest to use from model_object. Valid inputs depend on the model type, and whether or not a

1. BART

  • 'mean': Extracts leaf indices for the mean forest

  • 'variance': Extracts leaf indices for the variance forest

2. BCF

  • 'prognostic': Extracts leaf indices for the prognostic forest

  • 'treatment': Extracts leaf indices for the treatment effect forest

  • 'variance': Extracts leaf indices for the variance forest

3. ForestSamples

  • NULL: It is not necessary to disambiguate when this function is called directly on a ForestSamples object. This is the default value of this

propensity

(Optional) Propensities used for prediction (BCF-only).

forest_inds

(Optional) Indices of the forest sample(s) for which to compute max leaf indices. If not provided, this function will return max leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to forest_num = 0, and so on.

Value

computeForestLeafIndices returns a vector of size num_obs * num_trees, where num_obs = nrow(covariates) and num_trees is the number of trees in the relevant forest of model_object.

computeForestLeafVariances returns a vector of size length(forest_inds) with the leaf scale parameter for each requested forest.

computeForestMaxLeafIndex returns a vector containing the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.

Examples

X <- matrix(runif(10*100), ncol = 10)
y <- -5 + 10*(X[,1] > 0.5) + rnorm(100)
bart_model <- bart(X, y, num_gfr=0, num_mcmc=10)
leaf_indices <- computeForestLeafIndices(bart_model, X, "mean")
leaf_indices <- computeForestLeafIndices(bart_model, X, "mean", 0)
leaf_indices <- computeForestLeafIndices(bart_model, X, "mean", c(1,3,9))
leaf_variances <- computeForestLeafVariances(bart_model, "mean")
leaf_variances <- computeForestLeafVariances(bart_model, "mean", 0)
leaf_variances <- computeForestLeafVariances(bart_model, "mean", c(1,3,5))
max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean")
max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean", 0)
max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean", c(1,3,9))