Getting Started
stochtree
can be built and run as a standalone C++ program directly from source using cmake
:
Cloning the Repository
To clone the repository, you must have git installed, which you can do following these instructions.
Once git is available at the command line, navigate to the folder that will store this project (in bash / zsh, this is done by running cd
followed by the path to the directory). Then, clone the stochtree
repo as a subfolder by running
git clone --recursive https://github.com/StochasticTree/stochtree.git
NOTE: this project incorporates several dependencies as git submodules, which is why the --recursive
flag is necessary (some systems may perform a recursive clone without this flag, but --recursive
ensures this behavior on all platforms). If you have already cloned the repo without the --recursive
flag, you can retrieve the submodules recursively by running git submodule update --init --recursive
in the main repo directory.
Key Components
The stochtree C++ core consists of thousands of lines of C++ code, but it can organized and understood through several components (see topics for more detail):
- Trees: the most important "primitive" of decision tree algorithms is the decision tree itself, which in stochtree is defined by a Tree class as well as a series of static helper functions for prediction.
- Forest: individual trees are combined into a forest, or ensemble, which in stochtree is defined by the TreeEnsemble class and a container of forests is defined by the ForestContainer class.
- Dataset: data can be loaded from a variety of sources into a
stochtree
data layer.
- Leaf Model:
stochtree
's data structures are generalized to support a wide range of models, which are defined via specialized classes in the leaf model layer.
- Sampler: helper functions that sample forests from training data comprise the sampling layer of
stochtree
.
Extending stochtree
Custom Leaf Models
The leaf model documentation details the key components of new decision tree models: custom LeafModel
and SuffStat
classes that implement a model's log marginal likelihood and posterior computations.
Adding a new leaf model will consist largely of implementing new versions of each of these classes which track the API of the existing classes. Once these classes exist, they need to be reflected in several places.
Suppose, for the sake of illustration, that the newest custom leaf model is a multinomial logit model.
First, add an entry to the ModelType enumeration for this new model type
enum ModelType {
kConstantLeafGaussian,
kUnivariateRegressionLeafGaussian,
kMultivariateRegressionLeafGaussian,
kLogLinearVariance,
kMultinomialLogit,
};
Next, add entries to the std::variants
that bundle related SuffStat
and LeafModel
classes
using SuffStatVariant = std::variant<GaussianConstantSuffStat,
GaussianUnivariateRegressionSuffStat,
GaussianMultivariateRegressionSuffStat,
LogLinearVarianceSuffStat,
MultinomialLogitSuffStat>;
using LeafModelVariant = std::variant<GaussianConstantLeafModel,
GaussianUnivariateRegressionLeafModel,
GaussianMultivariateRegressionLeafModel,
LogLinearVarianceLeafModel,
MultinomialLogitLeafModel>;
Finally, update the suffStatFactory and leafModelFactory functions to add a logic branch registering these new objects
static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) {
if (model_type == kConstantLeafGaussian) {
return createSuffStat<GaussianConstantSuffStat>();
} else if (model_type == kUnivariateRegressionLeafGaussian) {
return createSuffStat<GaussianUnivariateRegressionSuffStat>();
} else if (model_type == kMultivariateRegressionLeafGaussian) {
return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim);
} else if (model_type == kLogLinearVariance) {
return createSuffStat<LogLinearVarianceSuffStat>();
} else if (model_type == kMultinomialLogit) {
return createSuffStat<MultinomialLogitSuffStat>();
} else {
Log::Fatal("Incompatible model type provided to suff stat factory");
}
}
static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) {
if (model_type == kConstantLeafGaussian) {
return createLeafModel<GaussianConstantLeafModel, double>(tau);
} else if (model_type == kUnivariateRegressionLeafGaussian) {
return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau);
} else if (model_type == kMultivariateRegressionLeafGaussian) {
return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0);
} else if (model_type == kLogLinearVariance) {
return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b);
} else if (model_type == kMultinomialLogit) {
return createLeafModel<MultinomialLogitLeafModel>();
} else {
Log::Fatal("Incompatible model type provided to leaf model factory");
}
}