StochTree 0.1.1
Loading...
Searching...
No Matches
stochtree C++ Documentation

Getting Started

stochtree can be built and run as a standalone C++ program directly from source using cmake:

Cloning the Repository

To clone the repository, you must have git installed, which you can do following these instructions.

Once git is available at the command line, navigate to the folder that will store this project (in bash / zsh, this is done by running cd followed by the path to the directory). Then, clone the stochtree repo as a subfolder by running

git clone --recursive https://github.com/StochasticTree/stochtree.git

NOTE: this project incorporates several dependencies as git submodules, which is why the --recursive flag is necessary (some systems may perform a recursive clone without this flag, but --recursive ensures this behavior on all platforms). If you have already cloned the repo without the --recursive flag, you can retrieve the submodules recursively by running git submodule update --init --recursive in the main repo directory.

Key Components

The stochtree C++ core consists of thousands of lines of C++ code, but it can organized and understood through several components (see topics for more detail):

  • Trees: the most important "primitive" of decision tree algorithms is the decision tree itself, which in stochtree is defined by a Tree class as well as a series of static helper functions for prediction.
  • Forest: individual trees are combined into a forest, or ensemble, which in stochtree is defined by the TreeEnsemble class and a container of forests is defined by the ForestContainer class.
  • Dataset: data can be loaded from a variety of sources into a stochtree data layer.
  • Leaf Model: stochtree's data structures are generalized to support a wide range of models, which are defined via specialized classes in the leaf model layer.
  • Sampler: helper functions that sample forests from training data comprise the sampling layer of stochtree.

Extending stochtree

Custom Leaf Models

The leaf model documentation details the key components of new decision tree models: custom LeafModel and SuffStat classes that implement a model's log marginal likelihood and posterior computations.

Adding a new leaf model will consist largely of implementing new versions of each of these classes which track the API of the existing classes. Once these classes exist, they need to be reflected in several places.

Suppose, for the sake of illustration, that the newest custom leaf model is a multinomial logit model.

First, add an entry to the ModelType enumeration for this new model type

enum ModelType {
kConstantLeafGaussian,
kUnivariateRegressionLeafGaussian,
kMultivariateRegressionLeafGaussian,
kLogLinearVariance,
kMultinomialLogit,
};

Next, add entries to the std::variants that bundle related SuffStat and LeafModel classes

using SuffStatVariant = std::variant<GaussianConstantSuffStat,
GaussianUnivariateRegressionSuffStat,
GaussianMultivariateRegressionSuffStat,
LogLinearVarianceSuffStat,
MultinomialLogitSuffStat>;
using LeafModelVariant = std::variant<GaussianConstantLeafModel,
GaussianUnivariateRegressionLeafModel,
GaussianMultivariateRegressionLeafModel,
LogLinearVarianceLeafModel,
MultinomialLogitLeafModel>;

Finally, update the suffStatFactory and leafModelFactory functions to add a logic branch registering these new objects

static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) {
if (model_type == kConstantLeafGaussian) {
return createSuffStat<GaussianConstantSuffStat>();
} else if (model_type == kUnivariateRegressionLeafGaussian) {
return createSuffStat<GaussianUnivariateRegressionSuffStat>();
} else if (model_type == kMultivariateRegressionLeafGaussian) {
return createSuffStat<GaussianMultivariateRegressionSuffStat, int>(basis_dim);
} else if (model_type == kLogLinearVariance) {
return createSuffStat<LogLinearVarianceSuffStat>();
} else if (model_type == kMultinomialLogit) {
return createSuffStat<MultinomialLogitSuffStat>();
} else {
Log::Fatal("Incompatible model type provided to suff stat factory");
}
}
static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) {
if (model_type == kConstantLeafGaussian) {
return createLeafModel<GaussianConstantLeafModel, double>(tau);
} else if (model_type == kUnivariateRegressionLeafGaussian) {
return createLeafModel<GaussianUnivariateRegressionLeafModel, double>(tau);
} else if (model_type == kMultivariateRegressionLeafGaussian) {
return createLeafModel<GaussianMultivariateRegressionLeafModel, Eigen::MatrixXd>(Sigma0);
} else if (model_type == kLogLinearVariance) {
return createLeafModel<LogLinearVarianceLeafModel, double, double>(a, b);
} else if (model_type == kMultinomialLogit) {
return createLeafModel<MultinomialLogitLeafModel>();
} else {
Log::Fatal("Incompatible model type provided to leaf model factory");
}
}