Skip to contents

Return each forest's leaf node scale parameters.

If leaf scale is not sampled for the forest in question, throws an error that the leaf model does not have a stochastic scale parameter.

This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at stochtree.ai

Usage

computeForestLeafVariances(model_object, forest_type, forest_inds = NULL)

Arguments

model_object

Object of type bartmodel or bcfmodel corresponding to a BART / BCF model with at least one forest sample

forest_type

Which forest to use from model_object. Valid inputs depend on the model type, and whether or not a given forest was sampled in that model.

1. BART

  • 'mean': Extracts leaf indices for the mean forest

  • 'variance': Extracts leaf indices for the variance forest

2. BCF

  • 'prognostic': Extracts leaf indices for the prognostic forest

  • 'treatment': Extracts leaf indices for the treatment effect forest

  • 'variance': Extracts leaf indices for the variance forest

forest_inds

(Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to forest_num = 0, and so on.

Value

Vector of size length(forest_inds) with the leaf scale parameter for each requested forest.

Examples

X <- matrix(runif(10*100), ncol = 10)
y <- -5 + 10*(X[,1] > 0.5) + rnorm(100)
bart_model <- bart(X, y, num_gfr=0, num_mcmc=10)
computeForestLeafVariances(bart_model, "mean")
#>  [1] 0.008128752 0.008018772 0.005952380 0.006414746 0.007315390 0.007005638
#>  [7] 0.007437530 0.008789202 0.009111788 0.008385264
computeForestLeafVariances(bart_model, "mean", 0)
#> [1] 0.008128752
computeForestLeafVariances(bart_model, "mean", c(1,3,5))
#> [1] 0.008018772 0.006414746 0.007005638