StochTree 0.0.1
Loading...
Searching...
No Matches
tree_sampler.h
1
2#ifndef STOCHTREE_TREE_SAMPLER_H_
3#define STOCHTREE_TREE_SAMPLER_H_
4
5#include <stochtree/container.h>
6#include <stochtree/cutpoint_candidates.h>
7#include <stochtree/data.h>
8#include <stochtree/ensemble.h>
9#include <stochtree/leaf_model.h>
10#include <stochtree/partition_tracker.h>
11#include <stochtree/prior.h>
12
13#include <cmath>
14#include <map>
15#include <memory>
16#include <random>
17#include <set>
18#include <string>
19#include <type_traits>
20#include <variant>
21#include <vector>
22
23namespace StochTree {
24
50static inline void VarSplitRange(ForestTracker& tracker, ForestDataset& dataset, int tree_num, int leaf_split, int feature_split, double& var_min, double& var_max) {
51 var_min = std::numeric_limits<double>::max();
52 var_max = std::numeric_limits<double>::min();
53 double feature_value;
54
55 std::vector<data_size_t>::iterator node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split);
56 std::vector<data_size_t>::iterator node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split);
57
58 for (auto i = node_begin_iter; i != node_end_iter; i++) {
59 auto idx = *i;
60 feature_value = dataset.CovariateValue(idx, feature_split);
61 if (feature_value < var_min) {
62 var_min = feature_value;
63 } else if (feature_value > var_max) {
64 var_max = feature_value;
65 }
66 }
67}
68
80static inline bool NodesNonConstantAfterSplit(ForestDataset& dataset, ForestTracker& tracker, TreeSplit& split, int tree_num, int leaf_split, int feature_split) {
81 int p = dataset.GetCovariates().cols();
82 data_size_t idx;
83 double feature_value;
84 double split_feature_value;
85 double var_max_left;
86 double var_min_left;
87 double var_max_right;
88 double var_min_right;
89
90 for (int j = 0; j < p; j++) {
91 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split);
92 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split);
93 var_max_left = std::numeric_limits<double>::min();
94 var_min_left = std::numeric_limits<double>::max();
95 var_max_right = std::numeric_limits<double>::min();
96 var_min_right = std::numeric_limits<double>::max();
97
98 for (auto i = node_begin_iter; i != node_end_iter; i++) {
99 auto idx = *i;
100 split_feature_value = dataset.CovariateValue(idx, feature_split);
101 feature_value = dataset.CovariateValue(idx, j);
102 if (split.SplitTrue(split_feature_value)) {
103 if (var_max_left < feature_value) {
104 var_max_left = feature_value;
105 } else if (var_min_left > feature_value) {
106 var_min_left = feature_value;
107 }
108 } else {
109 if (var_max_right < feature_value) {
110 var_max_right = feature_value;
111 } else if (var_min_right > feature_value) {
112 var_min_right = feature_value;
113 }
114 }
115 }
116 if ((var_max_left > var_min_left) && (var_max_right > var_min_right)) {
117 return true;
118 }
119 }
120 return false;
121}
122
123static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracker, int tree_num, int node_id) {
124 int p = dataset.GetCovariates().cols();
125 data_size_t idx;
126 double feature_value;
127 double var_max;
128 double var_min;
129
130 for (int j = 0; j < p; j++) {
131 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
132 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
133 var_max = std::numeric_limits<double>::min();
134 var_min = std::numeric_limits<double>::max();
135
136 for (auto i = node_begin_iter; i != node_end_iter; i++) {
137 auto idx = *i;
138 feature_value = dataset.CovariateValue(idx, j);
139 if (var_max < feature_value) {
140 var_max = feature_value;
141 } else if (var_min > feature_value) {
142 var_min = feature_value;
143 }
144 }
145 if (var_max > var_min) {
146 return true;
147 }
148 }
149 return false;
150}
151
152static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree,
153 int tree_num, int leaf_node, int feature_split, bool keep_sorted = false) {
154 // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete
155 if (tree->OutputDimension() > 1) {
156 std::vector<double> temp_leaf_values(tree->OutputDimension(), 0.);
157 tree->ExpandNode(leaf_node, feature_split, split, temp_leaf_values, temp_leaf_values);
158 } else {
159 double temp_leaf_value = 0.;
160 tree->ExpandNode(leaf_node, feature_split, split, temp_leaf_value, temp_leaf_value);
161 }
162 int left_node = tree->LeftChild(leaf_node);
163 int right_node = tree->RightChild(leaf_node);
164
165 // Update the ForestTracker
166 tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted);
167}
168
169static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree,
170 int tree_num, int leaf_node, int left_node, int right_node, bool keep_sorted = false) {
171 // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete
172 if (tree->OutputDimension() > 1) {
173 std::vector<double> temp_leaf_values(tree->OutputDimension(), 0.);
174 tree->CollapseToLeaf(leaf_node, temp_leaf_values);
175 } else {
176 double temp_leaf_value = 0.;
177 tree->CollapseToLeaf(leaf_node, temp_leaf_value);
178 }
179
180 // Update the ForestTracker
181 tracker.RemoveSplit(dataset.GetCovariates(), tree, tree_num, leaf_node, left_node, right_node, keep_sorted);
182}
183
184static inline double ComputeMeanOutcome(ColumnVector& residual) {
185 int n = residual.NumRows();
186 double sum_y = 0.;
187 double y;
188 for (data_size_t i = 0; i < n; i++) {
189 y = residual.GetElement(i);
190 sum_y += y;
191 }
192 return sum_y / static_cast<double>(n);
193}
194
195static inline double ComputeVarianceOutcome(ColumnVector& residual) {
196 int n = residual.NumRows();
197 double sum_y = 0.;
198 double sum_y_sq = 0.;
199 double y;
200 for (data_size_t i = 0; i < n; i++) {
201 y = residual.GetElement(i);
202 sum_y += y;
203 sum_y_sq += y * y;
204 }
205 return sum_y_sq / static_cast<double>(n) - (sum_y * sum_y) / (static_cast<double>(n) * static_cast<double>(n));
206}
207
208static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual,
209 TreeEnsemble* forest, bool requires_basis, std::function<double(double, double)> op) {
210 data_size_t n = dataset.GetCovariates().rows();
211 double tree_pred = 0.;
212 double pred_value = 0.;
213 double new_resid = 0.;
214 int32_t leaf_pred;
215 for (data_size_t i = 0; i < n; i++) {
216 for (int j = 0; j < forest->NumTrees(); j++) {
217 Tree* tree = forest->GetTree(j);
218 leaf_pred = tracker.GetNodeId(i, j);
219 if (requires_basis) {
220 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
221 } else {
222 tree_pred += tree->PredictFromNode(leaf_pred);
223 }
224 tracker.SetTreeSamplePrediction(i, j, tree_pred);
225 pred_value += tree_pred;
226 }
227
228 // Run op (either plus or minus) on the residual and the new prediction
229 new_resid = op(residual.GetElement(i), pred_value);
230 residual.SetElement(i, new_resid);
231 }
232 tracker.SyncPredictions();
233}
234
235static inline void UpdateResidualNoTrackerUpdate(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
236 bool requires_basis, std::function<double(double, double)> op) {
237 data_size_t n = dataset.GetCovariates().rows();
238 double tree_pred = 0.;
239 double pred_value = 0.;
240 double new_resid = 0.;
241 int32_t leaf_pred;
242 for (data_size_t i = 0; i < n; i++) {
243 for (int j = 0; j < forest->NumTrees(); j++) {
244 Tree* tree = forest->GetTree(j);
245 leaf_pred = tracker.GetNodeId(i, j);
246 if (requires_basis) {
247 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
248 } else {
249 tree_pred += tree->PredictFromNode(leaf_pred);
250 }
251 pred_value += tree_pred;
252 }
253
254 // Run op (either plus or minus) on the residual and the new prediction
255 new_resid = op(residual.GetElement(i), pred_value);
256 residual.SetElement(i, new_resid);
257 }
258}
259
260static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
261 bool requires_basis, std::function<double(double, double)> op) {
262 data_size_t n = dataset.GetCovariates().rows();
263 double tree_pred = 0.;
264 double pred_value = 0.;
265 double new_resid = 0.;
266 int32_t leaf_pred;
267 for (data_size_t i = 0; i < n; i++) {
268 for (int j = 0; j < forest->NumTrees(); j++) {
269 Tree* tree = forest->GetTree(j);
270 leaf_pred = tracker.GetNodeId(i, j);
271 if (requires_basis) {
272 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
273 } else {
274 tree_pred += tree->PredictFromNode(leaf_pred);
275 }
276 tracker.SetTreeSamplePrediction(i, j, tree_pred);
277 pred_value += tree_pred;
278 }
279
280 // Run op (either plus or minus) on the residual and the new prediction
281 new_resid = op(residual.GetElement(i), pred_value);
282 residual.SetElement(i, new_resid);
283 }
284 tracker.SyncPredictions();
285}
286
287static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector& residual) {
288 data_size_t n = residual.NumRows();
289 double pred_value;
290 double prev_outcome;
291 double new_resid;
292 for (data_size_t i = 0; i < n; i++) {
293 prev_outcome = residual.GetElement(i);
294 pred_value = tracker.GetSamplePrediction(i);
295 // Run op (either plus or minus) on the residual and the new prediction
296 new_resid = prev_outcome - pred_value;
297 residual.SetElement(i, new_resid);
298 }
299}
300
301static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num,
302 bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
303 data_size_t n = dataset.GetCovariates().rows();
304 double pred_value;
305 int32_t leaf_pred;
306 double new_resid;
307 double pred_delta;
308 for (data_size_t i = 0; i < n; i++) {
309 if (tree_new) {
310 // If the tree has been newly sampled or adjusted, we must rerun the prediction
311 // method and update the SamplePredMapper stored in tracker
312 leaf_pred = tracker.GetNodeId(i, tree_num);
313 if (requires_basis) {
314 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
315 } else {
316 pred_value = tree->PredictFromNode(leaf_pred);
317 }
318 pred_delta = pred_value - tracker.GetTreeSamplePrediction(i, tree_num);
319 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
320 tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta);
321 } else {
322 // If the tree has not yet been modified via a sampling step,
323 // we can query its prediction directly from the SamplePredMapper stored in tracker
324 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
325 }
326 // Run op (either plus or minus) on the residual and the new prediction
327 new_resid = op(residual.GetElement(i), pred_value);
328 residual.SetElement(i, new_resid);
329 }
330}
331
332static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest) {
333 CHECK(dataset.HasBasis());
334 data_size_t n = dataset.GetCovariates().rows();
335 int num_trees = forest->NumTrees();
336 double prev_tree_pred;
337 double new_tree_pred;
338 int32_t leaf_pred;
339 double new_resid;
340 for (int tree_num = 0; tree_num < num_trees; tree_num++) {
341 Tree* tree = forest->GetTree(tree_num);
342 for (data_size_t i = 0; i < n; i++) {
343 // Add back the currently stored tree prediction
344 prev_tree_pred = tracker.GetTreeSamplePrediction(i, tree_num);
345 new_resid = residual.GetElement(i) + prev_tree_pred;
346
347 // Compute new prediction based on updated basis
348 leaf_pred = tracker.GetNodeId(i, tree_num);
349 new_tree_pred = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
350
351 // Cache the new prediction in the tracker
352 tracker.SetTreeSamplePrediction(i, tree_num, new_tree_pred);
353
354 // Subtract out the updated tree prediction
355 new_resid -= new_tree_pred;
356
357 // Propagate the change back to the residual
358 residual.SetElement(i, new_resid);
359 }
360 }
361 tracker.SyncPredictions();
362}
363
364static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree,
365 int tree_num, bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
366 data_size_t n = dataset.GetCovariates().rows();
367 double pred_value;
368 int32_t leaf_pred;
369 double new_weight;
370 double pred_delta;
371 double prev_tree_pred;
372 double prev_pred;
373 for (data_size_t i = 0; i < n; i++) {
374 if (tree_new) {
375 // If the tree has been newly sampled or adjusted, we must rerun the prediction
376 // method and update the SamplePredMapper stored in tracker
377 leaf_pred = tracker.GetNodeId(i, tree_num);
378 if (requires_basis) {
379 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
380 } else {
381 pred_value = tree->PredictFromNode(leaf_pred);
382 }
383 prev_tree_pred = tracker.GetTreeSamplePrediction(i, tree_num);
384 prev_pred = tracker.GetSamplePrediction(i);
385 pred_delta = pred_value - prev_tree_pred;
386 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
387 tracker.SetSamplePrediction(i, prev_pred + pred_delta);
388 new_weight = std::log(dataset.VarWeightValue(i)) + pred_value;
389 dataset.SetVarWeightValue(i, new_weight, true);
390 } else {
391 // If the tree has not yet been modified via a sampling step,
392 // we can query its prediction directly from the SamplePredMapper stored in tracker
393 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
394 new_weight = std::log(dataset.VarWeightValue(i)) - pred_value;
395 dataset.SetVarWeightValue(i, new_weight, true);
396 }
397 }
398}
399
400template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
401static inline std::tuple<double, double, data_size_t, data_size_t> EvaluateProposedSplit(
402 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model,
403 TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance,
404 LeafSuffStatConstructorArgs&... leaf_suff_stat_args
405) {
406 // Initialize sufficient statistics
407 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
408 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
409 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
410
411 // Accumulate sufficient statistics
412 AccumulateSuffStatProposed<LeafSuffStat>(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker,
413 residual, global_variance, split, tree_num, leaf_num, split_feature);
414 data_size_t left_n = left_suff_stat.n;
415 data_size_t right_n = right_suff_stat.n;
416
417 // Evaluate split
418 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
419 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
420
421 return std::tuple<double, double, data_size_t, data_size_t>(split_log_ml, no_split_log_ml, left_n, right_n);
422}
423
424template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
425static inline std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(
426 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model,
427 double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id,
428 LeafSuffStatConstructorArgs&... leaf_suff_stat_args
429) {
430 // Initialize sufficient statistics
431 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
432 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
433 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
434
435 // Accumulate sufficient statistics
436 AccumulateSuffStatExisting<LeafSuffStat>(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker,
437 residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id);
438 data_size_t left_n = left_suff_stat.n;
439 data_size_t right_n = right_suff_stat.n;
440
441 // Evaluate split
442 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
443 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
444
445 return std::tuple<double, double, data_size_t, data_size_t>(split_log_ml, no_split_log_ml, left_n, right_n);
446}
447
448template <typename LeafModel>
449static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset,
450 ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) {
451 if (backfitting) {
452 UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus<double>(), false);
453 } else {
454 // TODO: think about a generic way to store "state" corresponding to the other models?
455 UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus<double>(), false);
456 }
457}
458
459template <typename LeafModel>
460static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset,
461 ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) {
462 if (backfitting) {
463 UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus<double>(), true);
464 } else {
465 // TODO: think about a generic way to store "state" corresponding to the other models?
466 UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus<double>(), true);
467 }
468}
469
470template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
471static inline void EvaluateAllPossibleSplits(
472 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id,
473 std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
474 data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
475 std::vector<FeatureType>& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
476) {
477 // Initialize sufficient statistics
478 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
479 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
480 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
481
482 // Accumulate aggregate sufficient statistic for the node to be split
483 AccumulateSingleNodeSuffStat<LeafSuffStat, false>(node_suff_stat, dataset, tracker, residual, tree_num, split_node_id);
484
485 // Compute the "no split" log marginal likelihood
486 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
487
488 // Unpack data
489 Eigen::MatrixXd covariates = dataset.GetCovariates();
490 Eigen::VectorXd outcome = residual.GetData();
491 Eigen::VectorXd var_weights;
492 bool has_weights = dataset.HasVarWeights();
493 if (has_weights) var_weights = dataset.GetVarWeights();
494
495 // Minimum size of newly created leaf nodes (used to rule out invalid splits)
496 int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf();
497
498 // Compute sufficient statistics for each possible split
499 data_size_t num_cutpoints = 0;
500 bool valid_split = false;
501 data_size_t node_row_iter;
502 data_size_t current_bin_begin, current_bin_size, next_bin_begin;
503 data_size_t feature_sort_idx;
504 data_size_t row_iter_idx;
505 double outcome_val, outcome_val_sq;
506 FeatureType feature_type;
507 double feature_value = 0.0;
508 double cutoff_value = 0.0;
509 double log_split_eval = 0.0;
510 double split_log_ml;
511 for (int j = 0; j < covariates.cols(); j++) {
512
513 if (std::abs(variable_weights.at(j)) > kEpsilon) {
514 // Enumerate cutpoint strides
515 cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), split_node_id, node_begin, node_end, j, feature_types);
516
517 // Reset sufficient statistics
518 left_suff_stat.ResetSuffStat();
519 right_suff_stat.ResetSuffStat();
520
521 // Iterate through possible cutpoints
522 int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j);
523 feature_type = feature_types[j];
524 // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins
525 for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) {
526 current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j);
527 current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j);
528 next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j);
529
530 // Accumulate sufficient statistics for the left node
531 AccumulateCutpointBinSuffStat<LeafSuffStat>(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual,
532 global_variance, tree_num, split_node_id, j, cutpoint_idx);
533
534 // Compute the corresponding right node sufficient statistics
535 right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat);
536
537 // Store the bin index as the "cutpoint value" - we can use this to query the actual split
538 // value or the set of split categories later on once a split is chose
539 cutoff_value = cutpoint_idx;
540
541 // Only include cutpoint for consideration if it defines a valid split in the training data
542 valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) &&
543 right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf));
544 if (valid_split) {
545 num_cutpoints++;
546 // Add to split rule vector
547 cutpoint_feature_types.push_back(feature_type);
548 cutpoint_features.push_back(j);
549 cutpoint_values.push_back(cutoff_value);
550 // Add the log marginal likelihood of the split to the split eval vector
551 split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
552 log_cutpoint_evaluations.push_back(split_log_ml);
553 }
554 }
555 }
556
557 }
558
559 // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper)
560 cutpoint_features.push_back(-1);
561 cutpoint_values.push_back(std::numeric_limits<double>::max());
562 cutpoint_feature_types.push_back(FeatureType::kNumeric);
563 log_cutpoint_evaluations.push_back(no_split_log_ml);
564
565 // Update valid cutpoint count
566 valid_cutpoint_count = num_cutpoints;
567}
568
569template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
570static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior,
571 std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end,
572 std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values,
573 std::vector<FeatureType>& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector<double>& variable_weights,
574 std::vector<FeatureType>& feature_types, CutpointGridContainer& cutpoint_grid_container, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
575 // Evaluate all possible cutpoints according to the leaf node model,
576 // recording their log-likelihood and other split information in a series of vectors.
577 // The last element of these vectors concerns the "no-split" option.
578 EvaluateAllPossibleSplits<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
579 dataset, tracker, residual, tree_prior, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations,
580 cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container,
581 node_begin, node_end, variable_weights, feature_types, leaf_suff_stat_args...
582 );
583
584 // Compute an adjustment to reflect the no split prior probability and the number of cutpoints
585 double bart_prior_no_split_adj;
586 double alpha = tree_prior.GetAlpha();
587 double beta = tree_prior.GetBeta();
588 int node_depth = tree->GetDepth(node_id);
589 if (valid_cutpoint_count == 0) {
590 bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0);
591 } else {
592 bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count);
593 }
594 log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj;
595}
596
597template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
598static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
599 TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size,
600 std::unordered_map<int, std::pair<data_size_t, data_size_t>>& node_index_map, std::deque<node_t>& split_queue,
601 int node_id, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
602 std::vector<FeatureType>& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
603 // Leaf depth
604 int leaf_depth = tree->GetDepth(node_id);
605
606 // Maximum leaf depth
607 int32_t max_depth = tree_prior.GetMaxDepth();
608
609 if ((max_depth == -1) || (leaf_depth < max_depth)) {
610
611 // Cutpoint enumeration
612 std::vector<double> log_cutpoint_evaluations;
613 std::vector<int> cutpoint_features;
614 std::vector<double> cutpoint_values;
615 std::vector<FeatureType> cutpoint_feature_types;
616 StochTree::data_size_t valid_cutpoint_count;
617 CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);
618 EvaluateCutpoints<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
619 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance,
620 cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features,
621 cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types,
622 cutpoint_grid_container, leaf_suff_stat_args...
623 );
624 // TODO: maybe add some checks here?
625
626 // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood
627 double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end());
628 std::vector<double> cutpoint_evaluations(log_cutpoint_evaluations.size());
629 for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){
630 cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll);
631 }
632
633 // Sample the split (including a "no split" option)
634 std::discrete_distribution<data_size_t> split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end());
635 data_size_t split_chosen = split_dist(gen);
636
637 if (split_chosen == valid_cutpoint_count){
638 // "No split" sampled, don't split or add any nodes to split queue
639 return;
640 } else {
641 // Split sampled
642 int feature_split = cutpoint_features[split_chosen];
643 FeatureType feature_type = cutpoint_feature_types[split_chosen];
644 double split_value = cutpoint_values[split_chosen];
645 // Perform all of the relevant "split" operations in the model, tree and training dataset
646
647 // Compute node sample size
648 data_size_t node_n = node_end - node_begin;
649
650 // Actual numeric cutpoint used for ordered categorical and numeric features
651 double split_value_numeric;
652 TreeSplit tree_split;
653
654 // We will use these later in the model expansion
655 data_size_t left_n = 0;
656 data_size_t right_n = 0;
657 data_size_t sort_idx;
658 double feature_value;
659 bool split_true;
660
661 if (feature_type == FeatureType::kUnorderedCategorical) {
662 // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split
663 int num_categories;
664 std::vector<std::uint32_t> categories = cutpoint_grid_container.CutpointVector(static_cast<std::uint32_t>(split_value), feature_split);
665 tree_split = TreeSplit(categories);
666 } else if (feature_type == FeatureType::kOrderedCategorical) {
667 // Convert the bin split to an actual split value
668 split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast<std::uint32_t>(split_value), feature_split);
669 tree_split = TreeSplit(split_value_numeric);
670 } else if (feature_type == FeatureType::kNumeric) {
671 // Convert the bin split to an actual split value
672 split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast<std::uint32_t>(split_value), feature_split);
673 tree_split = TreeSplit(split_value_numeric);
674 } else {
675 Log::Fatal("Invalid split type");
676 }
677
678 // Add split to tree and trackers
679 AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true);
680
681 // Determine the number of observation in the newly created left node
682 int left_node = tree->LeftChild(node_id);
683 int right_node = tree->RightChild(node_id);
684 auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split);
685 auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split);
686 for (auto i = left_begin_iter; i < left_end_iter; i++) {
687 left_n += 1;
688 }
689
690 // Add the begin and end indices for the new left and right nodes to node_index_map
691 node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)});
692 node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)});
693
694 // Add the left and right nodes to the split tracker
695 split_queue.push_front(right_node);
696 split_queue.push_front(left_node);
697 }
698 }
699}
700
701template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
702static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
703 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
704 int tree_num, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
705 LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
706 int root_id = Tree::kRoot;
707 int curr_node_id;
708 data_size_t curr_node_begin;
709 data_size_t curr_node_end;
710 data_size_t n = dataset.GetCovariates().rows();
711 // Mapping from node id to start and end points of sorted indices
712 std::unordered_map<int, std::pair<data_size_t, data_size_t>> node_index_map;
713 node_index_map.insert({root_id, std::make_pair(0, n)});
714 std::pair<data_size_t, data_size_t> begin_end;
715 // Add root node to the split queue
716 std::deque<node_t> split_queue;
717 split_queue.push_back(Tree::kRoot);
718 // Run the "GrowFromRoot" procedure using a stack in place of recursion
719 while (!split_queue.empty()) {
720 // Remove the next node from the queue
721 curr_node_id = split_queue.front();
722 split_queue.pop_front();
723 // Determine the beginning and ending indices of the left and right nodes
724 begin_end = node_index_map[curr_node_id];
725 curr_node_begin = begin_end.first;
726 curr_node_end = begin_end.second;
727 // Draw a split rule at random
728 SampleSplitRule<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
729 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size,
730 node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types,
731 leaf_suff_stat_args...);
732 }
733}
734
763template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
764static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
765 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
766 double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
767 bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
768
769 // Run the GFR algorithm for each tree
770 int num_trees = forests.NumTrees();
771 for (int i = 0; i < num_trees; i++) {
772 // Adjust any model state needed to run a tree sampler
773 // For models that involve Bayesian backfitting, this amounts to adding tree i's
774 // predictions back to the residual (thus, training a model on the "partial residual")
775 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
776 Tree* tree = active_forest.GetTree(i);
777 AdjustStateBeforeTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
778
779 // Reset the tree and sample trackers
780 active_forest.ResetInitTree(i);
781 tracker.ResetRoot(dataset.GetCovariates(), feature_types, i);
782 tree = active_forest.GetTree(i);
783
784 // Sample tree i
785 GFRSampleTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
786 tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen,
787 variable_weights, i, global_variance, feature_types, cutpoint_grid_size,
788 leaf_suff_stat_args...
789 );
790
791 // Sample leaf parameters for tree i
792 tree = active_forest.GetTree(i);
793 leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
794
795 // Adjust any model state needed to run a tree sampler
796 // For models that involve Bayesian backfitting, this amounts to subtracting tree i's
797 // predictions back out of the residual (thus, using an updated "partial residual" in the following interation).
798 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
799 AdjustStateAfterTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
800 }
801
802 if (keep_forest) {
803 forests.AddSample(active_forest);
804 }
805}
806
807template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
808static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
809 TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector<double>& variable_weights,
810 double global_variance, double prob_grow_old, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
811 // Extract dataset information
812 data_size_t n = dataset.GetCovariates().rows();
813
814 // Choose a leaf node at random
815 int num_leaves = tree->NumLeaves();
816 std::vector<int> leaves = tree->GetLeaves();
817 std::vector<double> leaf_weights(num_leaves);
818 std::fill(leaf_weights.begin(), leaf_weights.end(), 1.0/num_leaves);
819 std::discrete_distribution<> leaf_dist(leaf_weights.begin(), leaf_weights.end());
820 int leaf_chosen = leaves[leaf_dist(gen)];
821 int leaf_depth = tree->GetDepth(leaf_chosen);
822
823 // Maximum leaf depth
824 int32_t max_depth = tree_prior.GetMaxDepth();
825
826 // Terminate early if cannot be split
827 bool accept;
828 if ((leaf_depth >= max_depth) && (max_depth != -1)) {
829 accept = false;
830 } else {
831
832 // Select a split variable at random
833 int p = dataset.GetCovariates().cols();
834 CHECK_EQ(variable_weights.size(), p);
835 std::discrete_distribution<> var_dist(variable_weights.begin(), variable_weights.end());
836 int var_chosen = var_dist(gen);
837
838 // Determine the range of possible cutpoints
839 // TODO: specialize this for binary / ordered categorical / unordered categorical variables
840 double var_min, var_max;
841 VarSplitRange(tracker, dataset, tree_num, leaf_chosen, var_chosen, var_min, var_max);
842 if (var_max <= var_min) {
843 return;
844 }
845
846 // Split based on var_min to var_max in a given node
847 std::uniform_real_distribution<double> split_point_dist(var_min, var_max);
848 double split_point_chosen = split_point_dist(gen);
849
850 // Create a split object
851 TreeSplit split = TreeSplit(split_point_chosen);
852
853 // Compute the marginal likelihood of split and no split, given the leaf prior
854 std::tuple<double, double, int32_t, int32_t> split_eval = EvaluateProposedSplit<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
855 dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, leaf_suff_stat_args...
856 );
857 double split_log_marginal_likelihood = std::get<0>(split_eval);
858 double no_split_log_marginal_likelihood = std::get<1>(split_eval);
859 int32_t left_n = std::get<2>(split_eval);
860 int32_t right_n = std::get<3>(split_eval);
861
862 // Reject the split if either of the left and right nodes are smaller than tree_prior.GetMinSamplesLeaf()
863 bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf();
864 bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf();
865 if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) {
866
867 // Determine probability of growing the split node and its two new left and right nodes
868 double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta());
869 double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
870 double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
871
872 // Determine whether a "grow" move is possible from the newly formed tree
873 // in order to compute the probability of choosing "prune" from the new tree
874 // (which is always possible by construction)
875 bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
876 bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf();
877 bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf();
878 double prob_prune_new;
879 if (non_constant && (min_samples_left_check || min_samples_right_check)) {
880 prob_prune_new = 0.5;
881 } else {
882 prob_prune_new = 1.0;
883 }
884
885 // Determine the number of leaves in the current tree and leaf parents in the proposed tree
886 int num_leaf_parents = tree->NumLeafParents();
887 double p_leaf = 1/static_cast<double>(num_leaves);
888 double p_leaf_parent = 1/static_cast<double>(num_leaf_parents+1);
889
890 // Compute the final MH ratio
891 double log_mh_ratio = (
892 std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) +
893 std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
894 );
895 // Threshold at 0
896 if (log_mh_ratio > 0) {
897 log_mh_ratio = 0;
898 }
899
900 // Draw a uniform random variable and accept/reject the proposal on this basis
901 std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
902 double log_acceptance_prob = std::log(mh_accept(gen));
903 if (log_acceptance_prob <= log_mh_ratio) {
904 accept = true;
905 AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false);
906 } else {
907 accept = false;
908 }
909
910 } else {
911 accept = false;
912 }
913 }
914}
915
916template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
917static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
918 TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
919 // Choose a "leaf parent" node at random
920 int num_leaves = tree->NumLeaves();
921 int num_leaf_parents = tree->NumLeafParents();
922 std::vector<int> leaf_parents = tree->GetLeafParents();
923 std::vector<double> leaf_parent_weights(num_leaf_parents);
924 std::fill(leaf_parent_weights.begin(), leaf_parent_weights.end(), 1.0/num_leaf_parents);
925 std::discrete_distribution<> leaf_parent_dist(leaf_parent_weights.begin(), leaf_parent_weights.end());
926 int leaf_parent_chosen = leaf_parents[leaf_parent_dist(gen)];
927 int leaf_parent_depth = tree->GetDepth(leaf_parent_chosen);
928 int left_node = tree->LeftChild(leaf_parent_chosen);
929 int right_node = tree->RightChild(leaf_parent_chosen);
930 int feature_split = tree->SplitIndex(leaf_parent_chosen);
931
932 // Compute the marginal likelihood for the leaf parent and its left and right nodes
933 std::tuple<double, double, int32_t, int32_t> split_eval = EvaluateExistingSplit<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
934 dataset, tracker, residual, leaf_model, global_variance, tree_num, leaf_parent_chosen, left_node, right_node, leaf_suff_stat_args...
935 );
936 double split_log_marginal_likelihood = std::get<0>(split_eval);
937 double no_split_log_marginal_likelihood = std::get<1>(split_eval);
938 int32_t left_n = std::get<2>(split_eval);
939 int32_t right_n = std::get<3>(split_eval);
940
941 // Determine probability of growing the split node and its two new left and right nodes
942 double pg = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth, -tree_prior.GetBeta());
943 double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta());
944 double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta());
945
946 // Determine whether a "prune" move is possible from the new tree,
947 // in order to compute the probability of choosing "grow" from the new tree
948 // (which is always possible by construction)
949 bool non_root_tree = tree->NumNodes() > 1;
950 double prob_grow_new;
951 if (non_root_tree) {
952 prob_grow_new = 0.5;
953 } else {
954 prob_grow_new = 1.0;
955 }
956
957 // Determine whether a "grow" move was possible from the old tree,
958 // in order to compute the probability of choosing "prune" from the old tree
959 bool non_constant_left = NodeNonConstant(dataset, tracker, tree_num, left_node);
960 bool non_constant_right = NodeNonConstant(dataset, tracker, tree_num, right_node);
961 double prob_prune_old;
962 if (non_constant_left && non_constant_right) {
963 prob_prune_old = 0.5;
964 } else {
965 prob_prune_old = 1.0;
966 }
967
968 // Determine the number of leaves in the current tree and leaf parents in the proposed tree
969 double p_leaf = 1/static_cast<double>(num_leaves-1);
970 double p_leaf_parent = 1/static_cast<double>(num_leaf_parents);
971
972 // Compute the final MH ratio
973 double log_mh_ratio = (
974 std::log(1-pg) - std::log(pg) - std::log(1-pgl) - std::log(1-pgr) + std::log(prob_prune_old) +
975 std::log(p_leaf) - std::log(prob_grow_new) - std::log(p_leaf_parent) + no_split_log_marginal_likelihood - split_log_marginal_likelihood
976 );
977 // Threshold at 0
978 if (log_mh_ratio > 0) {
979 log_mh_ratio = 0;
980 }
981
982 // Draw a uniform random variable and accept/reject the proposal on this basis
983 bool accept;
984 std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
985 double log_acceptance_prob = std::log(mh_accept(gen));
986 if (log_acceptance_prob <= log_mh_ratio) {
987 accept = true;
988 RemoveSplitFromModel(tracker, dataset, tree_prior, gen, tree, tree_num, leaf_parent_chosen, left_node, right_node, false);
989 } else {
990 accept = false;
991 }
992}
993
994template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
995static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
996 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
997 int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
998 // Determine whether it is possible to grow any of the leaves
999 bool grow_possible = false;
1000 std::vector<int> leaves = tree->GetLeaves();
1001 for (auto& leaf: leaves) {
1002 if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) {
1003 grow_possible = true;
1004 break;
1005 }
1006 }
1007
1008 // Determine whether it is possible to prune the tree
1009 bool prune_possible = false;
1010 if (tree->NumValidNodes() > 1) {
1011 prune_possible = true;
1012 }
1013
1014 // Determine the relative probability of grow vs prune (0 = grow, 1 = prune)
1015 double prob_grow;
1016 std::vector<double> step_probs(2);
1017 if (grow_possible && prune_possible) {
1018 step_probs = {0.5, 0.5};
1019 prob_grow = 0.5;
1020 } else if (!grow_possible && prune_possible) {
1021 step_probs = {0.0, 1.0};
1022 prob_grow = 0.0;
1023 } else if (grow_possible && !prune_possible) {
1024 step_probs = {1.0, 0.0};
1025 prob_grow = 1.0;
1026 } else {
1027 Log::Fatal("In this tree, neither grow nor prune is possible");
1028 }
1029 std::discrete_distribution<> step_dist(step_probs.begin(), step_probs.end());
1030
1031 // Draw a split rule at random
1032 data_size_t step_chosen = step_dist(gen);
1033 bool accept;
1034
1035 if (step_chosen == 0) {
1036 MCMCGrowTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1037 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, leaf_suff_stat_args...
1038 );
1039 } else {
1040 MCMCPruneTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1041 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, leaf_suff_stat_args...
1042 );
1043 }
1044}
1045
1072template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
1073static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
1074 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
1075 double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
1076 // Run the MCMC algorithm for each tree
1077 int num_trees = forests.NumTrees();
1078 for (int i = 0; i < num_trees; i++) {
1079 // Adjust any model state needed to run a tree sampler
1080 // For models that involve Bayesian backfitting, this amounts to adding tree i's
1081 // predictions back to the residual (thus, training a model on the "partial residual")
1082 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
1083 Tree* tree = active_forest.GetTree(i);
1084 AdjustStateBeforeTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
1085
1086 // Sample tree i
1087 tree = active_forest.GetTree(i);
1088 MCMCSampleTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1089 tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i,
1090 global_variance, leaf_suff_stat_args...
1091 );
1092
1093 // Sample leaf parameters for tree i
1094 tree = active_forest.GetTree(i);
1095 leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
1096
1097 // Adjust any model state needed to run a tree sampler
1098 // For models that involve Bayesian backfitting, this amounts to subtracting tree i's
1099 // predictions back out of the residual (thus, using an updated "partial residual" in the following interation).
1100 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
1101 AdjustStateAfterTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
1102 }
1103
1104 if (keep_forest) {
1105 forests.AddSample(active_forest);
1106 }
1107}
1108
// end of sampling_group
1110
1111} // namespace StochTree
1112
1113#endif // STOCHTREE_TREE_SAMPLER_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:194
Container of TreeEnsemble forest objects. This is the primary (in-memory) storage interface for multi...
Definition container.h:28
void AddSample(TreeEnsemble &forest)
Add a new forest to the container by copying forest.
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
Eigen::MatrixXd & GetCovariates()
Return a reference to the raw Eigen::MatrixXd storing the covariate data.
Definition data.h:384
double CovariateValue(data_size_t row, int col)
Returns a dataset's covariate value stored at (row, col)
Definition data.h:365
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:50
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Tree * GetTree(int i)
Return a pointer to a tree in the forest.
Definition ensemble.h:92
void ResetInitTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:121
Definition prior.h:43
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:927
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:959
Decision tree data structure.
Definition tree.h:69
static void MCMCSampleOneIter(TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
Runs one iteration of the MCMC sampler for a tree ensemble model, which consists of two steps for eve...
Definition tree_sampler.h:1073
static void VarSplitRange(ForestTracker &tracker, ForestDataset &dataset, int tree_num, int leaf_split, int feature_split, double &var_min, double &var_max)
Computer the range of available split values for a continuous variable, given the current structure o...
Definition tree_sampler.h:50
static bool NodesNonConstantAfterSplit(ForestDataset &dataset, ForestTracker &tracker, TreeSplit &split, int tree_num, int leaf_split, int feature_split)
Determines whether a proposed split creates two leaf nodes with constant values for every feature (th...
Definition tree_sampler.h:80
static void GFRSampleOneIter(TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, double global_variance, std::vector< FeatureType > &feature_types, int cutpoint_grid_size, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
Definition tree_sampler.h:764
Definition category_tracker.h:40