StochTree 0.1.1
Loading...
Searching...
No Matches
tree.h
1
6#ifndef STOCHTREE_TREE_H_
7#define STOCHTREE_TREE_H_
8
9#include <nlohmann/json.hpp>
10#include <stochtree/data.h>
11#include <stochtree/log.h>
12#include <stochtree/meta.h>
13#include <Eigen/Dense>
14
15#include <cstdint>
16#include <map>
17#include <optional>
18#include <set>
19#include <stack>
20#include <string>
21
22using json = nlohmann::json;
23
24namespace StochTree {
25
28 kLeafNode = 0,
29 kNumericalSplitNode = 1,
30 kCategoricalSplitNode = 2
31};
32
33// template<typename T>
34// int enum_to_int(T& input_enum) {
35// return static_cast<int>(input_enum);
36// }
37
38// template<typename T>
39// T json_to_enum(json& input_json) {
40// return static_cast<T>(input_json);
41// }
42
45
47TreeNodeType TreeNodeTypeFromString(std::string const& name);
48
49enum FeatureSplitType {
50 kNumericSplit,
51 kOrderedCategoricalSplit,
52 kUnorderedCategoricalSplit
53};
54
56class TreeSplit;
57
69class Tree {
70 public:
71 static constexpr std::int32_t kInvalidNodeId{-1};
72 static constexpr std::int32_t kDeletedNodeMarker = std::numeric_limits<node_t>::max();
73 static constexpr std::int32_t kRoot{0};
74
75 Tree() = default;
76 // ~Tree() = default;
77 Tree(Tree const&) = delete;
78 Tree& operator=(Tree const&) = delete;
79 Tree(Tree&&) noexcept = default;
80 Tree& operator=(Tree&&) noexcept = default;
87 void CloneFromTree(Tree* tree);
88
89 std::int32_t num_nodes{0};
90 std::int32_t num_deleted_nodes{0};
91
93 void Reset();
95 void Init(int output_dimension = 1, bool is_log_scale = false);
97 int AllocNode();
99 void DeleteNode(std::int32_t nid);
101 void ExpandNode(std::int32_t nid, int split_index, double split_value, double left_value, double right_value);
103 void ExpandNode(std::int32_t nid, int split_index, std::vector<std::uint32_t> const& categorical_indices, double left_value, double right_value);
105 void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
107 void ExpandNode(std::int32_t nid, int split_index, std::vector<std::uint32_t> const& categorical_indices, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
109 void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, double left_value, double right_value);
111 void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
112
114 inline bool IsRoot() {return leaves_.size() == 1;}
115
117 json to_json();
123 void from_json(const json& tree_json);
124
125 void ChangeToLeaf(std::int32_t nid, double value) {
126 CHECK(this->IsLeaf(this->LeftChild(nid)));
127 CHECK(this->IsLeaf(this->RightChild(nid)));
128 this->DeleteNode(this->LeftChild(nid));
129 this->DeleteNode(this->RightChild(nid));
130 this->SetLeaf(nid, value);
131
132 // Add nid to leaves and remove from internal nodes and leaf parents (if it was there)
133 leaves_.push_back(nid);
134 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
135 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
136
137 // Check if the other child of nid's parent node is also a leaf, if so, add parent back to leaf parents
138 // TODO refactor and add this to the multivariate case as well
139 if (!IsRoot(nid)) {
140 int parent_id = Parent(nid);
141 if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))){
142 leaf_parents_.push_back(parent_id);
143 }
144 }
145 }
146
152 void CollapseToLeaf(std::int32_t nid, double value) {
153 CHECK_EQ(output_dimension_, 1);
154 if (this->IsLeaf(nid)) return;
155 if (!this->IsLeaf(this->LeftChild(nid))) {
156 CollapseToLeaf(this->LeftChild(nid), value);
157 }
158 if (!this->IsLeaf(this->RightChild(nid))) {
159 CollapseToLeaf(this->RightChild(nid), value);
160 }
161 this->ChangeToLeaf(nid, value);
162 }
163
164 void ChangeToLeaf(std::int32_t nid, std::vector<double> value_vector) {
165 CHECK(this->IsLeaf(this->LeftChild(nid)));
166 CHECK(this->IsLeaf(this->RightChild(nid)));
167 this->DeleteNode(this->LeftChild(nid));
168 this->DeleteNode(this->RightChild(nid));
169 this->SetLeafVector(nid, value_vector);
170
171 // Add nid to leaves and remove from internal nodes and leaf parents (if it was there)
172 leaves_.push_back(nid);
173 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
174 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
175
176 // Check if the other child of nid's parent node is also a leaf, if so, add parent back to leaf parents
177 // TODO refactor and add this to the multivariate case as well
178 if (!IsRoot(nid)) {
179 int parent_id = Parent(nid);
180 if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))){
181 leaf_parents_.push_back(parent_id);
182 }
183 }
184 }
185
191 void CollapseToLeaf(std::int32_t nid, std::vector<double> value_vector) {
192 CHECK_GT(output_dimension_, 1);
193 CHECK_EQ(output_dimension_, value_vector.size());
194 if (this->IsLeaf(nid)) return;
195 if (!this->IsLeaf(this->LeftChild(nid))) {
196 CollapseToLeaf(this->LeftChild(nid), value_vector);
197 }
198 if (!this->IsLeaf(this->RightChild(nid))) {
199 CollapseToLeaf(this->RightChild(nid), value_vector);
200 }
201 this->ChangeToLeaf(nid, value_vector);
202 }
203
209 void AddValueToLeaves(double constant_value) {
210 if (output_dimension_ == 1) {
211 for (int j = 0; j < leaf_value_.size(); j++) {
212 leaf_value_[j] += constant_value;
213 }
214 } else {
215 for (int j = 0; j < leaf_vector_.size(); j++) {
216 leaf_vector_[j] += constant_value;
217 }
218 }
219 }
220
226 void MultiplyLeavesByValue(double constant_multiple) {
227 if (output_dimension_ == 1) {
228 for (int j = 0; j < leaf_value_.size(); j++) {
229 leaf_value_[j] *= constant_multiple;
230 }
231 } else {
232 for (int j = 0; j < leaf_vector_.size(); j++) {
233 leaf_vector_[j] *= constant_multiple;
234 }
235 }
236 }
237
244 template <typename Func> void WalkTree(Func func) const {
245 std::stack<std::int32_t> nodes;
246 nodes.push(kRoot);
247 auto &self = *this;
248 while (!nodes.empty()) {
249 auto nidx = nodes.top();
250 nodes.pop();
251 if (!func(nidx)) {
252 return;
253 }
254 auto left = self.LeftChild(nidx);
255 auto right = self.RightChild(nidx);
256 if (left != Tree::kInvalidNodeId) {
257 nodes.push(left);
258 }
259 if (right != Tree::kInvalidNodeId) {
260 nodes.push(right);
261 }
262 }
263 }
264
265 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices);
266 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices, Eigen::MatrixXd& basis);
267 double PredictFromNode(std::int32_t node_id);
268 double PredictFromNode(std::int32_t node_id, Eigen::MatrixXd& basis, int row_idx);
269
274 bool HasVectorOutput() const {
275 return output_dimension_ > 1;
276 }
277
281 std::int32_t OutputDimension() const {
282 return output_dimension_;
283 }
284
288 bool IsLogScale() const {
289 return is_log_scale_;
290 }
291
296 std::int32_t Parent(std::int32_t nid) const {
297 return parent_[nid];
298 }
299
304 std::int32_t LeftChild(std::int32_t nid) const {
305 return cleft_[nid];
306 }
307
312 std::int32_t RightChild(std::int32_t nid) const {
313 return cright_[nid];
314 }
315
320 std::int32_t DefaultChild(std::int32_t nid) const {
321 return cleft_[nid];
322 }
323
328 std::int32_t SplitIndex(std::int32_t nid) const {
329 return split_index_[nid];
330 }
331
336 bool IsLeaf(std::int32_t nid) const {
337 return cleft_[nid] == kInvalidNodeId;
338 }
339
344 bool IsRoot(std::int32_t nid) const {
345 return parent_[nid] == kInvalidNodeId;
346 }
347
352 bool IsDeleted(std::int32_t nid) const {
353 return node_deleted_[nid];
354 }
355
360 double LeafValue(std::int32_t nid) const {
361 return leaf_value_[nid];
362 }
363
369 double LeafValue(std::int32_t nid, std::int32_t dim_id) const {
370 CHECK_LT(dim_id, output_dimension_);
371 if (output_dimension_ == 1 && dim_id == 0) {
372 return leaf_value_[nid];
373 } else {
374 std::size_t const offset_begin = leaf_vector_begin_[nid];
375 std::size_t const offset_end = leaf_vector_end_[nid];
376 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
377 Log::Fatal("No leaf vector set for node nid");
378 }
379 return leaf_vector_[offset_begin + dim_id];
380 }
381 }
382
386 std::int32_t MaxLeafDepth() const {
387 std::int32_t max_depth = 0;
388 std::stack<std::int32_t> nodes;
389 std::stack<std::int32_t> node_depths;
390 nodes.push(kRoot);
391 node_depths.push(0);
392 auto &self = *this;
393 while (!nodes.empty()) {
394 auto nidx = nodes.top();
395 nodes.pop();
396 auto node_depth = node_depths.top();
397 node_depths.pop();
398 bool valid_node = !self.IsDeleted(nidx);
399 if (valid_node) {
400 if (node_depth > max_depth) max_depth = node_depth;
401 auto left = self.LeftChild(nidx);
402 auto right = self.RightChild(nidx);
403 if (left != Tree::kInvalidNodeId) {
404 nodes.push(left);
405 node_depths.push(node_depth+1);
406 }
407 if (right != Tree::kInvalidNodeId) {
408 nodes.push(right);
409 node_depths.push(node_depth+1);
410 }
411 }
412 }
413 return max_depth;
414 }
415
420 std::vector<double> LeafVector(std::int32_t nid) const {
421 std::size_t const offset_begin = leaf_vector_begin_[nid];
422 std::size_t const offset_end = leaf_vector_end_[nid];
423 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
424 // Return empty vector, to indicate the lack of leaf vector
425 return std::vector<double>();
426 }
427 return std::vector<double>(&leaf_vector_[offset_begin], &leaf_vector_[offset_end]);
428 // Use unsafe access here, since we may need to take the address of one past the last
429 // element, to follow with the range semantic of std::vector<>.
430 }
431
436 double SumSquaredNodeValues(std::int32_t nid) const {
437 if (output_dimension_ == 1) {
438 return std::pow(leaf_value_[nid], 2.0);
439 } else {
440 double result = 0.;
441 std::size_t const offset_begin = leaf_vector_begin_[nid];
442 std::size_t const offset_end = leaf_vector_end_[nid];
443 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
444 Log::Fatal("No leaf vector set for node nid");
445 }
446 for (std::size_t i = offset_begin; i < offset_end; i++) {
447 result += std::pow(leaf_vector_[i], 2.0);
448 }
449 return result;
450 }
451 }
452
456 double SumSquaredLeafValues() const {
457 double result = 0.;
458 for (auto& leaf : leaves_) {
459 result += SumSquaredNodeValues(leaf);
460 }
461 return result;
462 }
463
468 bool HasLeafVector(std::int32_t nid) const {
469 return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
470 }
471
476 double Threshold(std::int32_t nid) const {
477 return threshold_[nid];
478 }
479
487 std::vector<std::uint32_t> CategoryList(std::int32_t nid) const {
488 std::size_t const offset_begin = category_list_begin_[nid];
489 std::size_t const offset_end = category_list_end_[nid];
490 if (offset_begin >= category_list_.size() || offset_end > category_list_.size()) {
491 // Return empty vector, to indicate the lack of any category list
492 // The node might be a numerical split
493 return {};
494 }
495 return std::vector<std::uint32_t>(&category_list_[offset_begin], &category_list_[offset_end]);
496 // Use unsafe access here, since we may need to take the address of one past the last
497 // element, to follow with the range semantic of std::vector<>.
498 }
499
504 TreeNodeType NodeType(std::int32_t nid) const {
505 return node_type_[nid];
506 }
507
512 bool IsNumericSplitNode(std::int32_t nid) const {
513 return node_type_[nid] == TreeNodeType::kNumericalSplitNode;
514 }
515
520 bool IsCategoricalSplitNode(std::int32_t nid) const {
521 return node_type_[nid] == TreeNodeType::kCategoricalSplitNode;
522 }
523
527 bool HasCategoricalSplit() const {
528 return has_categorical_split_;
529 }
530
531 /* \brief Count number of leaves in tree. */
532 [[nodiscard]] std::int32_t NumLeaves() const;
533 [[nodiscard]] std::int32_t NumLeafParents() const;
534 [[nodiscard]] std::int32_t NumSplitNodes() const;
535
536 /* \brief Determine whether nid is leaf parent */
537 [[nodiscard]] bool IsLeafParent(std::int32_t nid) const {
538 // False until we deduce left and right node are
539 // available and both are leaves
540 bool is_left_leaf = false;
541 bool is_right_leaf = false;
542 // Check if node nidx is a leaf, if so, return false
543 bool is_leaf = this->IsLeaf(nid);
544 if (is_leaf){
545 return false;
546 } else {
547 // If nidx is not a leaf, it must have left and right nodes
548 // so we check if those are leaves
549 std::int32_t left_node = LeftChild(nid);
550 std::int32_t right_node = RightChild(nid);
551 is_left_leaf = IsLeaf(left_node);
552 is_right_leaf = IsLeaf(right_node);
553 }
554 return is_left_leaf && is_right_leaf;
555 }
556
560 [[nodiscard]] std::vector<std::int32_t> const& GetInternalNodes() const {
561 return internal_nodes_;
562 }
563
567 [[nodiscard]] std::vector<std::int32_t> const& GetLeaves() const {
568 return leaves_;
569 }
570
574 [[nodiscard]] std::vector<std::int32_t> const& GetLeafParents() const {
575 return leaf_parents_;
576 }
577
581 [[nodiscard]] std::vector<std::int32_t> GetNodes() {
582 std::vector<std::int32_t> output;
583 auto const& self = *this;
584 this->WalkTree([&output, &self](std::int32_t nidx) {
585 if (!self.IsDeleted(nidx)) {
586 output.push_back(nidx);
587 }
588 return true;
589 });
590 return output;
591 }
592
597 [[nodiscard]] std::int32_t GetDepth(std::int32_t nid) const {
598 int depth = 0;
599 while (!IsRoot(nid)) {
600 ++depth;
601 nid = Parent(nid);
602 }
603 return depth;
604 }
605
609 [[nodiscard]] std::int32_t NumNodes() const noexcept { return num_nodes; }
610
614 [[nodiscard]] std::int32_t NumDeletedNodes() const noexcept { return num_deleted_nodes; }
615
619 [[nodiscard]] std::int32_t NumValidNodes() const noexcept {
620 return num_nodes - num_deleted_nodes;
621 }
622
629 void SetLeftChild(std::int32_t nid, std::int32_t left_child) {
630 cleft_[nid] = left_child;
631 }
632
638 void SetRightChild(std::int32_t nid, std::int32_t right_child) {
639 cright_[nid] = right_child;
640 }
641
648 void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
649 SetLeftChild(nid, left_child);
650 SetRightChild(nid, right_child);
651 }
652
658 void SetParent(std::int32_t child_node, std::int32_t parent_node) {
659 parent_[child_node] = parent_node;
660 }
661
668 void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
669 SetParent(left_child, nid);
670 SetParent(right_child, nid);
671 }
672
680 std::int32_t nid, std::int32_t split_index, double threshold);
681
690 void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index,
691 std::vector<std::uint32_t> const& category_list);
692
698 void SetLeaf(std::int32_t nid, double value);
699
705 void SetLeafVector(std::int32_t nid, std::vector<double> const& leaf_vector);
706
727 void PredictLeafIndexInplace(ForestDataset* dataset, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
728
749 void PredictLeafIndexInplace(Eigen::MatrixXd& covariates, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
750
771 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
772
773 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
774 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
775 int column_ind, int32_t offset, int32_t max_leaf);
776
777 // Node info
778 std::vector<TreeNodeType> node_type_;
779 std::vector<std::int32_t> parent_;
780 std::vector<std::int32_t> cleft_;
781 std::vector<std::int32_t> cright_;
782 std::vector<std::int32_t> split_index_;
783 std::vector<double> leaf_value_;
784 std::vector<double> threshold_;
785 std::vector<bool> node_deleted_;
786 std::vector<std::int32_t> internal_nodes_;
787 std::vector<std::int32_t> leaves_;
788 std::vector<std::int32_t> leaf_parents_;
789 std::vector<std::int32_t> deleted_nodes_;
790
791 // Leaf vector
792 std::vector<double> leaf_vector_;
793 std::vector<std::uint64_t> leaf_vector_begin_;
794 std::vector<std::uint64_t> leaf_vector_end_;
795
796 // Category list
797 std::vector<std::uint32_t> category_list_;
798 std::vector<std::uint64_t> category_list_begin_;
799 std::vector<std::uint64_t> category_list_end_;
800
801 bool has_categorical_split_{false};
802 int output_dimension_{1};
803 bool is_log_scale_{false};
804};
805
807inline bool operator==(const Tree& lhs, const Tree& rhs) {
808 return (
809 (lhs.has_categorical_split_ == rhs.has_categorical_split_) &&
810 (lhs.output_dimension_ == rhs.output_dimension_) &&
811 (lhs.is_log_scale_ == rhs.is_log_scale_) &&
812 (lhs.node_type_ == rhs.node_type_) &&
813 (lhs.parent_ == rhs.parent_) &&
814 (lhs.cleft_ == rhs.cleft_) &&
815 (lhs.cright_ == rhs.cright_) &&
816 (lhs.split_index_ == rhs.split_index_) &&
817 (lhs.leaf_value_ == rhs.leaf_value_) &&
818 (lhs.threshold_ == rhs.threshold_) &&
819 (lhs.internal_nodes_ == rhs.internal_nodes_) &&
820 (lhs.leaves_ == rhs.leaves_) &&
821 (lhs.leaf_parents_ == rhs.leaf_parents_) &&
822 (lhs.deleted_nodes_ == rhs.deleted_nodes_) &&
823 (lhs.leaf_vector_ == rhs.leaf_vector_) &&
824 (lhs.leaf_vector_begin_ == rhs.leaf_vector_begin_) &&
825 (lhs.leaf_vector_end_ == rhs.leaf_vector_end_) &&
826 (lhs.category_list_ == rhs.category_list_) &&
827 (lhs.category_list_begin_ == rhs.category_list_begin_) &&
828 (lhs.category_list_end_ == rhs.category_list_end_)
829 );
830}
831
838inline bool SplitTrueNumeric(double fvalue, double threshold) {
839 return (fvalue <= threshold);
840}
841
848inline bool SplitTrueCategorical(double fvalue, std::vector<std::uint32_t> const& category_list) {
849 bool category_matched;
850 // A valid (integer) category must satisfy two criteria:
851 // 1) it must be exactly representable as double
852 // 2) it must fit into uint32_t
853 auto max_representable_int
854 = std::min(static_cast<double>(std::numeric_limits<std::uint32_t>::max()),
855 static_cast<double>(std::uint64_t(1) << std::numeric_limits<double>::digits));
856 if (fvalue < 0 || std::fabs(fvalue) > max_representable_int) {
857 category_matched = false;
858 } else {
859 auto const category_value = static_cast<std::uint32_t>(fvalue);
860 category_matched = (std::find(category_list.begin(), category_list.end(), category_value)
861 != category_list.end());
862 }
863 return category_matched;
864}
865
872inline int NextNodeNumeric(double fvalue, double threshold, int left_child, int right_child) {
873 return (SplitTrueNumeric(fvalue, threshold) ? left_child : right_child);
874}
875
882inline int NextNodeCategorical(double fvalue, std::vector<std::uint32_t> const& category_list, int left_child, int right_child) {
883 return SplitTrueCategorical(fvalue, category_list) ? left_child : right_child;
884}
885
893inline int EvaluateTree(Tree const& tree, Eigen::MatrixXd& data, int row) {
894 int node_id = 0;
895 while (!tree.IsLeaf(node_id)) {
896 auto const split_index = tree.SplitIndex(node_id);
897 double const fvalue = data(row, split_index);
898 if (std::isnan(fvalue)) {
899 node_id = tree.DefaultChild(node_id);
900 } else {
901 if (tree.NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
902 node_id = NextNodeCategorical(fvalue, tree.CategoryList(node_id),
903 tree.LeftChild(node_id), tree.RightChild(node_id));
904 } else {
905 node_id = NextNodeNumeric(fvalue, tree.Threshold(node_id), tree.LeftChild(node_id), tree.RightChild(node_id));
906 }
907 }
908 }
909 return node_id;
910}
911
919inline int EvaluateTree(Tree const& tree, Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& data, int row) {
920 int node_id = 0;
921 while (!tree.IsLeaf(node_id)) {
922 auto const split_index = tree.SplitIndex(node_id);
923 double const fvalue = data(row, split_index);
924 if (std::isnan(fvalue)) {
925 node_id = tree.DefaultChild(node_id);
926 } else {
927 if (tree.NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
928 node_id = NextNodeCategorical(fvalue, tree.CategoryList(node_id),
929 tree.LeftChild(node_id), tree.RightChild(node_id));
930 } else {
931 node_id = NextNodeNumeric(fvalue, tree.Threshold(node_id), tree.LeftChild(node_id), tree.RightChild(node_id));
932 }
933 }
934 }
935 return node_id;
936}
937
944inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, double split_value) {
945 double const fvalue = covariates(row, split_index);
946 return SplitTrueNumeric(fvalue, split_value);
947}
948
955inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, std::vector<std::uint32_t> const& category_list) {
956 double const fvalue = covariates(row, split_index);
957 return SplitTrueCategorical(fvalue, category_list);
958}
959
962 public:
963 TreeSplit() {}
969 TreeSplit(double split_value) {
970 numeric_ = true;
971 split_value_ = split_value;
972 split_set_ = true;
973 }
979 TreeSplit(std::vector<std::uint32_t>& split_categories) {
980 numeric_ = false;
981 split_categories_ = split_categories;
982 split_set_ = true;
983 }
984 ~TreeSplit() {}
985 bool SplitSet() {return split_set_;}
987 bool NumericSplit() {return numeric_;}
993 bool SplitTrue(double fvalue) {
994 if (numeric_) return SplitTrueNumeric(fvalue, split_value_);
995 else return SplitTrueCategorical(fvalue, split_categories_);
996 }
998 double SplitValue() {return split_value_;}
1000 std::vector<std::uint32_t> SplitCategories() {return split_categories_;}
1001 private:
1002 bool split_set_{false};
1003 bool numeric_;
1004 double split_value_;
1005 std::vector<std::uint32_t> split_categories_;
1006};
1007
// end of tree_group
1009
1010} // namespace StochTree
1011
1012#endif // STOCHTREE_TREE_H_
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:961
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:993
bool NumericSplit()
Whether or not a TreeSplit rule is numeric.
Definition tree.h:987
std::vector< std::uint32_t > SplitCategories()
Categories defining a TreeSplit object.
Definition tree.h:1000
TreeSplit(std::vector< std::uint32_t > &split_categories)
Construct a categorical TreeSplit.
Definition tree.h:979
double SplitValue()
Numeric cutoff value defining a TreeSplit object.
Definition tree.h:998
TreeSplit(double split_value)
Construct a numeric TreeSplit.
Definition tree.h:969
Decision tree data structure.
Definition tree.h:69
void SetNumericSplit(std::int32_t nid, std::int32_t split_index, double threshold)
Create a numerical split.
void SetLeftChild(std::int32_t nid, std::int32_t left_child)
Identify left child node.
Definition tree.h:629
std::int32_t LeftChild(std::int32_t nid) const
Index of the node's left child.
Definition tree.h:304
void PredictLeafIndexInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
void CloneFromTree(Tree *tree)
Copy the structure and parameters of another tree. If the Tree object calling this method already has...
std::int32_t RightChild(std::int32_t nid) const
Index of the node's right child.
Definition tree.h:312
std::int32_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition tree.h:619
void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index, std::vector< std::uint32_t > const &category_list)
Create a categorical split.
std::int32_t SplitIndex(std::int32_t nid) const
Feature index defining the node's split rule.
Definition tree.h:328
void PredictLeafIndexInplace(Eigen::MatrixXd &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsLeaf(std::int32_t nid) const
Whether the node is a leaf node.
Definition tree.h:336
bool IsNumericSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:512
void WalkTree(Func func) const
Iterate through all nodes in this tree.
Definition tree.h:244
double SumSquaredLeafValues() const
Sum of squared values for all leaves in a tree.
Definition tree.h:456
std::int32_t DefaultChild(std::int32_t nid) const
Index of the node's "default" child (potentially used in the case of a missing feature at prediction ...
Definition tree.h:320
void AddValueToLeaves(double constant_value)
Add a constant value to every leaf of a tree. If leaves are multi-dimensional, constant_value will be...
Definition tree.h:209
int AllocNode()
Allocate a new node and return the node's ID.
void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a numeric split rule.
std::vector< double > LeafVector(std::int32_t nid) const
Get vector-valued parameters of a node (typically leaf)
Definition tree.h:420
std::int32_t Parent(std::int32_t nid) const
Index of the node's parent.
Definition tree.h:296
void PredictLeafIndexInplace(ForestDataset *dataset, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsCategoricalSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:520
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition tree.h:527
double LeafValue(std::int32_t nid, std::int32_t dim_id) const
Get parameter value of a node (typically though not necessarily a leaf node) at a given output dimens...
Definition tree.h:369
void from_json(const json &tree_json)
Load from JSON.
std::int32_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition tree.h:609
TreeNodeType NodeType(std::int32_t nid) const
Get the type of a node (i.e. numeric split, categorical split, leaf)
Definition tree.h:504
double SumSquaredNodeValues(std::int32_t nid) const
Sum of squared parameter values for a given node (typically though not necessarily a leaf node)
Definition tree.h:436
void DeleteNode(std::int32_t nid)
Deletes node indexed by node ID.
std::int32_t MaxLeafDepth() const
Get maximum depth of all of the leaf nodes.
Definition tree.h:386
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, double left_value, double right_value)
Expand a node based on a generic split rule.
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a generic split rule.
std::int32_t OutputDimension() const
Dimension of tree output.
Definition tree.h:281
bool HasVectorOutput() const
Whether or not a tree has vector output.
Definition tree.h:274
void MultiplyLeavesByValue(double constant_multiple)
Multiply every leaf of a tree by a constant value. If leaves are multi-dimensional,...
Definition tree.h:226
std::int32_t NumDeletedNodes() const noexcept
Get the total number of deleted nodes in this tree.
Definition tree.h:614
void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify two child nodes of the node and the corresponding parent node of the child nodes.
Definition tree.h:648
std::vector< std::uint32_t > CategoryList(std::int32_t nid) const
Get list of all categories belonging to the left child node. Categories are integers ranging from 0 t...
Definition tree.h:487
void SetParent(std::int32_t child_node, std::int32_t parent_node)
Identify parent node.
Definition tree.h:658
void SetLeaf(std::int32_t nid, double value)
Set the leaf value of the node.
bool IsRoot()
Whether or not a tree is a "stump" consisting of a single root node.
Definition tree.h:114
double LeafValue(std::int32_t nid) const
Get parameter value of a node (typically though not necessarily a leaf node)
Definition tree.h:360
void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify parent node of the left and right node ids.
Definition tree.h:668
std::vector< std::int32_t > const & GetInternalNodes() const
Get indices of all internal nodes.
Definition tree.h:560
std::vector< std::int32_t > const & GetLeafParents() const
Get indices of all leaf parent nodes.
Definition tree.h:574
std::vector< std::int32_t > GetNodes()
Get indices of all valid (non-deleted) nodes.
Definition tree.h:581
bool HasLeafVector(std::int32_t nid) const
Tests whether the leaf node has a non-empty leaf vector.
Definition tree.h:468
void Init(int output_dimension=1, bool is_log_scale=false)
Initialize the tree with a single root node.
void CollapseToLeaf(std::int32_t nid, std::vector< double > value_vector)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:191
double Threshold(std::int32_t nid) const
Get split threshold of the node.
Definition tree.h:476
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, double left_value, double right_value)
Expand a node based on a categorical split rule.
bool IsRoot(std::int32_t nid) const
Whether the node is root.
Definition tree.h:344
void Reset()
Reset tree to empty vectors and default values of boolean / integer variables.
bool IsDeleted(std::int32_t nid) const
Whether the node has been deleted.
Definition tree.h:352
void ExpandNode(std::int32_t nid, int split_index, double split_value, double left_value, double right_value)
Expand a node based on a numeric split rule.
void SetRightChild(std::int32_t nid, std::int32_t right_child)
Identify right child node.
Definition tree.h:638
std::vector< std::int32_t > const & GetLeaves() const
Get indices of all leaf nodes.
Definition tree.h:567
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a categorical split rule.
bool IsLogScale() const
Whether or not tree parameters should be exponentiated at prediction time.
Definition tree.h:288
void CollapseToLeaf(std::int32_t nid, double value)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:152
void SetLeafVector(std::int32_t nid, std::vector< double > const &leaf_vector)
Set the leaf vector of the node; useful for multi-output trees.
std::int32_t GetDepth(std::int32_t nid) const
Get the depth of a node.
Definition tree.h:597
json to_json()
Convert tree to JSON and return JSON in-memory.
bool SplitTrueNumeric(double fvalue, double threshold)
Determine whether an observation produces a "true" value in a numeric split node.
Definition tree.h:838
int NextNodeCategorical(double fvalue, std::vector< std::uint32_t > const &category_list, int left_child, int right_child)
Return left or right node id based on a categorical split.
Definition tree.h:882
bool SplitTrueCategorical(double fvalue, std::vector< std::uint32_t > const &category_list)
Determine whether an observation produces a "true" value in a categorical split node.
Definition tree.h:848
int EvaluateTree(Tree const &tree, Eigen::MatrixXd &data, int row)
Definition tree.h:893
int NextNodeNumeric(double fvalue, double threshold, int left_child, int right_child)
Return left or right node id based on a numeric split.
Definition tree.h:872
bool operator==(const Tree &lhs, const Tree &rhs)
Comparison operator for trees.
Definition tree.h:807
bool RowSplitLeft(Eigen::MatrixXd &covariates, int row, int split_index, double split_value)
Determine whether a given observation is "true" at a split proposed by split_index and split_value.
Definition tree.h:944
Definition category_tracker.h:40
std::string TreeNodeTypeToString(TreeNodeType type)
Get string representation of TreeNodeType.
TreeNodeType
Tree node type.
Definition tree.h:27
TreeNodeType TreeNodeTypeFromString(std::string const &name)
Get NodeType from string.