Sampler API#
stochtree.sampler.ForestSampler
#
Wrapper around many of the core C++ sampling data structures and algorithms.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Dataset
|
|
required |
feature_types
|
array
|
Array of integer-coded values indicating the column type of each feature in |
required |
num_trees
|
int
|
Number of trees in the forest model that this sampler class will fit. |
required |
num_obs
|
int
|
Number of observations / "rows" in |
required |
alpha
|
float
|
Prior probability of splitting for a tree of depth 0 in a forest model. Tree split prior combines |
required |
beta
|
float
|
Exponent that decreases split probabilities for nodes of depth > 0 in a forest model. Tree split prior combines |
required |
min_samples_leaf
|
int
|
Minimum allowable size of a leaf, in terms of training samples, in a forest model. |
required |
max_depth
|
int
|
Maximum depth of any tree in the ensemble in a forest model. |
-1
|
reconstitute_from_forest(forest, dataset, residual, is_mean_model)
#
Re-initialize a forest sampler tracking data structures from a specific forest in a ForestContainer
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Dataset
|
|
required |
residual
|
Residual
|
|
required |
forest
|
Forest
|
|
required |
is_mean_model
|
bool
|
Indicator of whether the model being updated a conditional mean model ( |
required |
sample_one_iteration(forest_container, forest, dataset, residual, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, gfr, pre_initialized)
#
Sample one iteration of a forest using the specified model and tree sampling algorithm
Parameters:
Name | Type | Description | Default |
---|---|---|---|
forest_container
|
ForestContainer
|
|
required |
forest
|
Forest
|
|
required |
dataset
|
Dataset
|
|
required |
residual
|
Residual
|
|
required |
rng
|
RNG
|
|
required |
feature_types
|
array
|
Array of integer-coded feature types (0 = numeric, 1 = ordered categorical, 2 = unordered categorical) |
required |
cutpoint_grid_size
|
int
|
Maximum size of a grid of available cutpoints (which thins the number of possible splits, particularly useful in the grow-from-root algorithm) |
required |
leaf_model_scale_input
|
array
|
Numpy array containing leaf model scale parameter (if the leaf model is univariate, this is essentially a scalar which is used as such in the C++ source, but stored as a numpy array) |
required |
variable_weights
|
array
|
Numpy array containing sampling probabilities for each feature |
required |
a_forest
|
float
|
Shape parameter for the inverse gamma outcome model for a heteroskedasticity forest |
required |
b_forest
|
float
|
Scale parameter for the inverse gamma outcome model for a heteroskedasticity forest |
required |
global_variance
|
float
|
Current value of the global error variance parameter |
required |
leaf_model_int
|
int
|
Integer encoding the leaf model type (0 = constant Gaussian leaf mean model, 1 = univariate Gaussian leaf regression mean model, 2 = multivariate Gaussian leaf regression mean model, 3 = univariate Inverse Gamma constant leaf variance model) |
required |
keep_forest
|
bool
|
Whether or not the resulting forest should be retained in |
required |
gfr
|
bool
|
Whether or not the "grow-from-root" (GFR) sampler is run (if this is |
required |
pre_initialized
|
bool
|
Whether or not the forest being sampled has already been initialized |
required |
prepare_for_sampler(dataset, residual, forest, leaf_model, initial_values)
#
Initialize forest and tracking data structures with constant root values before running a sampler
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Dataset
|
|
required |
residual
|
Residual
|
|
required |
forest
|
Forest
|
|
required |
leaf_model
|
int
|
Integer encoding the leaf model type |
required |
initial_values
|
array
|
Constant root node value(s) at which to initialize forest prediction (internally, it is divided by the number of trees and typically it is 0 for mean models and 1 for variance models). |
required |
adjust_residual(dataset, residual, forest, requires_basis, add)
#
Method that "adjusts" the residual used for training tree ensembles by either adding or subtracting the prediction of each tree to the existing residual.
This is typically run just once at the beginning of a forest sampling algorithm --- after trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Dataset
|
|
required |
residual
|
Residual
|
|
required |
forest
|
Forest
|
|
required |
requires_basis
|
bool
|
Whether or not the forest requires a basis dot product when predicting |
required |
add
|
bool
|
Whether the predictions of each tree are added (if |
required |
propagate_basis_update(dataset, residual, forest)
#
Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall "function" represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Dataset
|
Stochtree dataset object storing covariates / bases / weights |
required |
residual
|
Residual
|
Stochtree object storing continuously updated partial / full residual |
required |
forest
|
Forest
|
Stochtree object storing the "active" forest being sampled |
required |
update_alpha(alpha)
#
Update alpha
in the tree prior
Parameters:
Name | Type | Description | Default |
---|---|---|---|
alpha
|
float
|
New value of |
required |
update_beta(beta)
#
Update beta
in the tree prior
Parameters:
Name | Type | Description | Default |
---|---|---|---|
beta
|
float
|
New value of |
required |
update_min_samples_leaf(min_samples_leaf)
#
Update min_samples_leaf
in the tree prior
Parameters:
Name | Type | Description | Default |
---|---|---|---|
min_samples_leaf
|
int
|
New value of |
required |
update_max_depth(max_depth)
#
Update max_depth
in the tree prior
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_depth
|
int
|
New value of |
required |
stochtree.sampler.GlobalVarianceModel
#
Wrapper around methods / functions for sampling a "global" error variance model with inverse gamma prior.
sample_one_iteration(residual, rng, a, b)
#
Sample one iteration of a global error variance parameter
Parameters:
Name | Type | Description | Default |
---|---|---|---|
residual
|
Residual
|
|
required |
rng
|
RNG
|
|
required |
a
|
float
|
Shape parameter for the inverse gamma error variance model |
required |
b
|
float
|
Scale parameter for the inverse gamma error variance model |
required |
Returns:
Type | Description |
---|---|
float
|
One draw from a Gibbs sampler for the error variance model, which depends
on the rest of the model only through the "full" residual stored in
a |
stochtree.sampler.LeafVarianceModel
#
Wrapper around methods / functions for sampling a "leaf scale" model for the variance term of a Gaussian leaf model with inverse gamma prior.
sample_one_iteration(forest, rng, a, b)
#
Sample one iteration of a forest leaf model's variance parameter (assuming a location-scale leaf model, most commonly N(0, tau)
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
forest
|
Forest
|
|
required |
rng
|
RNG
|
|
required |
a
|
float
|
Shape parameter for the inverse gamma leaf scale model |
required |
b
|
float
|
Scale parameter for the inverse gamma leaf scale model |
required |
Returns:
Type | Description |
---|---|
float
|
One draw from a Gibbs sampler for the leaf scale model, which depends on the rest of the model only through its respective forest. |