StochTree 0.1.1
Loading...
Searching...
No Matches
tree_sampler.h
1
2#ifndef STOCHTREE_TREE_SAMPLER_H_
3#define STOCHTREE_TREE_SAMPLER_H_
4
5#include <stochtree/container.h>
6#include <stochtree/cutpoint_candidates.h>
7#include <stochtree/data.h>
8#include <stochtree/discrete_sampler.h>
9#include <stochtree/ensemble.h>
10#include <stochtree/leaf_model.h>
11#include <stochtree/openmp_utils.h>
12#include <stochtree/partition_tracker.h>
13#include <stochtree/prior.h>
14
15#include <cmath>
16#include <numeric>
17#include <random>
18#include <vector>
19
20namespace StochTree {
21
47static inline void VarSplitRange(ForestTracker& tracker, ForestDataset& dataset, int tree_num, int leaf_split, int feature_split, double& var_min, double& var_max) {
48 var_min = std::numeric_limits<double>::max();
49 var_max = std::numeric_limits<double>::min();
50 double feature_value;
51
52 std::vector<data_size_t>::iterator node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split);
53 std::vector<data_size_t>::iterator node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split);
54
55 for (auto i = node_begin_iter; i != node_end_iter; i++) {
56 auto idx = *i;
57 feature_value = dataset.CovariateValue(idx, feature_split);
58 if (feature_value < var_min) {
59 var_min = feature_value;
60 } else if (feature_value > var_max) {
61 var_max = feature_value;
62 }
63 }
64}
65
77static inline bool NodesNonConstantAfterSplit(ForestDataset& dataset, ForestTracker& tracker, TreeSplit& split, int tree_num, int leaf_split, int feature_split) {
78 int p = dataset.GetCovariates().cols();
79 data_size_t idx;
80 double feature_value;
81 double split_feature_value;
82 double var_max_left;
83 double var_min_left;
84 double var_max_right;
85 double var_min_right;
86
87 for (int j = 0; j < p; j++) {
88 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split);
89 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split);
90 var_max_left = std::numeric_limits<double>::min();
91 var_min_left = std::numeric_limits<double>::max();
92 var_max_right = std::numeric_limits<double>::min();
93 var_min_right = std::numeric_limits<double>::max();
94
95 for (auto i = node_begin_iter; i != node_end_iter; i++) {
96 auto idx = *i;
97 split_feature_value = dataset.CovariateValue(idx, feature_split);
98 feature_value = dataset.CovariateValue(idx, j);
99 if (split.SplitTrue(split_feature_value)) {
100 if (var_max_left < feature_value) {
101 var_max_left = feature_value;
102 } else if (var_min_left > feature_value) {
103 var_min_left = feature_value;
104 }
105 } else {
106 if (var_max_right < feature_value) {
107 var_max_right = feature_value;
108 } else if (var_min_right > feature_value) {
109 var_min_right = feature_value;
110 }
111 }
112 }
113 if ((var_max_left > var_min_left) && (var_max_right > var_min_right)) {
114 return true;
115 }
116 }
117 return false;
118}
119
120static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracker, int tree_num, int node_id) {
121 int p = dataset.GetCovariates().cols();
122 data_size_t idx;
123 double feature_value;
124 double var_max;
125 double var_min;
126
127 for (int j = 0; j < p; j++) {
128 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
129 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
130 var_max = std::numeric_limits<double>::min();
131 var_min = std::numeric_limits<double>::max();
132
133 for (auto i = node_begin_iter; i != node_end_iter; i++) {
134 auto idx = *i;
135 feature_value = dataset.CovariateValue(idx, j);
136 if (var_max < feature_value) {
137 var_max = feature_value;
138 } else if (var_min > feature_value) {
139 var_min = feature_value;
140 }
141 }
142 if (var_max > var_min) {
143 return true;
144 }
145 }
146 return false;
147}
148
149static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree,
150 int tree_num, int leaf_node, int feature_split, bool keep_sorted = false, int num_threads = -1) {
151 // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete
152 if (tree->OutputDimension() > 1) {
153 std::vector<double> temp_leaf_values(tree->OutputDimension(), 0.);
154 tree->ExpandNode(leaf_node, feature_split, split, temp_leaf_values, temp_leaf_values);
155 } else {
156 double temp_leaf_value = 0.;
157 tree->ExpandNode(leaf_node, feature_split, split, temp_leaf_value, temp_leaf_value);
158 }
159 int left_node = tree->LeftChild(leaf_node);
160 int right_node = tree->RightChild(leaf_node);
161
162 // Update the ForestTracker
163 tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted, num_threads);
164}
165
166static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree,
167 int tree_num, int leaf_node, int left_node, int right_node, bool keep_sorted = false) {
168 // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete
169 if (tree->OutputDimension() > 1) {
170 std::vector<double> temp_leaf_values(tree->OutputDimension(), 0.);
171 tree->CollapseToLeaf(leaf_node, temp_leaf_values);
172 } else {
173 double temp_leaf_value = 0.;
174 tree->CollapseToLeaf(leaf_node, temp_leaf_value);
175 }
176
177 // Update the ForestTracker
178 tracker.RemoveSplit(dataset.GetCovariates(), tree, tree_num, leaf_node, left_node, right_node, keep_sorted);
179}
180
181static inline double ComputeMeanOutcome(ColumnVector& residual) {
182 int n = residual.NumRows();
183 double sum_y = 0.;
184 double y;
185 for (data_size_t i = 0; i < n; i++) {
186 y = residual.GetElement(i);
187 sum_y += y;
188 }
189 return sum_y / static_cast<double>(n);
190}
191
192static inline double ComputeVarianceOutcome(ColumnVector& residual) {
193 int n = residual.NumRows();
194 double sum_y = 0.;
195 double sum_y_sq = 0.;
196 double y;
197 for (data_size_t i = 0; i < n; i++) {
198 y = residual.GetElement(i);
199 sum_y += y;
200 sum_y_sq += y * y;
201 }
202 return sum_y_sq / static_cast<double>(n) - (sum_y * sum_y) / (static_cast<double>(n) * static_cast<double>(n));
203}
204
205static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual,
206 TreeEnsemble* forest, bool requires_basis, std::function<double(double, double)> op) {
207 data_size_t n = dataset.GetCovariates().rows();
208 double tree_pred = 0.;
209 double pred_value = 0.;
210 double new_resid = 0.;
211 int32_t leaf_pred;
212 for (data_size_t i = 0; i < n; i++) {
213 for (int j = 0; j < forest->NumTrees(); j++) {
214 Tree* tree = forest->GetTree(j);
215 leaf_pred = tracker.GetNodeId(i, j);
216 if (requires_basis) {
217 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
218 } else {
219 tree_pred += tree->PredictFromNode(leaf_pred);
220 }
221 tracker.SetTreeSamplePrediction(i, j, tree_pred);
222 pred_value += tree_pred;
223 }
224
225 // Run op (either plus or minus) on the residual and the new prediction
226 new_resid = op(residual.GetElement(i), pred_value);
227 residual.SetElement(i, new_resid);
228 }
229 tracker.SyncPredictions();
230}
231
232static inline void UpdateResidualNoTrackerUpdate(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
233 bool requires_basis, std::function<double(double, double)> op) {
234 data_size_t n = dataset.GetCovariates().rows();
235 double tree_pred = 0.;
236 double pred_value = 0.;
237 double new_resid = 0.;
238 int32_t leaf_pred;
239 for (data_size_t i = 0; i < n; i++) {
240 for (int j = 0; j < forest->NumTrees(); j++) {
241 Tree* tree = forest->GetTree(j);
242 leaf_pred = tracker.GetNodeId(i, j);
243 if (requires_basis) {
244 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
245 } else {
246 tree_pred += tree->PredictFromNode(leaf_pred);
247 }
248 pred_value += tree_pred;
249 }
250
251 // Run op (either plus or minus) on the residual and the new prediction
252 new_resid = op(residual.GetElement(i), pred_value);
253 residual.SetElement(i, new_resid);
254 }
255}
256
257static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
258 bool requires_basis, std::function<double(double, double)> op) {
259 data_size_t n = dataset.GetCovariates().rows();
260 double tree_pred = 0.;
261 double pred_value = 0.;
262 double new_resid = 0.;
263 int32_t leaf_pred;
264 for (data_size_t i = 0; i < n; i++) {
265 for (int j = 0; j < forest->NumTrees(); j++) {
266 Tree* tree = forest->GetTree(j);
267 leaf_pred = tracker.GetNodeId(i, j);
268 if (requires_basis) {
269 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
270 } else {
271 tree_pred += tree->PredictFromNode(leaf_pred);
272 }
273 tracker.SetTreeSamplePrediction(i, j, tree_pred);
274 pred_value += tree_pred;
275 }
276
277 // Run op (either plus or minus) on the residual and the new prediction
278 new_resid = op(residual.GetElement(i), pred_value);
279 residual.SetElement(i, new_resid);
280 }
281 tracker.SyncPredictions();
282}
283
284static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector& residual) {
285 data_size_t n = residual.NumRows();
286 double pred_value;
287 double prev_outcome;
288 double new_resid;
289 for (data_size_t i = 0; i < n; i++) {
290 prev_outcome = residual.GetElement(i);
291 pred_value = tracker.GetSamplePrediction(i);
292 // Run op (either plus or minus) on the residual and the new prediction
293 new_resid = prev_outcome - pred_value;
294 residual.SetElement(i, new_resid);
295 }
296}
297
298static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num,
299 bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
300 data_size_t n = dataset.GetCovariates().rows();
301 double pred_value;
302 int32_t leaf_pred;
303 double new_resid;
304 double pred_delta;
305 for (data_size_t i = 0; i < n; i++) {
306 if (tree_new) {
307 // If the tree has been newly sampled or adjusted, we must rerun the prediction
308 // method and update the SamplePredMapper stored in tracker
309 leaf_pred = tracker.GetNodeId(i, tree_num);
310 if (requires_basis) {
311 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
312 } else {
313 pred_value = tree->PredictFromNode(leaf_pred);
314 }
315 pred_delta = pred_value - tracker.GetTreeSamplePrediction(i, tree_num);
316 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
317 tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta);
318 } else {
319 // If the tree has not yet been modified via a sampling step,
320 // we can query its prediction directly from the SamplePredMapper stored in tracker
321 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
322 }
323 // Run op (either plus or minus) on the residual and the new prediction
324 new_resid = op(residual.GetElement(i), pred_value);
325 residual.SetElement(i, new_resid);
326 }
327}
328
329static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest) {
330 CHECK(dataset.HasBasis());
331 data_size_t n = dataset.GetCovariates().rows();
332 int num_trees = forest->NumTrees();
333 double prev_tree_pred;
334 double new_tree_pred;
335 int32_t leaf_pred;
336 double new_resid;
337 for (int tree_num = 0; tree_num < num_trees; tree_num++) {
338 Tree* tree = forest->GetTree(tree_num);
339 for (data_size_t i = 0; i < n; i++) {
340 // Add back the currently stored tree prediction
341 prev_tree_pred = tracker.GetTreeSamplePrediction(i, tree_num);
342 new_resid = residual.GetElement(i) + prev_tree_pred;
343
344 // Compute new prediction based on updated basis
345 leaf_pred = tracker.GetNodeId(i, tree_num);
346 new_tree_pred = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
347
348 // Cache the new prediction in the tracker
349 tracker.SetTreeSamplePrediction(i, tree_num, new_tree_pred);
350
351 // Subtract out the updated tree prediction
352 new_resid -= new_tree_pred;
353
354 // Propagate the change back to the residual
355 residual.SetElement(i, new_resid);
356 }
357 }
358 tracker.SyncPredictions();
359}
360
361static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree,
362 int tree_num, bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
363 data_size_t n = dataset.GetCovariates().rows();
364 double pred_value;
365 int32_t leaf_pred;
366 double new_weight;
367 double pred_delta;
368 double prev_tree_pred;
369 double prev_pred;
370 for (data_size_t i = 0; i < n; i++) {
371 if (tree_new) {
372 // If the tree has been newly sampled or adjusted, we must rerun the prediction
373 // method and update the SamplePredMapper stored in tracker
374 leaf_pred = tracker.GetNodeId(i, tree_num);
375 if (requires_basis) {
376 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
377 } else {
378 pred_value = tree->PredictFromNode(leaf_pred);
379 }
380 prev_tree_pred = tracker.GetTreeSamplePrediction(i, tree_num);
381 prev_pred = tracker.GetSamplePrediction(i);
382 pred_delta = pred_value - prev_tree_pred;
383 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
384 tracker.SetSamplePrediction(i, prev_pred + pred_delta);
385 new_weight = std::log(dataset.VarWeightValue(i)) + pred_value;
386 dataset.SetVarWeightValue(i, new_weight, true);
387 } else {
388 // If the tree has not yet been modified via a sampling step,
389 // we can query its prediction directly from the SamplePredMapper stored in tracker
390 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
391 new_weight = std::log(dataset.VarWeightValue(i)) - pred_value;
392 dataset.SetVarWeightValue(i, new_weight, true);
393 }
394 }
395}
396
397template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
398static inline std::tuple<double, double, data_size_t, data_size_t> EvaluateProposedSplit(
399 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model,
400 TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance,
401 int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
402) {
403 // Initialize sufficient statistics
404 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
405 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
406 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
407
408 // Accumulate sufficient statistics
409 AccumulateSuffStatProposed<LeafSuffStat, LeafSuffStatConstructorArgs...>(
410 node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker,
411 residual, global_variance, split, tree_num, leaf_num, split_feature, num_threads,
412 leaf_suff_stat_args...
413 );
414 data_size_t left_n = left_suff_stat.n;
415 data_size_t right_n = right_suff_stat.n;
416
417 // Evaluate split
418 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
419 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
420
421 return std::tuple<double, double, data_size_t, data_size_t>(split_log_ml, no_split_log_ml, left_n, right_n);
422}
423
424template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
425static inline std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(
426 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model,
427 double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id,
428 LeafSuffStatConstructorArgs&... leaf_suff_stat_args
429) {
430 // Initialize sufficient statistics
431 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
432 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
433 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
434
435 // Accumulate sufficient statistics
436 AccumulateSuffStatExisting<LeafSuffStat>(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker,
437 residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id);
438 data_size_t left_n = left_suff_stat.n;
439 data_size_t right_n = right_suff_stat.n;
440
441 // Evaluate split
442 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
443 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
444
445 return std::tuple<double, double, data_size_t, data_size_t>(split_log_ml, no_split_log_ml, left_n, right_n);
446}
447
448template <typename LeafModel>
449static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset,
450 ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) {
451 if (backfitting) {
452 UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus<double>(), false);
453 } else {
454 // TODO: think about a generic way to store "state" corresponding to the other models?
455 UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus<double>(), false);
456 }
457}
458
459template <typename LeafModel>
460static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset,
461 ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) {
462 if (backfitting) {
463 UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus<double>(), true);
464 } else {
465 // TODO: think about a generic way to store "state" corresponding to the other models?
466 UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus<double>(), true);
467 }
468}
469
470template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
471static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
472 TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size,
473 std::unordered_map<int, std::pair<data_size_t, data_size_t>>& node_index_map, std::deque<node_t>& split_queue,
474 int node_id, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
475 std::vector<FeatureType>& feature_types, std::vector<bool> feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
476 // Leaf depth
477 int leaf_depth = tree->GetDepth(node_id);
478
479 // Maximum leaf depth
480 int32_t max_depth = tree_prior.GetMaxDepth();
481
482 if ((max_depth == -1) || (leaf_depth < max_depth)) {
483
484 // Vector of vectors to store results for each feature
485 int p = dataset.NumCovariates();
486 std::vector<std::vector<double>> feature_log_cutpoint_evaluations(p+1);
487 std::vector<std::vector<double>> feature_cutpoint_values(p+1);
488 std::vector<int> feature_cutpoint_counts(p+1, 0);
489 StochTree::data_size_t valid_cutpoint_count;
490
491 // Evaluate all possible cutpoints according to the leaf node model,
492 // recording their log-likelihood and other split information in a series of vectors.
493
494 // Initialize node sufficient statistics
495 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
496
497 // Accumulate aggregate sufficient statistic for the node to be split
498 AccumulateSingleNodeSuffStat<LeafSuffStat, false>(node_suff_stat, dataset, tracker, residual, tree_num, node_id);
499
500 // Compute the "no split" log marginal likelihood
501 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
502
503 // Unpack data
504 Eigen::MatrixXd& covariates = dataset.GetCovariates();
505 Eigen::VectorXd& outcome = residual.GetData();
506 Eigen::VectorXd var_weights;
507 bool has_weights = dataset.HasVarWeights();
508 if (has_weights) var_weights = dataset.GetVarWeights();
509
510 // Minimum size of newly created leaf nodes (used to rule out invalid splits)
511 int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf();
512
513 // Compute sufficient statistics for each possible split
514 data_size_t num_cutpoints = 0;
515 if (num_threads == -1) {
516 num_threads = GetOptimalThreadCount(static_cast<int>(covariates.cols() * covariates.rows()));
517 }
518
519 // Initialize cutpoint grid container
520 CutpointGridContainer cutpoint_grid_container(covariates, outcome, cutpoint_grid_size);
521
522 // Evaluate all possible splits for each feature in parallel
523 StochTree::ParallelFor(0, covariates.cols(), num_threads, [&](int j) {
524 if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) {
525 // Enumerate cutpoint strides
526 cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types);
527
528 // Left and right node sufficient statistics
529 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
530 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
531
532 // Iterate through possible cutpoints
533 int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j);
534 FeatureType feature_type = feature_types[j];
535 // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins
536 for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) {
537 data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j);
538 data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j);
539 data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j);
540
541 // Accumulate sufficient statistics for the left node
542 AccumulateCutpointBinSuffStat<LeafSuffStat>(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual,
543 global_variance, tree_num, node_id, j, cutpoint_idx);
544
545 // Compute the corresponding right node sufficient statistics
546 right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat);
547
548 // Store the bin index as the "cutpoint value" - we can use this to query the actual split
549 // value or the set of split categories later on once a split is chose
550 double cutoff_value = cutpoint_idx;
551
552 // Only include cutpoint for consideration if it defines a valid split in the training data
553 bool valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) &&
554 right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf));
555 if (valid_split) {
556 feature_cutpoint_counts[j]++;
557 // Add to split rule vector
558 feature_cutpoint_values[j].push_back(cutoff_value);
559 // Add the log marginal likelihood of the split to the split eval vector
560 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
561 feature_log_cutpoint_evaluations[j].push_back(split_log_ml);
562 }
563 }
564 }
565 });
566
567 // Compute total number of cutpoints
568 valid_cutpoint_count = std::accumulate(feature_cutpoint_counts.begin(), feature_cutpoint_counts.end(), 0);
569
570 // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper)
571 feature_log_cutpoint_evaluations[covariates.cols()].push_back(no_split_log_ml);
572
573 // Compute an adjustment to reflect the no split prior probability and the number of cutpoints
574 double bart_prior_no_split_adj;
575 double alpha = tree_prior.GetAlpha();
576 double beta = tree_prior.GetBeta();
577 int node_depth = tree->GetDepth(node_id);
578 if (valid_cutpoint_count == 0) {
579 bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0);
580 } else {
581 bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count);
582 }
583 feature_log_cutpoint_evaluations[covariates.cols()][0] += bart_prior_no_split_adj;
584
585
586 // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood
587 double largest_ml = -std::numeric_limits<double>::infinity();
588 for (int j = 0; j < p + 1; j++) {
589 if (feature_log_cutpoint_evaluations[j].size() > 0) {
590 double feature_max_ml = *std::max_element(feature_log_cutpoint_evaluations[j].begin(), feature_log_cutpoint_evaluations[j].end());;
591 largest_ml = std::max(largest_ml, feature_max_ml);
592 }
593 }
594 std::vector<std::vector<double>> feature_cutpoint_evaluations(p+1);
595 for (int j = 0; j < p + 1; j++) {
596 if (feature_log_cutpoint_evaluations[j].size() > 0) {
597 feature_cutpoint_evaluations[j].resize(feature_log_cutpoint_evaluations[j].size());
598 for (int i = 0; i < feature_log_cutpoint_evaluations[j].size(); i++) {
599 feature_cutpoint_evaluations[j][i] = std::exp(feature_log_cutpoint_evaluations[j][i] - largest_ml);
600 }
601 }
602 }
603
604 // Compute sum of marginal likelihoods for each feature
605 std::vector<double> feature_total_cutpoint_evaluations(p+1, 0.0);
606 for (int j = 0; j < p + 1; j++) {
607 if (feature_log_cutpoint_evaluations[j].size() > 0) {
608 feature_total_cutpoint_evaluations[j] = std::accumulate(feature_cutpoint_evaluations[j].begin(), feature_cutpoint_evaluations[j].end(), 0.0);
609 } else {
610 feature_total_cutpoint_evaluations[j] = 0.0;
611 }
612 }
613
614 // First, sample a feature according to feature_total_cutpoint_evaluations
615 std::discrete_distribution<data_size_t> feature_dist(feature_total_cutpoint_evaluations.begin(), feature_total_cutpoint_evaluations.end());
616 int feature_chosen = feature_dist(gen);
617
618 // Then, sample a cutpoint according to feature_cutpoint_evaluations[feature_chosen]
619 std::discrete_distribution<data_size_t> cutpoint_dist(feature_cutpoint_evaluations[feature_chosen].begin(), feature_cutpoint_evaluations[feature_chosen].end());
620 data_size_t cutpoint_chosen = cutpoint_dist(gen);
621
622 if (feature_chosen == p){
623 // "No split" sampled, don't split or add any nodes to split queue
624 return;
625 } else {
626 // Split sampled
627 int feature_split = feature_chosen;
628 FeatureType feature_type = feature_types[feature_split];
629 double split_value = feature_cutpoint_values[feature_split][cutpoint_chosen];
630 // Perform all of the relevant "split" operations in the model, tree and training dataset
631
632 // Compute node sample size
633 data_size_t node_n = node_end - node_begin;
634
635 // Actual numeric cutpoint used for ordered categorical and numeric features
636 double split_value_numeric;
637 TreeSplit tree_split;
638
639 // We will use these later in the model expansion
640 data_size_t left_n = 0;
641 data_size_t right_n = 0;
642 data_size_t sort_idx;
643 double feature_value;
644 bool split_true;
645
646 if (feature_type == FeatureType::kUnorderedCategorical) {
647 // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split
648 int num_categories;
649 std::vector<std::uint32_t> categories = cutpoint_grid_container.CutpointVector(static_cast<std::uint32_t>(split_value), feature_split);
650 tree_split = TreeSplit(categories);
651 } else if (feature_type == FeatureType::kOrderedCategorical) {
652 // Convert the bin split to an actual split value
653 split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast<std::uint32_t>(split_value), feature_split);
654 tree_split = TreeSplit(split_value_numeric);
655 } else if (feature_type == FeatureType::kNumeric) {
656 // Convert the bin split to an actual split value
657 split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast<std::uint32_t>(split_value), feature_split);
658 tree_split = TreeSplit(split_value_numeric);
659 } else {
660 Log::Fatal("Invalid split type");
661 }
662
663 // Add split to tree and trackers
664 AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true, num_threads);
665
666 // Determine the number of observation in the newly created left node
667 int left_node = tree->LeftChild(node_id);
668 int right_node = tree->RightChild(node_id);
669 auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split);
670 auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split);
671 for (auto i = left_begin_iter; i < left_end_iter; i++) {
672 left_n += 1;
673 }
674
675 // Add the begin and end indices for the new left and right nodes to node_index_map
676 node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)});
677 node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)});
678
679 // Add the left and right nodes to the split tracker
680 split_queue.push_front(right_node);
681 split_queue.push_front(left_node);
682 }
683 }
684}
685
686template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
687static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
688 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
689 int tree_num, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
690 int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
691 int root_id = Tree::kRoot;
692 int curr_node_id;
693 data_size_t curr_node_begin;
694 data_size_t curr_node_end;
695 data_size_t n = dataset.GetCovariates().rows();
696 int p = dataset.GetCovariates().cols();
697
698 // Subsample features (if requested)
699 std::vector<bool> feature_subset(p, true);
700 if (num_features_subsample < p) {
701 // Check if the number of (meaningfully) nonzero selection probabilities is greater than num_features_subsample
702 int number_nonzero_weights = 0;
703 for (int j = 0; j < p; j++) {
704 if (std::abs(variable_weights.at(j)) > kEpsilon) {
705 number_nonzero_weights++;
706 }
707 }
708 if (number_nonzero_weights > num_features_subsample) {
709 // Sample with replacement according to variable_weights
710 std::vector<int> feature_indices(p);
711 std::iota(feature_indices.begin(), feature_indices.end(), 0);
712 std::vector<int> features_selected(num_features_subsample);
713 sample_without_replacement<int, double>(
714 features_selected.data(), variable_weights.data(), feature_indices.data(),
715 p, num_features_subsample, gen
716 );
717 for (int i = 0; i < p; i++) {
718 feature_subset.at(i) = false;
719 }
720 for (const auto& feat : features_selected) {
721 feature_subset.at(feat) = true;
722 }
723 }
724 }
725
726 // Mapping from node id to start and end points of sorted indices
727 std::unordered_map<int, std::pair<data_size_t, data_size_t>> node_index_map;
728 node_index_map.insert({root_id, std::make_pair(0, n)});
729 std::pair<data_size_t, data_size_t> begin_end;
730 // Add root node to the split queue
731 std::deque<node_t> split_queue;
732 split_queue.push_back(Tree::kRoot);
733 // Run the "GrowFromRoot" procedure using a stack in place of recursion
734 while (!split_queue.empty()) {
735 // Remove the next node from the queue
736 curr_node_id = split_queue.front();
737 split_queue.pop_front();
738 // Determine the beginning and ending indices of the left and right nodes
739 begin_end = node_index_map[curr_node_id];
740 curr_node_begin = begin_end.first;
741 curr_node_end = begin_end.second;
742 // Draw a split rule at random
743 SampleSplitRule<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
744 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size,
745 node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types,
746 feature_subset, num_threads, leaf_suff_stat_args...
747 );
748 }
749}
750
781template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
782static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
783 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
784 std::vector<int>& sweep_update_indices, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
785 bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
786 // Run the GFR algorithm for each tree
787 int num_trees = forests.NumTrees();
788 for (const int& i : sweep_update_indices) {
789 // Adjust any model state needed to run a tree sampler
790 // For models that involve Bayesian backfitting, this amounts to adding tree i's
791 // predictions back to the residual (thus, training a model on the "partial residual")
792 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
793 Tree* tree = active_forest.GetTree(i);
794 AdjustStateBeforeTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
795
796 // Reset the tree and sample trackers
797 active_forest.ResetInitTree(i);
798 tracker.ResetRoot(dataset.GetCovariates(), feature_types, i);
799 tree = active_forest.GetTree(i);
800
801 // Sample tree i
802 GFRSampleTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
803 tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen,
804 variable_weights, i, global_variance, feature_types, cutpoint_grid_size,
805 num_features_subsample, num_threads, leaf_suff_stat_args...
806 );
807
808 // Sample leaf parameters for tree i
809 tree = active_forest.GetTree(i);
810 leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
811
812 // Adjust any model state needed to run a tree sampler
813 // For models that involve Bayesian backfitting, this amounts to subtracting tree i's
814 // predictions back out of the residual (thus, using an updated "partial residual" in the following interation).
815 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
816 AdjustStateAfterTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
817 }
818
819 if (keep_forest) {
820 forests.AddSample(active_forest);
821 }
822}
823
824template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
825static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
826 TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector<double>& variable_weights,
827 double global_variance, double prob_grow_old, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
828 // Extract dataset information
829 data_size_t n = dataset.GetCovariates().rows();
830
831 // Choose a leaf node at random
832 int num_leaves = tree->NumLeaves();
833 std::vector<int> leaves = tree->GetLeaves();
834 std::vector<double> leaf_weights(num_leaves);
835 std::fill(leaf_weights.begin(), leaf_weights.end(), 1.0/num_leaves);
836 std::discrete_distribution<> leaf_dist(leaf_weights.begin(), leaf_weights.end());
837 int leaf_chosen = leaves[leaf_dist(gen)];
838 int leaf_depth = tree->GetDepth(leaf_chosen);
839
840 // Maximum leaf depth
841 int32_t max_depth = tree_prior.GetMaxDepth();
842
843 // Terminate early if cannot be split
844 bool accept;
845 if ((leaf_depth >= max_depth) && (max_depth != -1)) {
846 accept = false;
847 } else {
848
849 // Select a split variable at random
850 int p = dataset.GetCovariates().cols();
851 CHECK_EQ(variable_weights.size(), p);
852 std::discrete_distribution<> var_dist(variable_weights.begin(), variable_weights.end());
853 int var_chosen = var_dist(gen);
854
855 // Determine the range of possible cutpoints
856 // TODO: specialize this for binary / ordered categorical / unordered categorical variables
857 double var_min, var_max;
858 VarSplitRange(tracker, dataset, tree_num, leaf_chosen, var_chosen, var_min, var_max);
859 if (var_max <= var_min) {
860 return;
861 }
862
863 // Split based on var_min to var_max in a given node
864 std::uniform_real_distribution<double> split_point_dist(var_min, var_max);
865 double split_point_chosen = split_point_dist(gen);
866
867 // Create a split object
868 TreeSplit split = TreeSplit(split_point_chosen);
869
870 // Compute the marginal likelihood of split and no split, given the leaf prior
871 std::tuple<double, double, int32_t, int32_t> split_eval = EvaluateProposedSplit<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
872 dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, num_threads, leaf_suff_stat_args...
873 );
874 double split_log_marginal_likelihood = std::get<0>(split_eval);
875 double no_split_log_marginal_likelihood = std::get<1>(split_eval);
876 int32_t left_n = std::get<2>(split_eval);
877 int32_t right_n = std::get<3>(split_eval);
878
879 // Reject the split if either of the left and right nodes are smaller than tree_prior.GetMinSamplesLeaf()
880 bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf();
881 bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf();
882 if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) {
883
884 // Determine probability of growing the split node and its two new left and right nodes
885 double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta());
886 double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
887 double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
888
889 // Determine whether a "grow" move is possible from the newly formed tree
890 // in order to compute the probability of choosing "prune" from the new tree
891 // (which is always possible by construction)
892 bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
893 bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf();
894 bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf();
895 double prob_prune_new;
896 if (non_constant && (min_samples_left_check || min_samples_right_check)) {
897 prob_prune_new = 0.5;
898 } else {
899 prob_prune_new = 1.0;
900 }
901
902 // Determine the number of leaves in the current tree and leaf parents in the proposed tree
903 int num_leaf_parents = tree->NumLeafParents();
904 double p_leaf = 1/static_cast<double>(num_leaves);
905 double p_leaf_parent = 1/static_cast<double>(num_leaf_parents+1);
906
907 // Compute the final MH ratio
908 double log_mh_ratio = (
909 std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) +
910 std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
911 );
912 // Threshold at 0
913 if (log_mh_ratio > 0) {
914 log_mh_ratio = 0;
915 }
916
917 // Draw a uniform random variable and accept/reject the proposal on this basis
918 std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
919 double log_acceptance_prob = std::log(mh_accept(gen));
920 if (log_acceptance_prob <= log_mh_ratio) {
921 accept = true;
922 AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false, num_threads);
923 } else {
924 accept = false;
925 }
926
927 } else {
928 accept = false;
929 }
930 }
931}
932
933template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
934static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
935 TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
936 // Choose a "leaf parent" node at random
937 int num_leaves = tree->NumLeaves();
938 int num_leaf_parents = tree->NumLeafParents();
939 std::vector<int> leaf_parents = tree->GetLeafParents();
940 std::vector<double> leaf_parent_weights(num_leaf_parents);
941 std::fill(leaf_parent_weights.begin(), leaf_parent_weights.end(), 1.0/num_leaf_parents);
942 std::discrete_distribution<> leaf_parent_dist(leaf_parent_weights.begin(), leaf_parent_weights.end());
943 int leaf_parent_chosen = leaf_parents[leaf_parent_dist(gen)];
944 int leaf_parent_depth = tree->GetDepth(leaf_parent_chosen);
945 int left_node = tree->LeftChild(leaf_parent_chosen);
946 int right_node = tree->RightChild(leaf_parent_chosen);
947 int feature_split = tree->SplitIndex(leaf_parent_chosen);
948
949 // Compute the marginal likelihood for the leaf parent and its left and right nodes
950 std::tuple<double, double, int32_t, int32_t> split_eval = EvaluateExistingSplit<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
951 dataset, tracker, residual, leaf_model, global_variance, tree_num, leaf_parent_chosen, left_node, right_node, leaf_suff_stat_args...
952 );
953 double split_log_marginal_likelihood = std::get<0>(split_eval);
954 double no_split_log_marginal_likelihood = std::get<1>(split_eval);
955 int32_t left_n = std::get<2>(split_eval);
956 int32_t right_n = std::get<3>(split_eval);
957
958 // Determine probability of growing the split node and its two new left and right nodes
959 double pg = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth, -tree_prior.GetBeta());
960 double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta());
961 double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta());
962
963 // Determine whether a "prune" move is possible from the new tree,
964 // in order to compute the probability of choosing "grow" from the new tree
965 // (which is always possible by construction)
966 bool non_root_tree = tree->NumNodes() > 1;
967 double prob_grow_new;
968 if (non_root_tree) {
969 prob_grow_new = 0.5;
970 } else {
971 prob_grow_new = 1.0;
972 }
973
974 // Determine whether a "grow" move was possible from the old tree,
975 // in order to compute the probability of choosing "prune" from the old tree
976 bool non_constant_left = NodeNonConstant(dataset, tracker, tree_num, left_node);
977 bool non_constant_right = NodeNonConstant(dataset, tracker, tree_num, right_node);
978 double prob_prune_old;
979 if (non_constant_left && non_constant_right) {
980 prob_prune_old = 0.5;
981 } else {
982 prob_prune_old = 1.0;
983 }
984
985 // Determine the number of leaves in the current tree and leaf parents in the proposed tree
986 double p_leaf = 1/static_cast<double>(num_leaves-1);
987 double p_leaf_parent = 1/static_cast<double>(num_leaf_parents);
988
989 // Compute the final MH ratio
990 double log_mh_ratio = (
991 std::log(1-pg) - std::log(pg) - std::log(1-pgl) - std::log(1-pgr) + std::log(prob_prune_old) +
992 std::log(p_leaf) - std::log(prob_grow_new) - std::log(p_leaf_parent) + no_split_log_marginal_likelihood - split_log_marginal_likelihood
993 );
994 // Threshold at 0
995 if (log_mh_ratio > 0) {
996 log_mh_ratio = 0;
997 }
998
999 // Draw a uniform random variable and accept/reject the proposal on this basis
1000 bool accept;
1001 std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
1002 double log_acceptance_prob = std::log(mh_accept(gen));
1003 if (log_acceptance_prob <= log_mh_ratio) {
1004 accept = true;
1005 RemoveSplitFromModel(tracker, dataset, tree_prior, gen, tree, tree_num, leaf_parent_chosen, left_node, right_node, false);
1006 } else {
1007 accept = false;
1008 }
1009}
1010
1011template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
1012static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
1013 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
1014 int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
1015 // Determine whether it is possible to grow any of the leaves
1016 bool grow_possible = false;
1017 std::vector<int> leaves = tree->GetLeaves();
1018 for (auto& leaf: leaves) {
1019 if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) {
1020 grow_possible = true;
1021 break;
1022 }
1023 }
1024
1025 // Determine whether it is possible to prune the tree
1026 bool prune_possible = false;
1027 if (tree->NumValidNodes() > 1) {
1028 prune_possible = true;
1029 }
1030
1031 // Determine the relative probability of grow vs prune (0 = grow, 1 = prune)
1032 double prob_grow;
1033 std::vector<double> step_probs(2);
1034 if (grow_possible && prune_possible) {
1035 step_probs = {0.5, 0.5};
1036 prob_grow = 0.5;
1037 } else if (!grow_possible && prune_possible) {
1038 step_probs = {0.0, 1.0};
1039 prob_grow = 0.0;
1040 } else if (grow_possible && !prune_possible) {
1041 step_probs = {1.0, 0.0};
1042 prob_grow = 1.0;
1043 } else {
1044 Log::Fatal("In this tree, neither grow nor prune is possible");
1045 }
1046 std::discrete_distribution<> step_dist(step_probs.begin(), step_probs.end());
1047
1048 // Draw a split rule at random
1049 data_size_t step_chosen = step_dist(gen);
1050 bool accept;
1051
1052 if (step_chosen == 0) {
1053 MCMCGrowTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1054 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, num_threads, leaf_suff_stat_args...
1055 );
1056 } else {
1057 MCMCPruneTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1058 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, num_threads, leaf_suff_stat_args...
1059 );
1060 }
1061}
1062
1090template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
1091static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
1092 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
1093 std::vector<int>& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads,
1094 LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
1095 // Run the MCMC algorithm for each tree
1096 int num_trees = forests.NumTrees();
1097 for (const int& i : sweep_update_indices) {
1098 // Adjust any model state needed to run a tree sampler
1099 // For models that involve Bayesian backfitting, this amounts to adding tree i's
1100 // predictions back to the residual (thus, training a model on the "partial residual")
1101 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
1102 Tree* tree = active_forest.GetTree(i);
1103 AdjustStateBeforeTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
1104
1105 // Sample tree i
1106 tree = active_forest.GetTree(i);
1107 MCMCSampleTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1108 tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i,
1109 global_variance, num_threads, leaf_suff_stat_args...
1110 );
1111
1112 // Sample leaf parameters for tree i
1113 tree = active_forest.GetTree(i);
1114 leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
1115
1116 // Adjust any model state needed to run a tree sampler
1117 // For models that involve Bayesian backfitting, this amounts to subtracting tree i's
1118 // predictions back out of the residual (thus, using an updated "partial residual" in the following interation).
1119 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
1120 AdjustStateAfterTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
1121 }
1122
1123 if (keep_forest) {
1124 forests.AddSample(active_forest);
1125 }
1126}
1127
// end of sampling_group
1129
1130} // namespace StochTree
1131
1132#endif // STOCHTREE_TREE_SAMPLER_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:193
Container of TreeEnsemble forest objects. This is the primary (in-memory) storage interface for multi...
Definition container.h:23
void AddSample(TreeEnsemble &forest)
Add a new forest to the container by copying forest.
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
Eigen::MatrixXd & GetCovariates()
Return a reference to the raw Eigen::MatrixXd storing the covariate data.
Definition data.h:383
double CovariateValue(data_size_t row, int col)
Returns a dataset's covariate value stored at (row, col)
Definition data.h:364
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:47
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:31
Tree * GetTree(int i)
Return a pointer to a tree in the forest.
Definition ensemble.h:134
void ResetInitTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:163
Definition prior.h:43
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:958
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:990
Decision tree data structure.
Definition tree.h:66
static void MCMCSampleOneIter(TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, std::vector< int > &sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
Runs one iteration of the MCMC sampler for a tree ensemble model, which consists of two steps for eve...
Definition tree_sampler.h:1091
static void VarSplitRange(ForestTracker &tracker, ForestDataset &dataset, int tree_num, int leaf_split, int feature_split, double &var_min, double &var_max)
Computer the range of available split values for a continuous variable, given the current structure o...
Definition tree_sampler.h:47
static bool NodesNonConstantAfterSplit(ForestDataset &dataset, ForestTracker &tracker, TreeSplit &split, int tree_num, int leaf_split, int feature_split)
Determines whether a proposed split creates two leaf nodes with constant values for every feature (th...
Definition tree_sampler.h:77
static void GFRSampleOneIter(TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, std::vector< int > &sweep_update_indices, double global_variance, std::vector< FeatureType > &feature_types, int cutpoint_grid_size, bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
Definition tree_sampler.h:782
Definition category_tracker.h:36