StochTree 0.1.1
Loading...
Searching...
No Matches
tree_sampler.h
1
2#ifndef STOCHTREE_TREE_SAMPLER_H_
3#define STOCHTREE_TREE_SAMPLER_H_
4
5#include <stochtree/container.h>
6#include <stochtree/cutpoint_candidates.h>
7#include <stochtree/data.h>
8#include <stochtree/discrete_sampler.h>
9#include <stochtree/ensemble.h>
10#include <stochtree/leaf_model.h>
11#include <stochtree/partition_tracker.h>
12#include <stochtree/prior.h>
13
14#include <cmath>
15#include <map>
16#include <memory>
17#include <numeric>
18#include <random>
19#include <set>
20#include <string>
21#include <type_traits>
22#include <variant>
23#include <vector>
24
25namespace StochTree {
26
52static inline void VarSplitRange(ForestTracker& tracker, ForestDataset& dataset, int tree_num, int leaf_split, int feature_split, double& var_min, double& var_max) {
53 var_min = std::numeric_limits<double>::max();
54 var_max = std::numeric_limits<double>::min();
55 double feature_value;
56
57 std::vector<data_size_t>::iterator node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split);
58 std::vector<data_size_t>::iterator node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split);
59
60 for (auto i = node_begin_iter; i != node_end_iter; i++) {
61 auto idx = *i;
62 feature_value = dataset.CovariateValue(idx, feature_split);
63 if (feature_value < var_min) {
64 var_min = feature_value;
65 } else if (feature_value > var_max) {
66 var_max = feature_value;
67 }
68 }
69}
70
82static inline bool NodesNonConstantAfterSplit(ForestDataset& dataset, ForestTracker& tracker, TreeSplit& split, int tree_num, int leaf_split, int feature_split) {
83 int p = dataset.GetCovariates().cols();
84 data_size_t idx;
85 double feature_value;
86 double split_feature_value;
87 double var_max_left;
88 double var_min_left;
89 double var_max_right;
90 double var_min_right;
91
92 for (int j = 0; j < p; j++) {
93 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, leaf_split);
94 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, leaf_split);
95 var_max_left = std::numeric_limits<double>::min();
96 var_min_left = std::numeric_limits<double>::max();
97 var_max_right = std::numeric_limits<double>::min();
98 var_min_right = std::numeric_limits<double>::max();
99
100 for (auto i = node_begin_iter; i != node_end_iter; i++) {
101 auto idx = *i;
102 split_feature_value = dataset.CovariateValue(idx, feature_split);
103 feature_value = dataset.CovariateValue(idx, j);
104 if (split.SplitTrue(split_feature_value)) {
105 if (var_max_left < feature_value) {
106 var_max_left = feature_value;
107 } else if (var_min_left > feature_value) {
108 var_min_left = feature_value;
109 }
110 } else {
111 if (var_max_right < feature_value) {
112 var_max_right = feature_value;
113 } else if (var_min_right > feature_value) {
114 var_min_right = feature_value;
115 }
116 }
117 }
118 if ((var_max_left > var_min_left) && (var_max_right > var_min_right)) {
119 return true;
120 }
121 }
122 return false;
123}
124
125static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracker, int tree_num, int node_id) {
126 int p = dataset.GetCovariates().cols();
127 data_size_t idx;
128 double feature_value;
129 double var_max;
130 double var_min;
131
132 for (int j = 0; j < p; j++) {
133 auto node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id);
134 auto node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id);
135 var_max = std::numeric_limits<double>::min();
136 var_min = std::numeric_limits<double>::max();
137
138 for (auto i = node_begin_iter; i != node_end_iter; i++) {
139 auto idx = *i;
140 feature_value = dataset.CovariateValue(idx, j);
141 if (var_max < feature_value) {
142 var_max = feature_value;
143 } else if (var_min > feature_value) {
144 var_min = feature_value;
145 }
146 }
147 if (var_max > var_min) {
148 return true;
149 }
150 }
151 return false;
152}
153
154static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree,
155 int tree_num, int leaf_node, int feature_split, bool keep_sorted = false) {
156 // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete
157 if (tree->OutputDimension() > 1) {
158 std::vector<double> temp_leaf_values(tree->OutputDimension(), 0.);
159 tree->ExpandNode(leaf_node, feature_split, split, temp_leaf_values, temp_leaf_values);
160 } else {
161 double temp_leaf_value = 0.;
162 tree->ExpandNode(leaf_node, feature_split, split, temp_leaf_value, temp_leaf_value);
163 }
164 int left_node = tree->LeftChild(leaf_node);
165 int right_node = tree->RightChild(leaf_node);
166
167 // Update the ForestTracker
168 tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted);
169}
170
171static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree,
172 int tree_num, int leaf_node, int left_node, int right_node, bool keep_sorted = false) {
173 // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete
174 if (tree->OutputDimension() > 1) {
175 std::vector<double> temp_leaf_values(tree->OutputDimension(), 0.);
176 tree->CollapseToLeaf(leaf_node, temp_leaf_values);
177 } else {
178 double temp_leaf_value = 0.;
179 tree->CollapseToLeaf(leaf_node, temp_leaf_value);
180 }
181
182 // Update the ForestTracker
183 tracker.RemoveSplit(dataset.GetCovariates(), tree, tree_num, leaf_node, left_node, right_node, keep_sorted);
184}
185
186static inline double ComputeMeanOutcome(ColumnVector& residual) {
187 int n = residual.NumRows();
188 double sum_y = 0.;
189 double y;
190 for (data_size_t i = 0; i < n; i++) {
191 y = residual.GetElement(i);
192 sum_y += y;
193 }
194 return sum_y / static_cast<double>(n);
195}
196
197static inline double ComputeVarianceOutcome(ColumnVector& residual) {
198 int n = residual.NumRows();
199 double sum_y = 0.;
200 double sum_y_sq = 0.;
201 double y;
202 for (data_size_t i = 0; i < n; i++) {
203 y = residual.GetElement(i);
204 sum_y += y;
205 sum_y_sq += y * y;
206 }
207 return sum_y_sq / static_cast<double>(n) - (sum_y * sum_y) / (static_cast<double>(n) * static_cast<double>(n));
208}
209
210static inline void UpdateModelVarianceForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual,
211 TreeEnsemble* forest, bool requires_basis, std::function<double(double, double)> op) {
212 data_size_t n = dataset.GetCovariates().rows();
213 double tree_pred = 0.;
214 double pred_value = 0.;
215 double new_resid = 0.;
216 int32_t leaf_pred;
217 for (data_size_t i = 0; i < n; i++) {
218 for (int j = 0; j < forest->NumTrees(); j++) {
219 Tree* tree = forest->GetTree(j);
220 leaf_pred = tracker.GetNodeId(i, j);
221 if (requires_basis) {
222 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
223 } else {
224 tree_pred += tree->PredictFromNode(leaf_pred);
225 }
226 tracker.SetTreeSamplePrediction(i, j, tree_pred);
227 pred_value += tree_pred;
228 }
229
230 // Run op (either plus or minus) on the residual and the new prediction
231 new_resid = op(residual.GetElement(i), pred_value);
232 residual.SetElement(i, new_resid);
233 }
234 tracker.SyncPredictions();
235}
236
237static inline void UpdateResidualNoTrackerUpdate(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
238 bool requires_basis, std::function<double(double, double)> op) {
239 data_size_t n = dataset.GetCovariates().rows();
240 double tree_pred = 0.;
241 double pred_value = 0.;
242 double new_resid = 0.;
243 int32_t leaf_pred;
244 for (data_size_t i = 0; i < n; i++) {
245 for (int j = 0; j < forest->NumTrees(); j++) {
246 Tree* tree = forest->GetTree(j);
247 leaf_pred = tracker.GetNodeId(i, j);
248 if (requires_basis) {
249 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
250 } else {
251 tree_pred += tree->PredictFromNode(leaf_pred);
252 }
253 pred_value += tree_pred;
254 }
255
256 // Run op (either plus or minus) on the residual and the new prediction
257 new_resid = op(residual.GetElement(i), pred_value);
258 residual.SetElement(i, new_resid);
259 }
260}
261
262static inline void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest,
263 bool requires_basis, std::function<double(double, double)> op) {
264 data_size_t n = dataset.GetCovariates().rows();
265 double tree_pred = 0.;
266 double pred_value = 0.;
267 double new_resid = 0.;
268 int32_t leaf_pred;
269 for (data_size_t i = 0; i < n; i++) {
270 for (int j = 0; j < forest->NumTrees(); j++) {
271 Tree* tree = forest->GetTree(j);
272 leaf_pred = tracker.GetNodeId(i, j);
273 if (requires_basis) {
274 tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
275 } else {
276 tree_pred += tree->PredictFromNode(leaf_pred);
277 }
278 tracker.SetTreeSamplePrediction(i, j, tree_pred);
279 pred_value += tree_pred;
280 }
281
282 // Run op (either plus or minus) on the residual and the new prediction
283 new_resid = op(residual.GetElement(i), pred_value);
284 residual.SetElement(i, new_resid);
285 }
286 tracker.SyncPredictions();
287}
288
289static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector& residual) {
290 data_size_t n = residual.NumRows();
291 double pred_value;
292 double prev_outcome;
293 double new_resid;
294 for (data_size_t i = 0; i < n; i++) {
295 prev_outcome = residual.GetElement(i);
296 pred_value = tracker.GetSamplePrediction(i);
297 // Run op (either plus or minus) on the residual and the new prediction
298 new_resid = prev_outcome - pred_value;
299 residual.SetElement(i, new_resid);
300 }
301}
302
303static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num,
304 bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
305 data_size_t n = dataset.GetCovariates().rows();
306 double pred_value;
307 int32_t leaf_pred;
308 double new_resid;
309 double pred_delta;
310 for (data_size_t i = 0; i < n; i++) {
311 if (tree_new) {
312 // If the tree has been newly sampled or adjusted, we must rerun the prediction
313 // method and update the SamplePredMapper stored in tracker
314 leaf_pred = tracker.GetNodeId(i, tree_num);
315 if (requires_basis) {
316 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
317 } else {
318 pred_value = tree->PredictFromNode(leaf_pred);
319 }
320 pred_delta = pred_value - tracker.GetTreeSamplePrediction(i, tree_num);
321 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
322 tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta);
323 } else {
324 // If the tree has not yet been modified via a sampling step,
325 // we can query its prediction directly from the SamplePredMapper stored in tracker
326 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
327 }
328 // Run op (either plus or minus) on the residual and the new prediction
329 new_resid = op(residual.GetElement(i), pred_value);
330 residual.SetElement(i, new_resid);
331 }
332}
333
334static inline void UpdateResidualNewBasis(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest) {
335 CHECK(dataset.HasBasis());
336 data_size_t n = dataset.GetCovariates().rows();
337 int num_trees = forest->NumTrees();
338 double prev_tree_pred;
339 double new_tree_pred;
340 int32_t leaf_pred;
341 double new_resid;
342 for (int tree_num = 0; tree_num < num_trees; tree_num++) {
343 Tree* tree = forest->GetTree(tree_num);
344 for (data_size_t i = 0; i < n; i++) {
345 // Add back the currently stored tree prediction
346 prev_tree_pred = tracker.GetTreeSamplePrediction(i, tree_num);
347 new_resid = residual.GetElement(i) + prev_tree_pred;
348
349 // Compute new prediction based on updated basis
350 leaf_pred = tracker.GetNodeId(i, tree_num);
351 new_tree_pred = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
352
353 // Cache the new prediction in the tracker
354 tracker.SetTreeSamplePrediction(i, tree_num, new_tree_pred);
355
356 // Subtract out the updated tree prediction
357 new_resid -= new_tree_pred;
358
359 // Propagate the change back to the residual
360 residual.SetElement(i, new_resid);
361 }
362 }
363 tracker.SyncPredictions();
364}
365
366static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree,
367 int tree_num, bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
368 data_size_t n = dataset.GetCovariates().rows();
369 double pred_value;
370 int32_t leaf_pred;
371 double new_weight;
372 double pred_delta;
373 double prev_tree_pred;
374 double prev_pred;
375 for (data_size_t i = 0; i < n; i++) {
376 if (tree_new) {
377 // If the tree has been newly sampled or adjusted, we must rerun the prediction
378 // method and update the SamplePredMapper stored in tracker
379 leaf_pred = tracker.GetNodeId(i, tree_num);
380 if (requires_basis) {
381 pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
382 } else {
383 pred_value = tree->PredictFromNode(leaf_pred);
384 }
385 prev_tree_pred = tracker.GetTreeSamplePrediction(i, tree_num);
386 prev_pred = tracker.GetSamplePrediction(i);
387 pred_delta = pred_value - prev_tree_pred;
388 tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
389 tracker.SetSamplePrediction(i, prev_pred + pred_delta);
390 new_weight = std::log(dataset.VarWeightValue(i)) + pred_value;
391 dataset.SetVarWeightValue(i, new_weight, true);
392 } else {
393 // If the tree has not yet been modified via a sampling step,
394 // we can query its prediction directly from the SamplePredMapper stored in tracker
395 pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
396 new_weight = std::log(dataset.VarWeightValue(i)) - pred_value;
397 dataset.SetVarWeightValue(i, new_weight, true);
398 }
399 }
400}
401
402template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
403static inline std::tuple<double, double, data_size_t, data_size_t> EvaluateProposedSplit(
404 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model,
405 TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance,
406 LeafSuffStatConstructorArgs&... leaf_suff_stat_args
407) {
408 // Initialize sufficient statistics
409 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
410 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
411 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
412
413 // Accumulate sufficient statistics
414 AccumulateSuffStatProposed<LeafSuffStat>(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker,
415 residual, global_variance, split, tree_num, leaf_num, split_feature);
416 data_size_t left_n = left_suff_stat.n;
417 data_size_t right_n = right_suff_stat.n;
418
419 // Evaluate split
420 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
421 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
422
423 return std::tuple<double, double, data_size_t, data_size_t>(split_log_ml, no_split_log_ml, left_n, right_n);
424}
425
426template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
427static inline std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(
428 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model,
429 double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id,
430 LeafSuffStatConstructorArgs&... leaf_suff_stat_args
431) {
432 // Initialize sufficient statistics
433 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
434 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
435 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
436
437 // Accumulate sufficient statistics
438 AccumulateSuffStatExisting<LeafSuffStat>(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker,
439 residual, global_variance, tree_num, split_node_id, left_node_id, right_node_id);
440 data_size_t left_n = left_suff_stat.n;
441 data_size_t right_n = right_suff_stat.n;
442
443 // Evaluate split
444 double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
445 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
446
447 return std::tuple<double, double, data_size_t, data_size_t>(split_log_ml, no_split_log_ml, left_n, right_n);
448}
449
450template <typename LeafModel>
451static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset,
452 ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) {
453 if (backfitting) {
454 UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus<double>(), false);
455 } else {
456 // TODO: think about a generic way to store "state" corresponding to the other models?
457 UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus<double>(), false);
458 }
459}
460
461template <typename LeafModel>
462static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset,
463 ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) {
464 if (backfitting) {
465 UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus<double>(), true);
466 } else {
467 // TODO: think about a generic way to store "state" corresponding to the other models?
468 UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus<double>(), true);
469 }
470}
471
472template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
473static inline void EvaluateAllPossibleSplits(
474 ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id,
475 std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
476 data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
477 std::vector<FeatureType>& feature_types, std::vector<bool>& feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
478) {
479 // Initialize sufficient statistics
480 LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
481 LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
482 LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
483
484 // Accumulate aggregate sufficient statistic for the node to be split
485 AccumulateSingleNodeSuffStat<LeafSuffStat, false>(node_suff_stat, dataset, tracker, residual, tree_num, split_node_id);
486
487 // Compute the "no split" log marginal likelihood
488 double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance);
489
490 // Unpack data
491 Eigen::MatrixXd covariates = dataset.GetCovariates();
492 Eigen::VectorXd outcome = residual.GetData();
493 Eigen::VectorXd var_weights;
494 bool has_weights = dataset.HasVarWeights();
495 if (has_weights) var_weights = dataset.GetVarWeights();
496
497 // Minimum size of newly created leaf nodes (used to rule out invalid splits)
498 int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf();
499
500 // Compute sufficient statistics for each possible split
501 data_size_t num_cutpoints = 0;
502 bool valid_split = false;
503 data_size_t node_row_iter;
504 data_size_t current_bin_begin, current_bin_size, next_bin_begin;
505 data_size_t feature_sort_idx;
506 data_size_t row_iter_idx;
507 double outcome_val, outcome_val_sq;
508 FeatureType feature_type;
509 double feature_value = 0.0;
510 double cutoff_value = 0.0;
511 double log_split_eval = 0.0;
512 double split_log_ml;
513 for (int j = 0; j < covariates.cols(); j++) {
514
515 if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) {
516 // Enumerate cutpoint strides
517 cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), split_node_id, node_begin, node_end, j, feature_types);
518
519 // Reset sufficient statistics
520 left_suff_stat.ResetSuffStat();
521 right_suff_stat.ResetSuffStat();
522
523 // Iterate through possible cutpoints
524 int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j);
525 feature_type = feature_types[j];
526 // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins
527 for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) {
528 current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j);
529 current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j);
530 next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j);
531
532 // Accumulate sufficient statistics for the left node
533 AccumulateCutpointBinSuffStat<LeafSuffStat>(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual,
534 global_variance, tree_num, split_node_id, j, cutpoint_idx);
535
536 // Compute the corresponding right node sufficient statistics
537 right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat);
538
539 // Store the bin index as the "cutpoint value" - we can use this to query the actual split
540 // value or the set of split categories later on once a split is chose
541 cutoff_value = cutpoint_idx;
542
543 // Only include cutpoint for consideration if it defines a valid split in the training data
544 valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) &&
545 right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf));
546 if (valid_split) {
547 num_cutpoints++;
548 // Add to split rule vector
549 cutpoint_feature_types.push_back(feature_type);
550 cutpoint_features.push_back(j);
551 cutpoint_values.push_back(cutoff_value);
552 // Add the log marginal likelihood of the split to the split eval vector
553 split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance);
554 log_cutpoint_evaluations.push_back(split_log_ml);
555 }
556 }
557 }
558
559 }
560
561 // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper)
562 cutpoint_features.push_back(-1);
563 cutpoint_values.push_back(std::numeric_limits<double>::max());
564 cutpoint_feature_types.push_back(FeatureType::kNumeric);
565 log_cutpoint_evaluations.push_back(no_split_log_ml);
566
567 // Update valid cutpoint count
568 valid_cutpoint_count = num_cutpoints;
569}
570
571template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
572static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior,
573 std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end,
574 std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values,
575 std::vector<FeatureType>& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector<double>& variable_weights,
576 std::vector<FeatureType>& feature_types, CutpointGridContainer& cutpoint_grid_container,
577 std::vector<bool>& feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
578 // Evaluate all possible cutpoints according to the leaf node model,
579 // recording their log-likelihood and other split information in a series of vectors.
580 // The last element of these vectors concerns the "no-split" option.
581 EvaluateAllPossibleSplits<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
582 dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations,
583 cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container,
584 node_begin, node_end, variable_weights, feature_types, feature_subset, leaf_suff_stat_args...
585 );
586
587 // Compute an adjustment to reflect the no split prior probability and the number of cutpoints
588 double bart_prior_no_split_adj;
589 double alpha = tree_prior.GetAlpha();
590 double beta = tree_prior.GetBeta();
591 int node_depth = tree->GetDepth(node_id);
592 if (valid_cutpoint_count == 0) {
593 bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0);
594 } else {
595 bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count);
596 }
597 log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj;
598}
599
600template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
601static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
602 TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size,
603 std::unordered_map<int, std::pair<data_size_t, data_size_t>>& node_index_map, std::deque<node_t>& split_queue,
604 int node_id, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
605 std::vector<FeatureType>& feature_types, std::vector<bool> feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
606 // Leaf depth
607 int leaf_depth = tree->GetDepth(node_id);
608
609 // Maximum leaf depth
610 int32_t max_depth = tree_prior.GetMaxDepth();
611
612 if ((max_depth == -1) || (leaf_depth < max_depth)) {
613
614 // Cutpoint enumeration
615 std::vector<double> log_cutpoint_evaluations;
616 std::vector<int> cutpoint_features;
617 std::vector<double> cutpoint_values;
618 std::vector<FeatureType> cutpoint_feature_types;
619 StochTree::data_size_t valid_cutpoint_count;
620 CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);
621 EvaluateCutpoints<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
622 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance,
623 cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features,
624 cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types,
625 cutpoint_grid_container, feature_subset, leaf_suff_stat_args...
626 );
627 // TODO: maybe add some checks here?
628
629 // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood
630 double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end());
631 std::vector<double> cutpoint_evaluations(log_cutpoint_evaluations.size());
632 for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){
633 cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll);
634 }
635
636 // Sample the split (including a "no split" option)
637 std::discrete_distribution<data_size_t> split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end());
638 data_size_t split_chosen = split_dist(gen);
639
640 if (split_chosen == valid_cutpoint_count){
641 // "No split" sampled, don't split or add any nodes to split queue
642 return;
643 } else {
644 // Split sampled
645 int feature_split = cutpoint_features[split_chosen];
646 FeatureType feature_type = cutpoint_feature_types[split_chosen];
647 double split_value = cutpoint_values[split_chosen];
648 // Perform all of the relevant "split" operations in the model, tree and training dataset
649
650 // Compute node sample size
651 data_size_t node_n = node_end - node_begin;
652
653 // Actual numeric cutpoint used for ordered categorical and numeric features
654 double split_value_numeric;
655 TreeSplit tree_split;
656
657 // We will use these later in the model expansion
658 data_size_t left_n = 0;
659 data_size_t right_n = 0;
660 data_size_t sort_idx;
661 double feature_value;
662 bool split_true;
663
664 if (feature_type == FeatureType::kUnorderedCategorical) {
665 // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split
666 int num_categories;
667 std::vector<std::uint32_t> categories = cutpoint_grid_container.CutpointVector(static_cast<std::uint32_t>(split_value), feature_split);
668 tree_split = TreeSplit(categories);
669 } else if (feature_type == FeatureType::kOrderedCategorical) {
670 // Convert the bin split to an actual split value
671 split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast<std::uint32_t>(split_value), feature_split);
672 tree_split = TreeSplit(split_value_numeric);
673 } else if (feature_type == FeatureType::kNumeric) {
674 // Convert the bin split to an actual split value
675 split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast<std::uint32_t>(split_value), feature_split);
676 tree_split = TreeSplit(split_value_numeric);
677 } else {
678 Log::Fatal("Invalid split type");
679 }
680
681 // Add split to tree and trackers
682 AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true);
683
684 // Determine the number of observation in the newly created left node
685 int left_node = tree->LeftChild(node_id);
686 int right_node = tree->RightChild(node_id);
687 auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split);
688 auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split);
689 for (auto i = left_begin_iter; i < left_end_iter; i++) {
690 left_n += 1;
691 }
692
693 // Add the begin and end indices for the new left and right nodes to node_index_map
694 node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)});
695 node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)});
696
697 // Add the left and right nodes to the split tracker
698 split_queue.push_front(right_node);
699 split_queue.push_front(left_node);
700 }
701 }
702}
703
704template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
705static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
706 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
707 int tree_num, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
708 int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
709 int root_id = Tree::kRoot;
710 int curr_node_id;
711 data_size_t curr_node_begin;
712 data_size_t curr_node_end;
713 data_size_t n = dataset.GetCovariates().rows();
714 int p = dataset.GetCovariates().cols();
715
716 // Subsample features (if requested)
717 std::vector<bool> feature_subset(p, true);
718 if (num_features_subsample < p) {
719 // Check if the number of (meaningfully) nonzero selection probabilities is greater than num_features_subsample
720 int number_nonzero_weights = 0;
721 for (int j = 0; j < p; j++) {
722 if (std::abs(variable_weights.at(j)) > kEpsilon) {
723 number_nonzero_weights++;
724 }
725 }
726 if (number_nonzero_weights > num_features_subsample) {
727 // Sample with replacement according to variable_weights
728 std::vector<int> feature_indices(p);
729 std::iota(feature_indices.begin(), feature_indices.end(), 0);
730 std::vector<int> features_selected(num_features_subsample);
731 sample_without_replacement<int, double>(
732 features_selected.data(), variable_weights.data(), feature_indices.data(),
733 p, num_features_subsample, gen
734 );
735 for (int i = 0; i < p; i++) {
736 feature_subset.at(i) = false;
737 }
738 for (const auto& feat : features_selected) {
739 feature_subset.at(feat) = true;
740 }
741 }
742 }
743
744 // Mapping from node id to start and end points of sorted indices
745 std::unordered_map<int, std::pair<data_size_t, data_size_t>> node_index_map;
746 node_index_map.insert({root_id, std::make_pair(0, n)});
747 std::pair<data_size_t, data_size_t> begin_end;
748 // Add root node to the split queue
749 std::deque<node_t> split_queue;
750 split_queue.push_back(Tree::kRoot);
751 // Run the "GrowFromRoot" procedure using a stack in place of recursion
752 while (!split_queue.empty()) {
753 // Remove the next node from the queue
754 curr_node_id = split_queue.front();
755 split_queue.pop_front();
756 // Determine the beginning and ending indices of the left and right nodes
757 begin_end = node_index_map[curr_node_id];
758 curr_node_begin = begin_end.first;
759 curr_node_end = begin_end.second;
760 // Draw a split rule at random
761 SampleSplitRule<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
762 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size,
763 node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types,
764 feature_subset, leaf_suff_stat_args...);
765 }
766}
767
798template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
799static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
800 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
801 std::vector<int>& sweep_update_indices, double global_variance, std::vector<FeatureType>& feature_types, int cutpoint_grid_size,
802 bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
803 // Run the GFR algorithm for each tree
804 int num_trees = forests.NumTrees();
805 for (const int& i : sweep_update_indices) {
806 // Adjust any model state needed to run a tree sampler
807 // For models that involve Bayesian backfitting, this amounts to adding tree i's
808 // predictions back to the residual (thus, training a model on the "partial residual")
809 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
810 Tree* tree = active_forest.GetTree(i);
811 AdjustStateBeforeTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
812
813 // Reset the tree and sample trackers
814 active_forest.ResetInitTree(i);
815 tracker.ResetRoot(dataset.GetCovariates(), feature_types, i);
816 tree = active_forest.GetTree(i);
817
818 // Sample tree i
819 GFRSampleTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
820 tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen,
821 variable_weights, i, global_variance, feature_types, cutpoint_grid_size,
822 num_features_subsample, leaf_suff_stat_args...
823 );
824
825 // Sample leaf parameters for tree i
826 tree = active_forest.GetTree(i);
827 leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
828
829 // Adjust any model state needed to run a tree sampler
830 // For models that involve Bayesian backfitting, this amounts to subtracting tree i's
831 // predictions back out of the residual (thus, using an updated "partial residual" in the following interation).
832 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
833 AdjustStateAfterTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
834 }
835
836 if (keep_forest) {
837 forests.AddSample(active_forest);
838 }
839}
840
841template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
842static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
843 TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector<double>& variable_weights,
844 double global_variance, double prob_grow_old, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
845 // Extract dataset information
846 data_size_t n = dataset.GetCovariates().rows();
847
848 // Choose a leaf node at random
849 int num_leaves = tree->NumLeaves();
850 std::vector<int> leaves = tree->GetLeaves();
851 std::vector<double> leaf_weights(num_leaves);
852 std::fill(leaf_weights.begin(), leaf_weights.end(), 1.0/num_leaves);
853 std::discrete_distribution<> leaf_dist(leaf_weights.begin(), leaf_weights.end());
854 int leaf_chosen = leaves[leaf_dist(gen)];
855 int leaf_depth = tree->GetDepth(leaf_chosen);
856
857 // Maximum leaf depth
858 int32_t max_depth = tree_prior.GetMaxDepth();
859
860 // Terminate early if cannot be split
861 bool accept;
862 if ((leaf_depth >= max_depth) && (max_depth != -1)) {
863 accept = false;
864 } else {
865
866 // Select a split variable at random
867 int p = dataset.GetCovariates().cols();
868 CHECK_EQ(variable_weights.size(), p);
869 std::discrete_distribution<> var_dist(variable_weights.begin(), variable_weights.end());
870 int var_chosen = var_dist(gen);
871
872 // Determine the range of possible cutpoints
873 // TODO: specialize this for binary / ordered categorical / unordered categorical variables
874 double var_min, var_max;
875 VarSplitRange(tracker, dataset, tree_num, leaf_chosen, var_chosen, var_min, var_max);
876 if (var_max <= var_min) {
877 return;
878 }
879
880 // Split based on var_min to var_max in a given node
881 std::uniform_real_distribution<double> split_point_dist(var_min, var_max);
882 double split_point_chosen = split_point_dist(gen);
883
884 // Create a split object
885 TreeSplit split = TreeSplit(split_point_chosen);
886
887 // Compute the marginal likelihood of split and no split, given the leaf prior
888 std::tuple<double, double, int32_t, int32_t> split_eval = EvaluateProposedSplit<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
889 dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, leaf_suff_stat_args...
890 );
891 double split_log_marginal_likelihood = std::get<0>(split_eval);
892 double no_split_log_marginal_likelihood = std::get<1>(split_eval);
893 int32_t left_n = std::get<2>(split_eval);
894 int32_t right_n = std::get<3>(split_eval);
895
896 // Reject the split if either of the left and right nodes are smaller than tree_prior.GetMinSamplesLeaf()
897 bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf();
898 bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf();
899 if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) {
900
901 // Determine probability of growing the split node and its two new left and right nodes
902 double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta());
903 double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
904 double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
905
906 // Determine whether a "grow" move is possible from the newly formed tree
907 // in order to compute the probability of choosing "prune" from the new tree
908 // (which is always possible by construction)
909 bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
910 bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf();
911 bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf();
912 double prob_prune_new;
913 if (non_constant && (min_samples_left_check || min_samples_right_check)) {
914 prob_prune_new = 0.5;
915 } else {
916 prob_prune_new = 1.0;
917 }
918
919 // Determine the number of leaves in the current tree and leaf parents in the proposed tree
920 int num_leaf_parents = tree->NumLeafParents();
921 double p_leaf = 1/static_cast<double>(num_leaves);
922 double p_leaf_parent = 1/static_cast<double>(num_leaf_parents+1);
923
924 // Compute the final MH ratio
925 double log_mh_ratio = (
926 std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) +
927 std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
928 );
929 // Threshold at 0
930 if (log_mh_ratio > 0) {
931 log_mh_ratio = 0;
932 }
933
934 // Draw a uniform random variable and accept/reject the proposal on this basis
935 std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
936 double log_acceptance_prob = std::log(mh_accept(gen));
937 if (log_acceptance_prob <= log_mh_ratio) {
938 accept = true;
939 AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false);
940 } else {
941 accept = false;
942 }
943
944 } else {
945 accept = false;
946 }
947 }
948}
949
950template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
951static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual,
952 TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
953 // Choose a "leaf parent" node at random
954 int num_leaves = tree->NumLeaves();
955 int num_leaf_parents = tree->NumLeafParents();
956 std::vector<int> leaf_parents = tree->GetLeafParents();
957 std::vector<double> leaf_parent_weights(num_leaf_parents);
958 std::fill(leaf_parent_weights.begin(), leaf_parent_weights.end(), 1.0/num_leaf_parents);
959 std::discrete_distribution<> leaf_parent_dist(leaf_parent_weights.begin(), leaf_parent_weights.end());
960 int leaf_parent_chosen = leaf_parents[leaf_parent_dist(gen)];
961 int leaf_parent_depth = tree->GetDepth(leaf_parent_chosen);
962 int left_node = tree->LeftChild(leaf_parent_chosen);
963 int right_node = tree->RightChild(leaf_parent_chosen);
964 int feature_split = tree->SplitIndex(leaf_parent_chosen);
965
966 // Compute the marginal likelihood for the leaf parent and its left and right nodes
967 std::tuple<double, double, int32_t, int32_t> split_eval = EvaluateExistingSplit<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
968 dataset, tracker, residual, leaf_model, global_variance, tree_num, leaf_parent_chosen, left_node, right_node, leaf_suff_stat_args...
969 );
970 double split_log_marginal_likelihood = std::get<0>(split_eval);
971 double no_split_log_marginal_likelihood = std::get<1>(split_eval);
972 int32_t left_n = std::get<2>(split_eval);
973 int32_t right_n = std::get<3>(split_eval);
974
975 // Determine probability of growing the split node and its two new left and right nodes
976 double pg = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth, -tree_prior.GetBeta());
977 double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta());
978 double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_parent_depth+1, -tree_prior.GetBeta());
979
980 // Determine whether a "prune" move is possible from the new tree,
981 // in order to compute the probability of choosing "grow" from the new tree
982 // (which is always possible by construction)
983 bool non_root_tree = tree->NumNodes() > 1;
984 double prob_grow_new;
985 if (non_root_tree) {
986 prob_grow_new = 0.5;
987 } else {
988 prob_grow_new = 1.0;
989 }
990
991 // Determine whether a "grow" move was possible from the old tree,
992 // in order to compute the probability of choosing "prune" from the old tree
993 bool non_constant_left = NodeNonConstant(dataset, tracker, tree_num, left_node);
994 bool non_constant_right = NodeNonConstant(dataset, tracker, tree_num, right_node);
995 double prob_prune_old;
996 if (non_constant_left && non_constant_right) {
997 prob_prune_old = 0.5;
998 } else {
999 prob_prune_old = 1.0;
1000 }
1001
1002 // Determine the number of leaves in the current tree and leaf parents in the proposed tree
1003 double p_leaf = 1/static_cast<double>(num_leaves-1);
1004 double p_leaf_parent = 1/static_cast<double>(num_leaf_parents);
1005
1006 // Compute the final MH ratio
1007 double log_mh_ratio = (
1008 std::log(1-pg) - std::log(pg) - std::log(1-pgl) - std::log(1-pgr) + std::log(prob_prune_old) +
1009 std::log(p_leaf) - std::log(prob_grow_new) - std::log(p_leaf_parent) + no_split_log_marginal_likelihood - split_log_marginal_likelihood
1010 );
1011 // Threshold at 0
1012 if (log_mh_ratio > 0) {
1013 log_mh_ratio = 0;
1014 }
1015
1016 // Draw a uniform random variable and accept/reject the proposal on this basis
1017 bool accept;
1018 std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
1019 double log_acceptance_prob = std::log(mh_accept(gen));
1020 if (log_acceptance_prob <= log_mh_ratio) {
1021 accept = true;
1022 RemoveSplitFromModel(tracker, dataset, tree_prior, gen, tree, tree_num, leaf_parent_chosen, left_node, right_node, false);
1023 } else {
1024 accept = false;
1025 }
1026}
1027
1028template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
1029static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
1030 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
1031 int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
1032 // Determine whether it is possible to grow any of the leaves
1033 bool grow_possible = false;
1034 std::vector<int> leaves = tree->GetLeaves();
1035 for (auto& leaf: leaves) {
1036 if (tracker.UnsortedNodeSize(tree_num, leaf) > 2 * tree_prior.GetMinSamplesLeaf()) {
1037 grow_possible = true;
1038 break;
1039 }
1040 }
1041
1042 // Determine whether it is possible to prune the tree
1043 bool prune_possible = false;
1044 if (tree->NumValidNodes() > 1) {
1045 prune_possible = true;
1046 }
1047
1048 // Determine the relative probability of grow vs prune (0 = grow, 1 = prune)
1049 double prob_grow;
1050 std::vector<double> step_probs(2);
1051 if (grow_possible && prune_possible) {
1052 step_probs = {0.5, 0.5};
1053 prob_grow = 0.5;
1054 } else if (!grow_possible && prune_possible) {
1055 step_probs = {0.0, 1.0};
1056 prob_grow = 0.0;
1057 } else if (grow_possible && !prune_possible) {
1058 step_probs = {1.0, 0.0};
1059 prob_grow = 1.0;
1060 } else {
1061 Log::Fatal("In this tree, neither grow nor prune is possible");
1062 }
1063 std::discrete_distribution<> step_dist(step_probs.begin(), step_probs.end());
1064
1065 // Draw a split rule at random
1066 data_size_t step_chosen = step_dist(gen);
1067 bool accept;
1068
1069 if (step_chosen == 0) {
1070 MCMCGrowTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1071 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, leaf_suff_stat_args...
1072 );
1073 } else {
1074 MCMCPruneTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1075 tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, leaf_suff_stat_args...
1076 );
1077 }
1078}
1079
1107template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
1108static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
1109 ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double>& variable_weights,
1110 std::vector<int>& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
1111 // Run the MCMC algorithm for each tree
1112 int num_trees = forests.NumTrees();
1113 for (const int& i : sweep_update_indices) {
1114 // Adjust any model state needed to run a tree sampler
1115 // For models that involve Bayesian backfitting, this amounts to adding tree i's
1116 // predictions back to the residual (thus, training a model on the "partial residual")
1117 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
1118 Tree* tree = active_forest.GetTree(i);
1119 AdjustStateBeforeTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
1120
1121 // Sample tree i
1122 tree = active_forest.GetTree(i);
1123 MCMCSampleTreeOneIter<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
1124 tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i,
1125 global_variance, leaf_suff_stat_args...
1126 );
1127
1128 // Sample leaf parameters for tree i
1129 tree = active_forest.GetTree(i);
1130 leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
1131
1132 // Adjust any model state needed to run a tree sampler
1133 // For models that involve Bayesian backfitting, this amounts to subtracting tree i's
1134 // predictions back out of the residual (thus, using an updated "partial residual" in the following interation).
1135 // For more general "blocked MCMC" models, this might require changes to a ForestTracker or Dataset object
1136 AdjustStateAfterTreeSampling<LeafModel>(tracker, leaf_model, dataset, residual, tree_prior, backfitting, tree, i);
1137 }
1138
1139 if (keep_forest) {
1140 forests.AddSample(active_forest);
1141 }
1142}
1143
// end of sampling_group
1145
1146} // namespace StochTree
1147
1148#endif // STOCHTREE_TREE_SAMPLER_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:194
Container of TreeEnsemble forest objects. This is the primary (in-memory) storage interface for multi...
Definition container.h:28
void AddSample(TreeEnsemble &forest)
Add a new forest to the container by copying forest.
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
Eigen::MatrixXd & GetCovariates()
Return a reference to the raw Eigen::MatrixXd storing the covariate data.
Definition data.h:384
double CovariateValue(data_size_t row, int col)
Returns a dataset's covariate value stored at (row, col)
Definition data.h:365
"Superclass" wrapper around tracking data structures for forest sampling algorithms
Definition partition_tracker.h:50
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Tree * GetTree(int i)
Return a pointer to a tree in the forest.
Definition ensemble.h:140
void ResetInitTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:169
Definition prior.h:43
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:961
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:993
Decision tree data structure.
Definition tree.h:69
static void VarSplitRange(ForestTracker &tracker, ForestDataset &dataset, int tree_num, int leaf_split, int feature_split, double &var_min, double &var_max)
Computer the range of available split values for a continuous variable, given the current structure o...
Definition tree_sampler.h:52
static void MCMCSampleOneIter(TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, std::vector< int > &sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
Runs one iteration of the MCMC sampler for a tree ensemble model, which consists of two steps for eve...
Definition tree_sampler.h:1108
static bool NodesNonConstantAfterSplit(ForestDataset &dataset, ForestTracker &tracker, TreeSplit &split, int tree_num, int leaf_split, int feature_split)
Determines whether a proposed split creates two leaf nodes with constant values for every feature (th...
Definition tree_sampler.h:82
static void GFRSampleOneIter(TreeEnsemble &active_forest, ForestTracker &tracker, ForestContainer &forests, LeafModel &leaf_model, ForestDataset &dataset, ColumnVector &residual, TreePrior &tree_prior, std::mt19937 &gen, std::vector< double > &variable_weights, std::vector< int > &sweep_update_indices, double global_variance, std::vector< FeatureType > &feature_types, int cutpoint_grid_size, bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, LeafSuffStatConstructorArgs &... leaf_suff_stat_args)
Definition tree_sampler.h:799
Definition category_tracker.h:40