Skip to contents

A forest sampler features two types of state: ephemeral and persistent. Persistent state includes objects like ForestSamples and RandomEffectSamples which constitute part of the final sampled model. Ephemeral state supports the sampling computations, but is not retained after the sampler finishes.

The two primary forest-based bits of ephemeral state are the Forest and ForestModel classes, which represent the current state of a forest and its corresponding tracking data structures.

In a linear sampling loop, this ephemeral state is updated with each iteration of the sampler and any retained forests are copied to a ForestSamples object. However, in multi-chain settings, the state of a forest must typically be "reset" at the beginning of a new chain. These function enable this process by synchronizing the state of a Forest and ForestModel with a corresponding element of a ForestSamples object, or by resetting both to their default (root) state.

resetActiveForest resets a Forest object, either from a specific forest in a ForestSamples object or to an ensemble of single-node (i.e. root) trees. resetForestModel re-initializes a forest model (tracking data structures) from a specific forest in a ForestSamples object.

These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/

Usage

resetActiveForest(active_forest, forest_samples = NULL, forest_num = NULL)

resetForestModel(forest_model, forest, dataset, residual, is_mean_model)

Arguments

active_forest

Current active forest

forest_samples

(Optional) Container of forest samples from which to re-initialize active forest. If not provided, active forest will be reset to an ensemble of single-node (i.e. root) trees.

forest_num

(Optional) Index of forest samples from which to initialize active forest. If not provided, active forest will be reset to an ensemble of single-node (i.e. root) trees.

forest_model

Forest model with tracking data structures

forest

Forest from which to re-initialize forest model

dataset

Training dataset object

residual

Residual which will also be updated

is_mean_model

Whether the model being updated is a conditional mean model

Value

Both functions have no return type and operate in-place on the relevant Forest or ForestModel objects

Examples

n <- 100
p <- 10
num_trees <- 100
leaf_dimension <- 1
is_leaf_constant <- TRUE
is_exponentiated <- FALSE
alpha <- 0.95
beta <- 2.0
min_samples_leaf <- 2
max_depth <- 10
feature_types <- as.integer(rep(0, p))
leaf_model <- 0
sigma2 <- 1.0
leaf_scale <- as.matrix(1.0)
variable_weights <- rep(1/p, p)
a_forest <- 1
b_forest <- 1
cutpoint_grid_size <- 100
X <- matrix(runif(n*p), ncol = p)
forest_dataset <- createForestDataset(X)
y <- -5 + 10*(X[,1] > 0.5) + rnorm(n)
outcome <- createOutcome(y)
rng <- createCppRNG(1234)
global_model_config <- createGlobalModelConfig(global_error_variance=sigma2)
forest_model_config <- createForestModelConfig(feature_types=feature_types,
                                               num_trees=num_trees, num_observations=n,
                                               num_features=p, alpha=alpha, beta=beta,
                                               min_samples_leaf=min_samples_leaf,
                                               max_depth=max_depth,
                                               variable_weights=variable_weights,
                                               cutpoint_grid_size=cutpoint_grid_size,
                                               leaf_model_type=leaf_model,
                                               leaf_model_scale=leaf_scale)
forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)
active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
forest_samples <- createForestSamples(num_trees, leaf_dimension,
                                      is_leaf_constant, is_exponentiated)
active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.)
forest_model$sample_one_iteration(
    forest_dataset, outcome, forest_samples, active_forest,
    rng, forest_model_config, global_model_config,
    keep_forest = TRUE, gfr = FALSE
)
resetActiveForest(active_forest, forest_samples, 0)
resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE)