StochTree Python API Reference#
Overview of the stochtree
python library's key classes and functions.
The stochtree
interface is divided into two "levels":
- "High level": end-to-end implementations of stochastic tree ensembles for supervised learning (BART / XBART) and causal inference (BCF / XBCF). Both interfaces are designed to mirror the scikit-learn estimator style, with the
.fit()
method replaced by a.sample()
method. - "Low level": we provide access to most of the C++ sampling objects and functionality via Python, which allow for custom sampling algorithms and integration of other model terms. This interface is documented here and consists broadly of the following components:
- Data API: loading and storing in-memory data needed to train
stochtree
models. - Forest API: creating, storing, and modifying ensembles of decision trees that underlie all
stochtree
models. - Sampler API: sampling from stochastic tree ensemble models as well as several supported parametric models.
- Utilities API: seeding a C++ random number generator, preprocessing data, and serializing models to JSON (files or in-memory strings).
- Data API: loading and storing in-memory data needed to train