StochTree 0.4.1.9000
Loading...
Searching...
No Matches
tree_sampler.h
1
2#ifndef STOCHTREE_TREE_SAMPLER_H_
3#define STOCHTREE_TREE_SAMPLER_H_
4
5#include <stochtree/container.h>
6#include <stochtree/cutpoint_candidates.h>
7#include <stochtree/data.h>
8#include <stochtree/discrete_sampler.h>
9#include <stochtree/distributions.h>
10#include <stochtree/ensemble.h>
11#include <stochtree/leaf_model.h>
12#include <stochtree/openmp_utils.h>
13#include <stochtree/partition_tracker.h>
14#include <stochtree/prior.h>
15
16#include <cmath>
17#include <numeric>
18#include <random>
19#include <vector>
20
21namespace StochTree {
22
48static inline void VarSplitRange(ForestTracker& tracker, ForestDataset& dataset, int tree_num, int leaf_split, int feature_split, double& var_min, double& var_max) {
49 var_min = std::numeric_limits<double>::max();
50 var_max = std::numeric_limits<double>::min();
51 double feature_value;
52
53 std::vector<data_size_t>::iterator node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split);
54 std::vector<data_size_t>::iterator node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split);
55
56 for (auto i = node_begin_iter; i != node_end_iter; i++) {
57 auto idx = *i;
58 feature_value = dataset.CovariateValue(idx, feature_split);
59 if (feature_value < var_min) {
60 var_min = feature_value;
61 } else if (feature_value > var_max) {
62 var_max = feature_value;
63 }
64 }
65}
66
78static inline bool NodesNonConstantAfterSplit(ForestDataset& dataset, ForestTracker& tracker, TreeSplit& split, int tree_num, int leaf_split, int feature_split) {
79 int p = dataset.GetCovariates().cols();
80 data_size_t idx;
81 double feature_value;
82 double split_feature_value;
83 double var_max_left;
84 double var_min_left;
85 double var_max_right;
86 double var_min_right;
87
88 for (int j = 0; j < p; j++) {
89 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split);
90 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split);
91 var_max_left = std::numeric_limits<double>::min();
92 var_min_left = std::numeric_limits<double>::max();
93 var_max_right = std::numeric_limits<double>::min();
94 var_min_right = std::numeric_limits<double>::max();
95
96 for (auto i = node_begin_iter; i != node_end_iter; i++) {
97 auto idx = *i;
98 split_feature_value = dataset.CovariateValue(idx, feature_split);
99 feature_value = dataset.CovariateValue(idx, j);
100 if (split.SplitTrue(split_feature_value)) {
101 if (var_max_left < feature_value) {
102 var_max_left = feature_value;
103 } else if (var_min_left > feature_value) {
104 var_min_left = feature_value;
105 }
106 } else {
107 if (var_max_right < feature_value) {
108 var_max_right = feature_value;
109 } else if (var_min_right > feature_value) {
110 var_min_right = feature_value;
111 }
112 }
113 }
114 if ((var_max_left > var_min_left) && (var_max_right > var_min_right)) {
115 return true;
116 }
117 }
118 return false;
119}
120
121static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracker, int tree_num, int node_id) {
122 int p = dataset.GetCovariates().cols();
123 data_size_t idx;
124 double feature_value;
125 double var_max;
126 double var_min;
127
128 for (int j = 0; j < p; j++) {
129 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
130 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
131 var_max = std::numeric_limits<double>::min();
132 var_min = std::numeric_limits<double>::max();
133
134 for (auto i = node_begin_iter; i != node_end_iter; i++) {
135 auto idx = *i;
136 feature_value = dataset.CovariateValue(idx, j);
137 if (var_max < feature_value) {
138 var_max = feature_value;
139 } else if (var_min > feature_value) {
140 var_min = feature_value;
141 }
142 }
143 if (var_max > var_min) {
144 return true;
145 }
146 }
147 return false;
148}
149
150static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree,
151 int tree_num, int leaf_node, int feature_split, bool keep_sorted = false, int num_threads = -1) {
152 // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete
153 if (tree->OutputDimension() > 1) {
154 std::vector<double> temp_leaf_values(tree->OutputDimension(), 0.);
155 tree->ExpandNode(leaf_node, feature_split, split, temp_leaf_values, temp_leaf_values);
156 } else {
157 double temp_leaf_value = 0.;
158 tree->ExpandNode(leaf_node, feature_split, split, temp_leaf_value, temp_leaf_value);
159 }
160 int left_node = tree->LeftChild(leaf_node);
161 int right_node = tree->RightChild(leaf_node);
162
163 // Update the ForestTracker
164 tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted, num_threads);
165}
166
167static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree,
168 int tree_num, int leaf_node, int left_node, int right_node, bool keep_sorted = false) {
169 // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete
170 if (tree->OutputDimension() > 1) {
171 std::vector<double> temp_leaf_values(tree->OutputDimension(), 0.);
172 tree->CollapseToLeaf(leaf_node, temp_leaf_values);
173 } else {
174 double temp_leaf_value = 0.;
175 tree->CollapseToLeaf(leaf_node, temp_leaf_value);
176 }
177
178 // Update the ForestTracker
179 tracker.RemoveSplit(dataset.GetCovariates(), tree, tree_num, leaf_node, left_node, right_node, keep_sorted);
180}
181
182static inline double ComputeMeanOutcome(ColumnVector& residual) {
183 int n = residual.NumRows();
184 double sum_y = 0.;
185 double y;
186 for (data_size_t i = 0; i < n; i++) {
187 y = residual.GetElement(i);
188 sum_y += y;
189 }
190 return sum_y / static_cast<double>(n);
191}
192
193static inline double ComputeVarianceOutcome(ColumnVector& residual) {
194 int n = residual.NumRows();
195 double sum_y = 0.;
196 double sum_y_sq = 0.;
197 double y;
198 for (data_size_t i = 0; i < n; i++) {
199 y = residual.GetElement(i);
200 sum_y += y;
201 sum_y_sq += y * y;
202 }
203 return sum_y_sq / static_cast<double>(n) - (sum_y * sum_y) / (static_cast<double>(n) * static_cast<double>(n));
204}
205
206static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual,
207 TreeEnsemble* forest, bool requires_basis, std::function<double(double, double)> op) {
208 data_size_t n = dataset.GetCovariates().rows();
209 double tree_pred = 0.;
210 double pred_value = 0.;
211 double new_resid = 0.;
212 int32_t leaf_pred;
213 for (data_size_t i = 0; i < n; i++) {
214 for (int j = 0; j < forest->NumTrees(); j++) {
215 Tree* tree = forest->GetTree(j);
216 leaf_pred = tracker.GetNodeId(i, j);
217 if (requires_basis) {
218 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
219 } else {
220 tree_pred += tree->PredictFromNode(leaf_pred);
221 }
222 tracker.SetTreeSamplePrediction(i, j, tree_pred);
223 pred_value += tree_pred;
224 }
225
226 // Run op (either plus or minus) on the residual and the new prediction
227 new_resid = op(residual.GetElement(i), pred_value);
228 residual.SetElement(i, new_resid);
229 }
230 tracker.SyncPredictions();
231}
232
233static inline void UpdateResidualNoTrackerUpdate(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
234 bool requires_basis, std::function<double(double, double)> op) {
235 data_size_t n = dataset.GetCovariates().rows();
236 double tree_pred = 0.;
237 double pred_value = 0.;
238 double new_resid = 0.;
239 int32_t leaf_pred;
240 for (data_size_t i = 0; i < n; i++) {
241 for (int j = 0; j < forest->NumTrees(); j++) {
242 Tree* tree = forest->GetTree(j);
243 leaf_pred = tracker.GetNodeId(i, j);
244 if (requires_basis) {
245 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
246 } else {
247 tree_pred += tree->PredictFromNode(leaf_pred);
248 }
249 pred_value += tree_pred;
250 }
251
252 // Run op (either plus or minus) on the residual and the new prediction
253 new_resid = op(residual.GetElement(i), pred_value);
254 residual.SetElement(i, new_resid);
255 }
256}
257
258static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
259 bool requires_basis, std::function<double(double, double)> op) {
260 data_size_t n = dataset.GetCovariates().rows();
261 double tree_pred = 0.;
262 double pred_value = 0.;
263 double new_resid = 0.;
264 int32_t leaf_pred;
265 for (data_size_t i = 0; i < n; i++) {
266 for (int j = 0; j < forest->NumTrees(); j++) {
267 Tree* tree = forest->GetTree(j);
268 leaf_pred = tracker.GetNodeId(i, j);
269 if (requires_basis) {
270 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
271 } else {
272 tree_pred += tree->PredictFromNode(leaf_pred);
273 }
274 tracker.SetTreeSamplePrediction(i, j, tree_pred);
275 pred_value += tree_pred;
276 }
277
278 // Run op (either plus or minus) on the residual and the new prediction
279 new_resid = op(residual.GetElement(i), pred_value);
280 residual.SetElement(i, new_resid);
281 }
282 tracker.SyncPredictions();
283}
284
285static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector& residual) {
286 data_size_t n = residual.NumRows();
287 double pred_value;
288 double prev_outcome;
289 double new_resid;
290 for (data_size_t i = 0; i < n; i++) {
291 prev_outcome = residual.GetElement(i);
292 pred_value = tracker.GetSamplePrediction(i);
293 // Run op (either plus or minus) on the residual and the new prediction
294 new_resid = prev_outcome - pred_value;
295 residual.SetElement(i, new_resid);
296 }
297}
298
299static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num,
300 bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
301 data_size_t n = dataset.GetCovariates().rows();
302 double pred_value;
303 int32_t leaf_pred;
304 double new_resid;
305 double pred_delta;
306 for (data_size_t i = 0; i < n; i++) {
307 if (tree_new) {
308 // If the tree has been newly sampled or adjusted, we must rerun the prediction
309 // method and update the SamplePredMapper stored in tracker
310 leaf_pred = tracker.GetNodeId(i, tree_num);
311 if (requires_basis) {
312 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
313 } else {
314 pred_value = tree->PredictFromNode(leaf_pred);
315 }
316 pred_delta = pred_value - tracker.GetTreeSamplePrediction(i, tree_num);
317 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
318 tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta);
319 } else {
320 // If the tree has not yet been modified via a sampling step,
321 // we can query its prediction directly from the SamplePredMapper stored in tracker
322 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
323 }
324 // Run op (either plus or minus) on the residual and the new prediction
325 new_resid = op(residual.GetElement(i), pred_value);
326 residual.SetElement(i, new_resid);
327 }
328}
329
330static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest) {
331 CHECK(dataset.HasBasis());
332 data_size_t n = dataset.GetCovariates().rows();
333 int num_trees = forest->NumTrees();
334 double prev_tree_pred;
335 double new_tree_pred;
336 int32_t leaf_pred;
337 double new_resid;
338 for (int tree_num = 0; tree_num < num_trees; tree_num++) {
339 Tree* tree = forest->GetTree(tree_num);
340 for (data_size_t i = 0; i < n; i++) {
341 // Add back the currently stored tree prediction
342 prev_tree_pred = tracker.GetTreeSamplePrediction(i, tree_num);
343 new_resid = residual.GetElement(i) + prev_tree_pred;
344
345 // Compute new prediction based on updated basis
346 leaf_pred = tracker.GetNodeId(i, tree_num);
347 new_tree_pred = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
348
349 // Cache the new prediction in the tracker
350 tracker.SetTreeSamplePrediction(i, tree_num, new_tree_pred);
351
352 // Subtract out the updated tree prediction
353 new_resid -= new_tree_pred;
354
355 // Propagate the change back to the residual
356 residual.SetElement(i, new_resid);
357 }
358 }
359 tracker.SyncPredictions();
360}
361
362static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree,
363 int tree_num, bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
364 data_size_t n = dataset.GetCovariates().rows();
365 double pred_value;
366 int32_t leaf_pred;
367 double new_weight;
368 double pred_delta;
369 double prev_tree_pred;
370 double prev_pred;
371 for (data_size_t i = 0; i < n; i++) {
372 if (tree_new) {
373 // If the tree has been newly sampled or adjusted, we must rerun the prediction
374 // method and update the SamplePredMapper stored in tracker
375 leaf_pred = tracker.GetNodeId(i, tree_num);
376 if (requires_basis) {
377 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
378 } else {
379 pred_value = tree->PredictFromNode(leaf_pred);
380 }
381 prev_tree_pred = tracker.GetTreeSamplePrediction(i, tree_num);
382 prev_pred = tracker.GetSamplePrediction(i);
383 pred_delta = pred_value - prev_tree_pred;
384 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
385 tracker.SetSamplePrediction(i, prev_pred + pred_delta);
386 new_weight = std::log(dataset.VarWeightValue(i)) + pred_value;
387 dataset.SetVarWeightValue(i, new_weight, true);
388 } else {
389 // If the tree has not yet been modified via a sampling step,
390 // we can query its prediction directly from the SamplePredMapper stored in tracker
391 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
392 new_weight = std::log(dataset.VarWeightValue(i)) - pred_value;
393 dataset.SetVarWeightValue(i, new_weight, true);
394 }
395 }
396}
397
398static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num,
399 bool requires_basis, bool tree_new) {
400 data_size_t n = dataset.GetCovariates().rows();
401
402 double pred_value;
403 int32_t leaf_pred;
404 double pred_delta;
405 for (data_size_t i = 0; i < n; i++) {
406 if (tree_new) {
407 // If the tree has been newly sampled or adjusted, we must rerun the prediction
408 // method and update the SamplePredMapper stored in tracker
409 leaf_pred = tracker.GetNodeId(i, tree_num);
410 if (requires_basis) {
411 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
412 } else {
413 pred_value = tree->PredictFromNode(leaf_pred);
414 }
415 pred_delta = pred_value - tracker.GetTreeSamplePrediction(i, tree_num);
416 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
417 tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta);
418 // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num)
419 dataset.SetAuxiliaryDataValue(1, i, tracker.GetSamplePrediction(i) - pred_value);
420 } else {
421 // If the tree has not yet been modified via a sampling step,
422 // we can query its prediction directly from the SamplePredMapper stored in tracker
423 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
424 // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num): needed? since tree not changed?
425 double current_lambda_hat = tracker.GetSamplePrediction(i);
426 double lambda_minus = current_lambda_hat - pred_value;
427 dataset.SetAuxiliaryDataValue(1, i, lambda_minus);
428 }
429 }
430}
431
432template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
433static inline std::tuple<double, double, data_size_t, data_size_t> EvaluateProposedSplit(
434 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model,
435 TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance,
436 int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
437) {
438 // Initialize sufficient statistics
439 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
440 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
441 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
442
443 // Accumulate sufficient statistics
444 AccumulateSuffStatProposed<LeafSuffStat, LeafSuffStatConstructorArgs...>(
445 node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker,
446 residual, global_variance, split, tree_num, leaf_num, split_feature, num_threads,
447 leaf_suff_stat_args...
448 );
449 data_size_t left_n = left_suff_stat.n;
450 data_size_t right_n = right_suff_stat.n;
451
452 // Evaluate split
453 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
454 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
455
456 return std::tuple<double, double, data_size_t, data_size_t>(split_log_ml, no_split_log_ml, left_n, right_n);
457}
458
459template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
460static inline std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(
461 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model,
462 double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id,
463 LeafSuffStatConstructorArgs&... leaf_suff_stat_args
464) {
465 // Initialize sufficient statistics
466 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
467 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
468 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
469
470 // Accumulate sufficient statistics
471 AccumulateSuffStatExisting<LeafSuffStat>(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker,
472 residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id);
473 data_size_t left_n = left_suff_stat.n;
474 data_size_t right_n = right_suff_stat.n;
475
476 // Evaluate split
477 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
478 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
479
480 return std::tuple<double, double, data_size_t, data_size_t>(split_log_ml, no_split_log_ml, left_n, right_n);
481}
482
483template <typename LeafModel>
484static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset,
485 ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) {
486 if constexpr (std::is_same_v<LeafModel, CloglogOrdinalLeafModel>) {
487 UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false);
488 } else if (backfitting) {
489 UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus<double>(), false);
490 } else {
491 // TODO: think about a generic way to store "state" corresponding to the other models?
492 UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus<double>(), false);
493 }
494}
495
496template <typename LeafModel>
497static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset,
498 ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) {
499 if constexpr (std::is_same_v<LeafModel, CloglogOrdinalLeafModel>) {
500 UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true);
501 } else if (backfitting) {
502 UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus<double>(), true);
503 } else {
504 // TODO: think about a generic way to store "state" corresponding to the other models?
505 UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus<double>(), true);
506 }
507}
508
509template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
510static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
511 TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size,
512 std::unordered_map<int, std::pair<data_size_t, data_size_t>>& node_index_map, std::deque<node_t>& split_queue,
513 int node_id, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
514 std::vector<FeatureType>& feature_types, std::vector<bool> feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
515 // Leaf depth
516 int leaf_depth = tree->GetDepth(node_id);
517
518 // Maximum leaf depth
519 int32_t max_depth = tree_prior.GetMaxDepth();
520
521 if ((max_depth == -1) || (leaf_depth < max_depth)) {
522
523 // Vector of vectors to store results for each feature
524 int p = dataset.NumCovariates();
525 std::vector<std::vector<double>> feature_log_cutpoint_evaluations(p+1);
526 std::vector<std::vector<double>> feature_cutpoint_values(p+1);
527 std::vector<int> feature_cutpoint_counts(p+1, 0);
528 StochTree::data_size_t valid_cutpoint_count;
529
530 // Evaluate all possible cutpoints according to the leaf node model,
531 // recording their log-likelihood and other split information in a series of vectors.
532
533 // Initialize node sufficient statistics
534 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
535
536 // Accumulate aggregate sufficient statistic for the node to be split
537 AccumulateSingleNodeSuffStat<LeafSuffStat, false>(node_suff_stat, dataset, tracker, residual, tree_num, node_id);
538
539 // Compute the "no split" log marginal likelihood
540 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
541
542 // Unpack data
543 Eigen::MatrixXd& covariates = dataset.GetCovariates();
544 Eigen::VectorXd& outcome = residual.GetData();
545 Eigen::VectorXd var_weights;
546 bool has_weights = dataset.HasVarWeights();
547 if (has_weights) var_weights = dataset.GetVarWeights();
548
549 // Minimum size of newly created leaf nodes (used to rule out invalid splits)
550 int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf();
551
552 // Compute sufficient statistics for each possible split
553 data_size_t num_cutpoints = 0;
554 if (num_threads == -1) {
555 num_threads = GetOptimalThreadCount(static_cast<int>(covariates.cols() * covariates.rows()));
556 }
557
558 // Initialize cutpoint grid container
559 CutpointGridContainer cutpoint_grid_container(covariates, outcome, cutpoint_grid_size);
560
561 // Evaluate all possible splits for each feature in parallel
562 StochTree::ParallelFor(0, covariates.cols(), num_threads, [&](int j) {
563 if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) {
564 // Enumerate cutpoint strides
565 cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types);
566
567 // Left and right node sufficient statistics
568 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
569 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
570
571 // Iterate through possible cutpoints
572 int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j);
573 FeatureType feature_type = feature_types[j];
574 // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins
575 for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) {
576 data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j);
577 data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j);
578 data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j);
579
580 // Accumulate sufficient statistics for the left node
581 AccumulateCutpointBinSuffStat<LeafSuffStat>(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual,
582 global_variance, tree_num, node_id, j, cutpoint_idx);
583
584 // Compute the corresponding right node sufficient statistics
585 right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat);
586
587 // Store the bin index as the "cutpoint value" - we can use this to query the actual split
588 // value or the set of split categories later on once a split is chose
589 double cutoff_value = cutpoint_idx;
590
591 // Only include cutpoint for consideration if it defines a valid split in the training data
592 bool valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) &&
593 right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf));
594 if (valid_split) {
595 feature_cutpoint_counts[j]++;
596 // Add to split rule vector
597 feature_cutpoint_values[j].push_back(cutoff_value);
598 // Add the log marginal likelihood of the split to the split eval vector
599 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
600 feature_log_cutpoint_evaluations[j].push_back(split_log_ml);
601 }
602 }
603 }
604 });
605
606 // Compute total number of cutpoints
607 valid_cutpoint_count = std::accumulate(feature_cutpoint_counts.begin(), feature_cutpoint_counts.end(), 0);
608
609 // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper)
610 feature_log_cutpoint_evaluations[covariates.cols()].push_back(no_split_log_ml);
611
612 // Compute an adjustment to reflect the no split prior probability and the number of cutpoints
613 double bart_prior_no_split_adj;
614 double alpha = tree_prior.GetAlpha();
615 double beta = tree_prior.GetBeta();
616 int node_depth = tree->GetDepth(node_id);
617 if (valid_cutpoint_count == 0) {
618 bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0);
619 } else {
620 bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count);
621 }
622 feature_log_cutpoint_evaluations[covariates.cols()][0] += bart_prior_no_split_adj;
623
624
625 // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood
626 double largest_ml = -std::numeric_limits<double>::infinity();
627 for (int j = 0; j < p + 1; j++) {
628 if (feature_log_cutpoint_evaluations[j].size() > 0) {
629 double feature_max_ml = *std::max_element(feature_log_cutpoint_evaluations[j].begin(), feature_log_cutpoint_evaluations[j].end());;
630 largest_ml = std::max(largest_ml, feature_max_ml);
631 }
632 }
633 std::vector<std::vector<double>> feature_cutpoint_evaluations(p+1);
634 for (int j = 0; j < p + 1; j++) {
635 if (feature_log_cutpoint_evaluations[j].size() > 0) {
636 feature_cutpoint_evaluations[j].resize(feature_log_cutpoint_evaluations[j].size());
637 for (int i = 0; i < feature_log_cutpoint_evaluations[j].size(); i++) {
638 feature_cutpoint_evaluations[j][i] = std::exp(feature_log_cutpoint_evaluations[j][i] - largest_ml);
639 }
640 }
641 }
642
643 // Compute sum of marginal likelihoods for each feature
644 std::vector<double> feature_total_cutpoint_evaluations(p+1, 0.0);
645 for (int j = 0; j < p + 1; j++) {
646 if (feature_log_cutpoint_evaluations[j].size() > 0) {
647 feature_total_cutpoint_evaluations[j] = std::accumulate(feature_cutpoint_evaluations[j].begin(), feature_cutpoint_evaluations[j].end(), 0.0);
648 } else {
649 feature_total_cutpoint_evaluations[j] = 0.0;
650 }
651 }
652
653 // First, sample a feature according to feature_total_cutpoint_evaluations
654 int feature_chosen = sample_discrete_stateless(gen, feature_total_cutpoint_evaluations);
655
656 // Then, sample a cutpoint according to feature_cutpoint_evaluations[feature_chosen]
657 int cutpoint_chosen = sample_discrete_stateless(gen, feature_cutpoint_evaluations[feature_chosen], feature_total_cutpoint_evaluations[feature_chosen]);
658
659 if (feature_chosen == p){
660 // "No split" sampled, don't split or add any nodes to split queue
661 return;
662 } else {
663 // Split sampled
664 int feature_split = feature_chosen;
665 FeatureType feature_type = feature_types[feature_split];
666 double split_value = feature_cutpoint_values[feature_split][cutpoint_chosen];
667 // Perform all of the relevant "split" operations in the model, tree and training dataset
668
669 // Compute node sample size
670 data_size_t node_n = node_end - node_begin;
671
672 // Actual numeric cutpoint used for ordered categorical and numeric features
673 double split_value_numeric;
674 TreeSplit tree_split;
675
676 // We will use these later in the model expansion
677 data_size_t left_n = 0;
678 data_size_t right_n = 0;
679 data_size_t sort_idx;
680 double feature_value;
681 bool split_true;
682
683 if (feature_type == FeatureType::kUnorderedCategorical) {
684 // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split
685 int num_categories;
686 std::vector<std::uint32_t> categories = cutpoint_grid_container.CutpointVector(static_cast<std::uint32_t>(split_value), feature_split);
687 tree_split = TreeSplit(categories);
688 } else if (feature_type == FeatureType::kOrderedCategorical) {
689 // Convert the bin split to an actual split value
690 split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast<std::uint32_t>(split_value), feature_split);
691 tree_split = TreeSplit(split_value_numeric);
692 } else if (feature_type == FeatureType::kNumeric) {
693 // Convert the bin split to an actual split value
694 split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast<std::uint32_t>(split_value), feature_split);
695 tree_split = TreeSplit(split_value_numeric);
696 } else {
697 Log::Fatal("Invalid split type");
698 }
699
700 // Add split to tree and trackers
701 AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true, num_threads);
702
703 // Determine the number of observation in the newly created left node
704 int left_node = tree->LeftChild(node_id);
705 int right_node = tree->RightChild(node_id);
706 auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split);
707 auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split);
708 for (auto i = left_begin_iter; i < left_end_iter; i++) {
709 left_n += 1;
710 }
711
712 // Add the begin and end indices for the new left and right nodes to node_index_map
713 node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)});
714 node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)});
715
716 // Add the left and right nodes to the split tracker
717 split_queue.push_front(right_node);
718 split_queue.push_front(left_node);
719 }
720 }
721}
722
723template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
724static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
725 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
726 int tree_num, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
727 int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
728 int root_id = Tree::kRoot;
729 int curr_node_id;
730 data_size_t curr_node_begin;
731 data_size_t curr_node_end;
732 data_size_t n = dataset.GetCovariates().rows();
733 int p = dataset.GetCovariates().cols();
734
735 // Subsample features (if requested)
736 std::vector<bool> feature_subset(p, true);
737 if (num_features_subsample < p) {
738 // Check if the number of (meaningfully) nonzero selection probabilities is greater than num_features_subsample
739 int number_nonzero_weights = 0;
740 for (int j = 0; j < p; j++) {
741 if (std::abs(variable_weights.at(j)) > kEpsilon) {
742 number_nonzero_weights++;
743 }
744 }
745 if (number_nonzero_weights > num_features_subsample) {
746 // Sample with replacement according to variable_weights
747 std::vector<int> feature_indices(p);
748 std::iota(feature_indices.begin(), feature_indices.end(), 0);
749 std::vector<int> features_selected(num_features_subsample);
750 sample_without_replacement<int, double>(
751 features_selected.data(), variable_weights.data(), feature_indices.data(),
752 p, num_features_subsample, gen
753 );
754 for (int i = 0; i < p; i++) {
755 feature_subset.at(i) = false;
756 }
757 for (const auto& feat : features_selected) {
758 feature_subset.at(feat) = true;
759 }
760 }
761 }
762
763 // Mapping from node id to start and end points of sorted indices
764 std::unordered_map<int, std::pair<data_size_t, data_size_t>> node_index_map;
765 node_index_map.insert({root_id, std::make_pair(0, n)});
766 std::pair<data_size_t, data_size_t> begin_end;
767 // Add root node to the split queue
768 std::deque<node_t> split_queue;
769 split_queue.push_back(Tree::kRoot);
770 // Run the "GrowFromRoot" procedure using a stack in place of recursion
771 while (!split_queue.empty()) {
772 // Remove the next node from the queue
773 curr_node_id = split_queue.front();
774 split_queue.pop_front();
775 // Determine the beginning and ending indices of the left and right nodes
776 begin_end = node_index_map[curr_node_id];
777 curr_node_begin = begin_end.first;
778 curr_node_end = begin_end.second;
779 // Draw a split rule at random
780 SampleSplitRule<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
781 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size,
782 node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types,
783 feature_subset, num_threads, leaf_suff_stat_args...
784 );
785 }
786}
787
819template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
820static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
821 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
822 std::vector<int>& sweep_update_indices, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
823 bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
824 // Run the GFR algorithm for each tree
825 int num_trees = forests.NumTrees();
826 for (const int& i : sweep_update_indices) {
827 // Adjust any model state needed to run a tree sampler
828 // For models that involve Bayesian backfitting, this amounts to adding tree i's
829 // predictions back to the residual (thus, training a model on the "partial residual")
830 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
831 Tree* tree = active_forest.GetTree(i);
832 AdjustStateBeforeTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
833
834 // Reset the tree and sample trackers
835 active_forest.ResetInitTree(i);
836 tracker.ResetRoot(dataset.GetCovariates(), feature_types, i);
837 tree = active_forest.GetTree(i);
838
839 // Sample tree i
840 GFRSampleTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
841 tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen,
842 variable_weights, i, global_variance, feature_types, cutpoint_grid_size,
843 num_features_subsample, num_threads, leaf_suff_stat_args...
844 );
845
846 // Sample leaf parameters for tree i
847 tree = active_forest.GetTree(i);
848 leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
849
850 // Adjust any model state needed to run a tree sampler
851 // For models that involve Bayesian backfitting, this amounts to subtracting tree i's
852 // predictions back out of the residual (thus, using an updated "partial residual" in the following interation).
853 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
854 AdjustStateAfterTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
855 }
856
857 if (keep_forest) {
858 forests.AddSample(active_forest);
859 }
860}
861
862template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
863static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
864 TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector<double>& variable_weights,
865 double global_variance, double prob_grow_old, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
866 // Extract dataset information
867 data_size_t n = dataset.GetCovariates().rows();
868
869 // Choose a leaf node at random
870 int num_leaves = tree->NumLeaves();
871 std::vector<int> leaves = tree->GetLeaves();
872 std::vector<double> leaf_weights(num_leaves);
873 std::fill(leaf_weights.begin(), leaf_weights.end(), 1.0/num_leaves);
874 walker_vose leaf_dist(leaf_weights.begin(), leaf_weights.end());
875 int leaf_chosen = leaves[leaf_dist(gen)];
876 int leaf_depth = tree->GetDepth(leaf_chosen);
877
878 // Maximum leaf depth
879 int32_t max_depth = tree_prior.GetMaxDepth();
880
881 // Terminate early if cannot be split
882 bool accept;
883 if ((leaf_depth >= max_depth) && (max_depth != -1)) {
884 accept = false;
885 } else {
886
887 // Select a split variable at random
888 int p = dataset.GetCovariates().cols();
889 CHECK_EQ(variable_weights.size(), p);
890 walker_vose var_dist(variable_weights.begin(), variable_weights.end());
891 int var_chosen = var_dist(gen);
892
893 // Determine the range of possible cutpoints
894 // TODO: specialize this for binary / ordered categorical / unordered categorical variables
895 double var_min, var_max;
896 VarSplitRange(tracker, dataset, tree_num, leaf_chosen, var_chosen, var_min, var_max);
897 if (var_max <= var_min) {
898 return;
899 }
900
901 // Split based on var_min to var_max in a given node
902 double split_point_chosen = standard_uniform_draw_53bit(gen) * (var_max - var_min) + var_min;
903
904 // Create a split object
905 TreeSplit split = TreeSplit(split_point_chosen);
906
907 // Compute the marginal likelihood of split and no split, given the leaf prior
908 std::tuple<double, double, int32_t, int32_t> split_eval = EvaluateProposedSplit<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
909 dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, num_threads, leaf_suff_stat_args...
910 );
911 double split_log_marginal_likelihood = std::get<0>(split_eval);
912 double no_split_log_marginal_likelihood = std::get<1>(split_eval);
913 int32_t left_n = std::get<2>(split_eval);
914 int32_t right_n = std::get<3>(split_eval);
915
916 // Reject the split if either of the left and right nodes are smaller than tree_prior.GetMinSamplesLeaf()
917 bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf();
918 bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf();
919 if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) {
920
921 // Determine probability of growing the split node and its two new left and right nodes
922 double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta());
923 double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
924 double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
925
926 // Determine whether a "grow" move is possible from the newly formed tree
927 // in order to compute the probability of choosing "prune" from the new tree
928 // (which is always possible by construction)
929 bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
930 bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf();
931 bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf();
932 double prob_prune_new;
933 if (non_constant && (min_samples_left_check || min_samples_right_check)) {
934 prob_prune_new = 0.5;
935 } else {
936 prob_prune_new = 1.0;
937 }
938
939 // Determine the number of leaves in the current tree and leaf parents in the proposed tree
940 int num_leaf_parents = tree->NumLeafParents();
941 double p_leaf = 1/static_cast<double>(num_leaves);
942 double p_leaf_parent = 1/static_cast<double>(num_leaf_parents+1);
943
944 // Compute the final MH ratio
945 double log_mh_ratio = (
946 std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) +
947 std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
948 );
949 // Threshold at 0
950 if (log_mh_ratio > 0) {
951 log_mh_ratio = 0;
952 }
953
954 // Draw a uniform random variable and accept/reject the proposal on this basis
955 double log_acceptance_prob = std::log(standard_uniform_draw_53bit(gen));
956 if (log_acceptance_prob <= log_mh_ratio) {
957 accept = true;
958 AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false, num_threads);
959 } else {
960 accept = false;
961 }
962
963 } else {
964 accept = false;
965 }
966 }
967}
968
969template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
970static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
971 TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads,
972 LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
973 // Choose a "leaf parent" node at random
974 int num_leaves = tree->NumLeaves();
975 int num_leaf_parents = tree->NumLeafParents();
976 std::vector<int> leaf_parents = tree->GetLeafParents();
977 std::vector<double> leaf_parent_weights(num_leaf_parents);
978 std::fill(leaf_parent_weights.begin(), leaf_parent_weights.end(), 1.0/num_leaf_parents);
979 walker_vose leaf_parent_dist(leaf_parent_weights.begin(), leaf_parent_weights.end());
980 int leaf_parent_chosen = leaf_parents[leaf_parent_dist(gen)];
981 int leaf_parent_depth = tree->GetDepth(leaf_parent_chosen);
982 int left_node = tree->LeftChild(leaf_parent_chosen);
983 int right_node = tree->RightChild(leaf_parent_chosen);
984 int feature_split = tree->SplitIndex(leaf_parent_chosen);
985
986 // Compute the marginal likelihood for the leaf parent and its left and right nodes
987 std::tuple<double, double, int32_t, int32_t> split_eval = EvaluateExistingSplit<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
988 dataset, tracker, residual, leaf_model, global_variance, tree_num, leaf_parent_chosen, left_node, right_node, leaf_suff_stat_args...
989 );
990 double split_log_marginal_likelihood = std::get<0>(split_eval);
991 double no_split_log_marginal_likelihood = std::get<1>(split_eval);
992 int32_t left_n = std::get<2>(split_eval);
993 int32_t right_n = std::get<3>(split_eval);
994
995 // Determine probability of growing the split node and its two new left and right nodes
996 double pg = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth, -tree_prior.GetBeta());
997 double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta());
998 double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta());
999
1000 // Determine whether a "prune" move is possible from the new tree,
1001 // in order to compute the probability of choosing "grow" from the new tree
1002 // (which is always possible by construction)
1003 bool non_root_tree = tree->NumNodes() > 1;
1004 double prob_grow_new;
1005 if (non_root_tree) {
1006 prob_grow_new = 0.5;
1007 } else {
1008 prob_grow_new = 1.0;
1009 }
1010
1011 // Determine whether a "grow" move was possible from the old tree,
1012 // in order to compute the probability of choosing "prune" from the old tree
1013 bool non_constant_left = NodeNonConstant(dataset, tracker, tree_num, left_node);
1014 bool non_constant_right = NodeNonConstant(dataset, tracker, tree_num, right_node);
1015 double prob_prune_old;
1016 if (non_constant_left && non_constant_right) {
1017 prob_prune_old = 0.5;
1018 } else {
1019 prob_prune_old = 1.0;
1020 }
1021
1022 // Determine the number of leaves in the current tree and leaf parents in the proposed tree
1023 double p_leaf = 1/static_cast<double>(num_leaves-1);
1024 double p_leaf_parent = 1/static_cast<double>(num_leaf_parents);
1025
1026 // Compute the final MH ratio
1027 double log_mh_ratio = (
1028 std::log(1-pg) - std::log(pg) - std::log(1-pgl) - std::log(1-pgr) + std::log(prob_prune_old) +
1029 std::log(p_leaf) - std::log(prob_grow_new) - std::log(p_leaf_parent) + no_split_log_marginal_likelihood - split_log_marginal_likelihood
1030 );
1031 // Threshold at 0
1032 if (log_mh_ratio > 0) {
1033 log_mh_ratio = 0;
1034 }
1035
1036 // Draw a uniform random variable and accept/reject the proposal on this basis
1037 bool accept;
1038 double log_acceptance_prob = std::log(standard_uniform_draw_53bit(gen));
1039 if (log_acceptance_prob <= log_mh_ratio) {
1040 accept = true;
1041 RemoveSplitFromModel(tracker, dataset, tree_prior, gen, tree, tree_num, leaf_parent_chosen, left_node, right_node, false);
1042 } else {
1043 accept = false;
1044 }
1045}
1046
1047template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
1048static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
1049 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
1050 int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
1051 // Determine whether it is possible to grow any of the leaves
1052 bool grow_possible = false;
1053 std::vector<int> leaves = tree->GetLeaves();
1054 for (auto& leaf: leaves) {
1055 if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) {
1056 grow_possible = true;
1057 break;
1058 }
1059 }
1060
1061 // Determine whether it is possible to prune the tree
1062 bool prune_possible = false;
1063 if (tree->NumValidNodes() > 1) {
1064 prune_possible = true;
1065 }
1066
1067 // Determine the relative probability of grow vs prune (0 = grow, 1 = prune)
1068 double prob_grow;
1069 std::vector<double> step_probs(2);
1070 if (grow_possible && prune_possible) {
1071 step_probs = {0.5, 0.5};
1072 prob_grow = 0.5;
1073 } else if (!grow_possible && prune_possible) {
1074 step_probs = {0.0, 1.0};
1075 prob_grow = 0.0;
1076 } else if (grow_possible && !prune_possible) {
1077 step_probs = {1.0, 0.0};
1078 prob_grow = 1.0;
1079 } else {
1080 Log::Fatal("In this tree, neither grow nor prune is possible");
1081 }
1082 walker_vose step_dist(step_probs.begin(), step_probs.end());
1083
1084 // Draw a split rule at random
1085 data_size_t step_chosen = step_dist(gen);
1086 bool accept;
1087
1088 if (step_chosen == 0) {
1089 MCMCGrowTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1090 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, num_threads, leaf_suff_stat_args...
1091 );
1092 } else {
1093 MCMCPruneTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1094 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, num_threads, leaf_suff_stat_args...
1095 );
1096 }
1097}
1098
1127template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
1128static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
1129 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
1130 std::vector<int>& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads,
1131 LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
1132 // Run the MCMC algorithm for each tree
1133 int num_trees = forests.NumTrees();
1134 for (const int& i : sweep_update_indices) {
1135 // Adjust any model state needed to run a tree sampler
1136 // For models that involve Bayesian backfitting, this amounts to adding tree i's
1137 // predictions back to the residual (thus, training a model on the "partial residual")
1138 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
1139 Tree* tree = active_forest.GetTree(i);
1140 AdjustStateBeforeTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
1141
1142 // Sample tree i
1143 tree = active_forest.GetTree(i);
1144 MCMCSampleTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1145 tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i,
1146 global_variance, num_threads, leaf_suff_stat_args...
1147 );
1148
1149 // Sample leaf parameters for tree i
1150 tree = active_forest.GetTree(i);
1151 leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
1152
1153 // Adjust any model state needed to run a tree sampler
1154 // For models that involve Bayesian backfitting, this amounts to subtracting tree i's
1155 // predictions back out of the residual (thus, using an updated "partial residual" in the following interation).
1156 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
1157 AdjustStateAfterTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
1158 }
1159
1160 if (keep_forest) {
1161 forests.AddSample(active_forest);
1162 }
1163}
1164
// end of sampling_group
1166
1167} // namespace StochTree
1168
1169#endif // STOCHTREE_TREE_SAMPLER_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:193
Container of TreeEnsemble forest objects. This is the primary (in-memory) storage interface for multi...
Definition container.h:23
void AddSample(TreeEnsemble &forest)
Add a new forest to the container by copying forest.
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
Eigen::MatrixXd & GetCovariates()
Return a reference to the raw Eigen::MatrixXd storing the covariate data.
Definition data.h:383
double CovariateValue(data_size_t row, int col)
Returns a dataset's covariate value stored at (row, col)
Definition data.h:364
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:46
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:31
Tree * GetTree(int i)
Return a pointer to a tree in the forest.
Definition ensemble.h:134
void ResetInitTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:163
Definition prior.h:43
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:958
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:990
Decision tree data structure.
Definition tree.h:66
static void MCMCSampleOneIter(TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, std::vector< int > &sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
Runs one iteration of the MCMC sampler for a tree ensemble model, which consists of two steps for eve...
Definition tree_sampler.h:1128
static void VarSplitRange(ForestTracker &tracker, ForestDataset &dataset, int tree_num, int leaf_split, int feature_split, double &var_min, double &var_max)
Computer the range of available split values for a continuous variable, given the current structure o...
Definition tree_sampler.h:48
static bool NodesNonConstantAfterSplit(ForestDataset &dataset, ForestTracker &tracker, TreeSplit &split, int tree_num, int leaf_split, int feature_split)
Determines whether a proposed split creates two leaf nodes with constant values for every feature (th...
Definition tree_sampler.h:78
static void GFRSampleOneIter(TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, std::vector< int > &sweep_update_indices, double global_variance, std::vector< FeatureType > &feature_types, int cutpoint_grid_size, bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
Definition tree_sampler.h:820
A collection of random number generation utilities.
Definition category_tracker.h:36
double standard_uniform_draw_53bit(std::mt19937 &gen)
Definition distributions.h:19