A forest sampler features two types of state: ephemeral and persistent. Persistent state includes objects like ForestSamples and RandomEffectSamples which constitute part of the final sampled model. Ephemeral state supports the sampling computations, but is not retained after the sampler finishes.
The two primary forest-based bits of ephemeral state are the Forest and ForestModel classes, which represent the current state of a forest and its corresponding tracking data structures.
In a linear sampling loop, this ephemeral state is updated with each iteration of the sampler and any retained forests are copied to a ForestSamples object. However, in multi-chain settings, the state of a forest must typically be "reset" at the beginning of a new chain. These function enable this process by synchronizing the state of a Forest and ForestModel with a corresponding element of a ForestSamples object, or by resetting both to their default (root) state.
resetActiveForest resets a Forest object, either from a specific forest in a ForestSamples
object or to an ensemble of single-node (i.e. root) trees.
resetForestModel re-initializes a forest model (tracking data structures) from a specific forest in a ForestSamples object.
These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
Usage
resetActiveForest(active_forest, forest_samples = NULL, forest_num = NULL)
resetForestModel(forest_model, forest, dataset, residual, is_mean_model)Arguments
- active_forest
Current active forest
- forest_samples
(Optional) Container of forest samples from which to re-initialize active forest. If not provided, active forest will be reset to an ensemble of single-node (i.e. root) trees.
- forest_num
(Optional) Index of forest samples from which to initialize active forest. If not provided, active forest will be reset to an ensemble of single-node (i.e. root) trees.
- forest_model
Forest model with tracking data structures
- forest
Forest from which to re-initialize forest model
- dataset
Training dataset object
- residual
Residual which will also be updated
- is_mean_model
Whether the model being updated is a conditional mean model
Value
Both functions have no return type and operate in-place on the relevant Forest or ForestModel objects
Examples
n <- 100
p <- 10
num_trees <- 100
leaf_dimension <- 1
is_leaf_constant <- TRUE
is_exponentiated <- FALSE
alpha <- 0.95
beta <- 2.0
min_samples_leaf <- 2
max_depth <- 10
feature_types <- as.integer(rep(0, p))
leaf_model <- 0
sigma2 <- 1.0
leaf_scale <- as.matrix(1.0)
variable_weights <- rep(1/p, p)
a_forest <- 1
b_forest <- 1
cutpoint_grid_size <- 100
X <- matrix(runif(n*p), ncol = p)
forest_dataset <- createForestDataset(X)
y <- -5 + 10*(X[,1] > 0.5) + rnorm(n)
outcome <- createOutcome(y)
rng <- createCppRNG(1234)
global_model_config <- createGlobalModelConfig(global_error_variance=sigma2)
forest_model_config <- createForestModelConfig(feature_types=feature_types,
num_trees=num_trees, num_observations=n,
num_features=p, alpha=alpha, beta=beta,
min_samples_leaf=min_samples_leaf,
max_depth=max_depth,
variable_weights=variable_weights,
cutpoint_grid_size=cutpoint_grid_size,
leaf_model_type=leaf_model,
leaf_model_scale=leaf_scale)
forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)
active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
forest_samples <- createForestSamples(num_trees, leaf_dimension,
is_leaf_constant, is_exponentiated)
active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.)
forest_model$sample_one_iteration(
forest_dataset, outcome, forest_samples, active_forest,
rng, forest_model_config, global_model_config,
keep_forest = TRUE, gfr = FALSE
)
resetActiveForest(active_forest, forest_samples, 0)
resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE)