Skip to content

stochtree Models Wrapped as sklearn Estimators#

stochtree.StochTreeBARTRegressor #

Bases: RegressorMixin, BaseEstimator

A scikit-learn-compatible estimator that implements a BART regression model.

Parameters:

Name Type Description Default
num_gfr int

The number of grow-from-root (GFR) iterations to run of the BART model.

10
num_burnin int

The number of MCMC iterations of the BART model that will be discarded as "burn-in" samples.

0
num_trees int

The number of retained MCMC iterations to run of the BART model.

100
general_params dict

General parameters for the BART model.

None
mean_forest_params dict

Parameters for the mean forest.

None
variance_forest_params dict

Parameters for the variance forest.

None
rfx_params dict

Parameters for the random effects.

None

Attributes:

Name Type Description
X_ (ndarray, shape(n_samples, n_features))

The covariates (or features) used to define tree partitions.

y_ (ndarray, shape(n_samples))

The outcome variable (or labels) used to evaluate tree partitions.

leaf_regression_basis_ (ndarray, shape(n_samples, n_bases))

The basis functions used for leaf regression model if requested.

rfx_group_ids_ (ndarray, shape(n_samples))

The group IDs for random effects if requested.

rfx_basis_ (ndarray, shape(n_samples, n_rfx_bases))

The basis functions used for random effects if requested.

n_features_in_ int

Number of features seen during :term:fit.

feature_names_in_ ndarray of shape (`n_features_in_`,)

Names of features seen during :term:fit. Defined only when X has feature names that are all strings.

Examples:

>>> from sklearn.datasets import load_boston
>>> from stochtree import StochTreeBARTRegressor
>>> data = load_boston()
>>> X = data.data
>>> y = data.target
>>> reg = StochTreeBARTRegressor()
>>> reg.fit(X, y)
>>> reg.predict(X)

fit(X, y, leaf_regression_basis=None, rfx_group_ids=None, rfx_basis=None) #

Fit a BART regressor by sampling from its posterior.

Parameters:

Name Type Description Default
X array-like, sparse matrix

The covariates used to train a BART forest.

array-like
y (array - like, shape(n_samples) or (n_samples, n_outputs))

The continuous outcomes used to train a BART forest.

required
leaf_regression_basis optional array-like, (n_samples, n_bases)

The basis functions to use for leaf regression model, if requested.

None
rfx_group_ids optional array-like, (n_samples,)

The group IDs for random effects, if requested.

None
rfx_basis optional array-like, (n_samples, n_rfx_bases)

The basis functions to use for random effects, if requested.

None

Returns:

Name Type Description
self object

Returns self.

predict(X, leaf_regression_basis=None, rfx_group_ids=None, rfx_basis=None) #

Predict the outcome based on the provided test data.

Parameters:

Name Type Description Default
X array-like, sparse matrix

The covariates used to predict from a BART forest.

array-like
leaf_regression_basis optional array-like, (n_samples, n_bases)

The basis functions to use for leaf regression model if requested.

None
rfx_group_ids optional array-like, (n_samples,)

The group IDs for random effects if requested.

None
rfx_basis optional array-like, (n_samples, n_rfx_bases)

The basis functions to use for random effects if requested.

None

Returns:

Name Type Description
y (ndarray, shape(n_samples))

Returns an array of predicted target values.

score(X, y, leaf_regression_basis=None, rfx_group_ids=None, rfx_basis=None) #

Compute and return the R2 for a BART regression model

Parameters:

Name Type Description Default
X array-like, sparse matrix

The covariates used to train a BART forest.

array-like
y (array - like, shape(n_samples) or (n_samples, n_outputs))

The continuous outcomes used to train a BART forest.

required
leaf_regression_basis optional array-like, (n_samples, n_bases)

The basis functions to use for leaf regression model, if requested.

None
rfx_group_ids optional array-like, (n_samples,)

The group IDs for random effects, if requested.

None
rfx_basis optional array-like, (n_samples, n_rfx_bases)

The basis functions to use for random effects, if requested.

None

Returns:

Name Type Description
score float

R^2 of self.predict(X, leaf_regression_basis, rfx_group_ids, rfx_basis) with respect to y.

__getstate__() #

Prepare the estimator for pickling.

We convert the BART model to its JSON representation.

__setstate__(state) #

Restore the estimator from a pickled state.

We reconstruct a BART model object from its JSON representation.

stochtree.StochTreeBARTBinaryClassifier #

Bases: ClassifierMixin, BaseEstimator

A scikit-learn-compatible estimator that implements a binary probit BART classifier.

Parameters:

Name Type Description Default
num_gfr int

The number of grow-from-root (GFR) iterations to run of the BART model.

10
num_burnin int

The number of MCMC iterations of the BART model that will be discarded as "burn-in" samples.

0
num_trees int

The number of retained MCMC iterations to run of the BART model.

100
general_params dict

General parameters for the BART model.

None
mean_forest_params dict

Parameters for the mean forest.

None
variance_forest_params dict

Parameters for the variance forest.

None
rfx_params dict

Parameters for the random effects.

None

Attributes:

Name Type Description
X_ (ndarray, shape(n_samples, n_features))

The covariates (or features) used to define tree partitions.

y_ (ndarray, shape(n_samples))

The outcome variable (or labels) used to evaluate tree partitions.

leaf_regression_basis_ (ndarray, shape(n_samples, n_bases))

The basis functions used for leaf regression model if requested.

rfx_group_ids_ (ndarray, shape(n_samples))

The group IDs for random effects if requested.

rfx_basis_ (ndarray, shape(n_samples, n_rfx_bases))

The basis functions used for random effects if requested.

n_features_in_ int

Number of features seen during :term:fit.

feature_names_in_ ndarray of shape (`n_features_in_`,)

Names of features seen during :term:fit. Defined only when X has feature names that are all strings.

Examples:

>>> from sklearn.datasets import load_wine
>>> from stochtree import StochTreeBARTBinaryClassifier
>>> data = load_wine()
>>> X = data.data
>>> y = data.target
>>> clf = StochTreeBARTBinaryClassifier()
>>> clf.fit(X, y)
>>> clf.predict(X)

fit(X, y, leaf_regression_basis=None, rfx_group_ids=None, rfx_basis=None) #

Fit a BART classifier by sampling from its posterior.

Parameters:

Name Type Description Default
X array-like, sparse matrix

The covariates used to train a BART forest.

array-like
y (array - like, shape(n_samples) or (n_samples, n_outputs))

The continuous outcomes used to train a BART forest.

required
leaf_regression_basis optional array-like, (n_samples, n_bases)

The basis functions to use for leaf regression model, if requested.

None
rfx_group_ids optional array-like, (n_samples,)

The group IDs for random effects, if requested.

None
rfx_basis optional array-like, (n_samples, n_rfx_bases)

The basis functions to use for random effects, if requested.

None

Returns:

Name Type Description
self object

Returns self.

decision_function(X, leaf_regression_basis=None, rfx_group_ids=None, rfx_basis=None) #

Evaluate the (linear-scale) decision function for the given input samples.

Parameters:

Name Type Description Default
X array-like, sparse matrix

The covariates used to predict a BART forest.

array-like
leaf_regression_basis optional array-like, (n_samples, n_bases)

The basis functions to use for leaf regression model, if requested.

None
rfx_group_ids optional array-like, (n_samples,)

The group IDs for random effects, if requested.

None
rfx_basis optional array-like, (n_samples, n_rfx_bases)

The basis functions to use for random effects, if requested.

None

Returns:

Name Type Description
y (ndarray, shape(n_samples))

Returns an array of predicted target values.

predict(X, leaf_regression_basis=None, rfx_group_ids=None, rfx_basis=None) #

Predict the target classes for the given input samples.

Parameters:

Name Type Description Default
X array-like, sparse matrix

The covariates used to predict a BART forest.

array-like
leaf_regression_basis optional array-like, (n_samples, n_bases)

The basis functions to use for leaf regression model, if requested.

None
rfx_group_ids optional array-like, (n_samples,)

The group IDs for random effects, if requested.

None
rfx_basis optional array-like, (n_samples, n_rfx_bases)

The basis functions to use for random effects, if requested.

None

Returns:

Name Type Description
y (ndarray, shape(n_samples))

Returns an array of predicted target values.

predict_proba(X, leaf_regression_basis=None, rfx_group_ids=None, rfx_basis=None) #

Predict the target probabilities for the given input samples.

Parameters:

Name Type Description Default
X array-like, sparse matrix

The covariates used to predict a BART forest.

array-like
leaf_regression_basis optional array-like, (n_samples, n_bases)

The basis functions to use for leaf regression model, if requested.

None
rfx_group_ids optional array-like, (n_samples,)

The group IDs for random effects, if requested.

None
rfx_basis optional array-like, (n_samples, n_rfx_bases)

The basis functions to use for random effects, if requested.

None

Returns:

Name Type Description
y (ndarray, shape(n_samples))

Returns an array of predicted target values.

score(X, y, leaf_regression_basis=None, rfx_group_ids=None, rfx_basis=None) #

A reference implementation of a scoring function.

Parameters:

Name Type Description Default
X array-like, sparse matrix

The covariates used to train a BART forest.

array-like
y (array - like, shape(n_samples) or (n_samples, n_outputs))

The continuous outcomes used to train a BART forest.

required
leaf_regression_basis optional array-like, (n_samples, n_bases)

The basis functions to use for leaf regression model, if requested.

None
rfx_group_ids optional array-like, (n_samples,)

The group IDs for random effects, if requested.

None
rfx_basis optional array-like, (n_samples, n_rfx_bases)

The basis functions to use for random effects, if requested.

None

Returns:

Name Type Description
score float

R^2 of self.predict(X, leaf_regression_basis, rfx_group_ids, rfx_basis) with respect to y.

__getstate__() #

Prepare the estimator for pickling.

We convert the BART model to its JSON representation.

__setstate__(state) #

Restore the estimator from a pickled state.

We reconstruct a BART model object from its JSON representation.