Skip to contents

Supervised learning

High-level functionality for training supervised Bayesian tree ensembles (BART, XBART)

bart()
Run the BART algorithm for supervised learning.
predict(<bartmodel>)
Predict from a sampled BART model on new data

Causal inference

High-level functionality for estimating causal effects using Bayesian tree ensembles (BCF, XBCF)

bcf()
Run the Bayesian Causal Forest (BCF) algorithm for regularized causal effect estimation.
predict(<bcfmodel>)
Predict from a sampled BCF model on new data

Low-level functionality

Serialization

Classes and functions for converting sampling artifacts to JSON and saving to disk

CppJson
Class that stores draws from an random ensemble of decision trees
createCppJson()
Create a new (empty) C++ Json object
createCppJsonFile()
Create a C++ Json object from a Json file
createCppJsonString()
Create a C++ Json object from a Json string
loadForestContainerJson()
Load a container of forest samples from json
loadForestContainerCombinedJson()
Combine multiple JSON model objects containing forests (with the same hierarchy / schema) into a single forest_container
loadForestContainerCombinedJsonString()
Combine multiple JSON strings representing model objects containing forests (with the same hierarchy / schema) into a single forest_container
loadVectorJson()
Load a vector from json
loadScalarJson()
Load a scalar from json
loadRandomEffectSamplesJson()
Load a container of random effect samples from json
loadRandomEffectSamplesCombinedJson()
Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container
loadRandomEffectSamplesCombinedJsonString()
Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container
saveBARTModelToJson()
Convert the persistent aspects of a BART model to (in-memory) JSON
saveBARTModelToJsonFile()
Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file
saveBARTModelToJsonString()
Convert the persistent aspects of a BART model to (in-memory) JSON string
createBARTModelFromJson()
Convert an (in-memory) JSON representation of a BART model to a BART model object which can be used for prediction, etc...
createBARTModelFromJsonFile()
Convert a JSON file containing sample information on a trained BART model to a BART model object which can be used for prediction, etc...
createBARTModelFromJsonString()
Convert a JSON string containing sample information on a trained BART model to a BART model object which can be used for prediction, etc...
createBARTModelFromCombinedJson()
Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object which can be used for prediction, etc...
createBARTModelFromCombinedJsonString()
Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object which can be used for prediction, etc...
saveBCFModelToJson()
Convert the persistent aspects of a BCF model to (in-memory) JSON
saveBCFModelToJsonFile()
Convert the persistent aspects of a BCF model to (in-memory) JSON and save to a file
saveBCFModelToJsonString()
Convert the persistent aspects of a BCF model to (in-memory) JSON string
createBCFModelFromJsonFile()
Convert a JSON file containing sample information on a trained BCF model to a BCF model object which can be used for prediction, etc...
createBCFModelFromJsonString()
Convert a JSON string containing sample information on a trained BCF model to a BCF model object which can be used for prediction, etc...
createBCFModelFromJson()
Convert an (in-memory) JSON representation of a BCF model to a BCF model object which can be used for prediction, etc...
createBCFModelFromCombinedJson()
Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object which can be used for prediction, etc...
createBCFModelFromCombinedJsonString()
Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object which can be used for prediction, etc...

Data

Classes and functions for preparing data for sampling algorithms

ForestDataset
Dataset used to sample a forest
createForestDataset()
Create a forest dataset object
Outcome
Outcome / partial residual used to sample an additive model.
createOutcome()
Create an outcome object
RandomEffectsDataset
Dataset used to sample a random effects model
createRandomEffectsDataset()
Create a random effects dataset object
preprocessTrainData()
Preprocess covariates. DataFrames will be preprocessed based on their column types. Matrices will be passed through assuming all columns are numeric.
preprocessPredictionData()
Preprocess covariates. DataFrames will be preprocessed based on their column types. Matrices will be passed through assuming all columns are numeric.
convertPreprocessorToJson()
Convert the persistent aspects of a covariate preprocessor to (in-memory) C++ JSON object
savePreprocessorToJsonString()
Convert the persistent aspects of a covariate preprocessor to (in-memory) JSON string
createPreprocessorFromJson()
Reload a covariate preprocessor object from a JSON string containing a serialized preprocessor
createPreprocessorFromJsonString()
Reload a covariate preprocessor object from a JSON string containing a serialized preprocessor

Forest

Classes and functions for constructing and persisting forests

Forest
Class that stores a single ensemble of decision trees (often treated as the "active forest")
createForest()
Create a forest
ForestModel
Class that defines and samples a forest model
createForestModel()
Create a forest model object
ForestSamples
Class that stores draws from an random ensemble of decision trees
createForestSamples()
Create a container of forest samples
ForestModelConfig
Object used to get / set parameters and other model configuration options for a forest model in the "low-level" stochtree interface
createForestModelConfig()
Create a forest model config object
GlobalModelConfig
Object used to get / set global parameters and other global model configuration options in the "low-level" stochtree interface
createGlobalModelConfig()
Create a global model config object
CppRNG
Class that wraps a C++ random number generator (for reproducibility)
createCppRNG()
Create an R class that wraps a C++ random number generator
calibrateInverseGammaErrorVariance()
Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022)
computeForestMaxLeafIndex()
Compute and return the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.
computeForestLeafIndices()
Compute vector of forest leaf indices
computeForestLeafVariances()
Compute vector of forest leaf scale parameters
resetActiveForest()
Reset an active forest, either from a specific forest in a ForestContainer or to an ensemble of single-node (i.e. root) trees
resetForestModel()
Re-initialize a forest model (tracking data structures) from a specific forest in a ForestContainer

Random Effects

Classes and functions for constructing and persisting random effects terms

RandomEffectSamples
Class that wraps the "persistent" aspects of a C++ random effects model (draws of the parameters and a map from the original label indices to the 0-indexed label numbers used to place group samples in memory (i.e. the first label is stored in column 0 of the sample matrix, the second label is store in column 1 of the sample matrix, etc...))
createRandomEffectSamples()
Create a RandomEffectSamples object
RandomEffectsModel
The core "model" class for sampling random effects.
createRandomEffectsModel()
Create a RandomEffectsModel object
RandomEffectsTracker
Class that defines a "tracker" for random effects models, most notably storing the data indices available in each group for quicker posterior computation and sampling of random effects terms.
createRandomEffectsTracker()
Create a RandomEffectsTracker object
getRandomEffectSamples()
Generic function for extracting random effect samples from a model object (BCF, BART, etc...)
getRandomEffectSamples(<bartmodel>)
Extract raw sample values for each of the random effect parameter terms.
getRandomEffectSamples(<bcfmodel>)
Extract raw sample values for each of the random effect parameter terms.
sampleGlobalErrorVarianceOneIteration()
Sample one iteration of the (inverse gamma) global variance model
sampleLeafVarianceOneIteration()
Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!)
resetRandomEffectsModel()
Reset a RandomEffectsModel object based on the parameters indexed by sample_num in a RandomEffectsSamples object
resetRandomEffectsTracker()
Reset a RandomEffectsTracker object based on the parameters indexed by sample_num in a RandomEffectsSamples object
rootResetRandomEffectsModel()
Reset a RandomEffectsModel object to its "default" state
rootResetRandomEffectsTracker()
Reset a RandomEffectsTracker object to its "default" state

Package info

High-level package details

stochtree stochtree-package
stochtree: Stochastic Tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference