Skip to content

Data API#

stochtree.data.Dataset() #

Wrapper around a C++ class that stores all of the non-outcome data used in stochtree. This includes:

  1. Features used for partitioning (also referred to as "covariates" in many places in these docs).
  2. Basis vectors used to define non-constant leaf models. This is optional but may be included via the add_basis method.
  3. Variance weights used to define heteroskedastic or otherwise weighted models. This is optional but may be included via the add_variance_weights method.

add_covariates(covariates) #

Add covariates to a dataset

Parameters:

Name Type Description Default
covariates array

Numpy array of covariates. If data contain categorical, string, time series, or other columns in a dataframe, please first preprocess using the CovariateTransformer.

required

add_basis(basis) #

Add basis matrix to a dataset

Parameters:

Name Type Description Default
basis array

Numpy array of basis vectors.

required

update_basis(basis) #

Update basis matrix in a dataset. Allows users to build an ensemble whose leaves regress on bases that are updated throughout the sampler.

Parameters:

Name Type Description Default
basis array

Numpy array of basis vectors.

required

add_variance_weights(variance_weights) #

Add variance weights to a dataset

Parameters:

Name Type Description Default
variance_weights array

Univariate numpy array of variance weights.

required

num_observations() #

Query the number of observations in a dataset

Returns:

Type Description
int

Number of observations in the dataset

num_covariates() #

Query the number of covariates (features) in a dataset

Returns:

Type Description
int

Number of covariates in the dataset

num_basis() #

Query the dimension of the basis vector in a dataset

Returns:

Type Description
int

Dimension of the basis vector in the dataset, returning 0 if the dataset does not have a basis

has_basis() #

Whether or not a dataset has a basis vector (for leaf regression)

Returns:

Type Description
bool

True if the dataset has a basis, False otherwise

has_variance_weights() #

Whether or not a dataset has variance weights

Returns:

Type Description
bool

True if the dataset has variance weights, False otherwise

stochtree.data.Residual(residual) #

Wrapper around a C++ class that stores residual data used in stochtree. This object becomes part of the real-time model "state" in that its contents always contain a full or partial residual, depending on the state of the sampler.

Typically this object is initialized with the original outcome and then "residualized" by subtracting out the initial prediction value of every tree in every forest term (as well as the predictions of any other model term).

Parameters:

Name Type Description Default
residual array

Univariate numpy array of residual values.

required

get_residual() #

Extract the current values of the residual as a numpy array

Returns:

Type Description
array

Current values of the residual (which may be net of any forest / other model terms)

update_data(new_vector) #

Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of new_vector

Parameters:

Name Type Description Default
new_vector array

Univariate numpy array of new residual values.

required