StochTree 0.0.1
Loading...
Searching...
No Matches
tree.h
1
6#ifndef STOCHTREE_TREE_H_
7#define STOCHTREE_TREE_H_
8
9#include <nlohmann/json.hpp>
10#include <stochtree/data.h>
11#include <stochtree/log.h>
12#include <stochtree/meta.h>
13#include <Eigen/Dense>
14
15#include <cstdint>
16#include <map>
17#include <optional>
18#include <set>
19#include <stack>
20#include <string>
21
22using json = nlohmann::json;
23
24namespace StochTree {
25
28 kLeafNode = 0,
29 kNumericalSplitNode = 1,
30 kCategoricalSplitNode = 2
31};
32
33// template<typename T>
34// int enum_to_int(T& input_enum) {
35// return static_cast<int>(input_enum);
36// }
37
38// template<typename T>
39// T json_to_enum(json& input_json) {
40// return static_cast<T>(input_json);
41// }
42
45
47TreeNodeType TreeNodeTypeFromString(std::string const& name);
48
49enum FeatureSplitType {
50 kNumericSplit,
51 kOrderedCategoricalSplit,
52 kUnorderedCategoricalSplit
53};
54
56class TreeSplit;
57
69class Tree {
70 public:
71 static constexpr std::int32_t kInvalidNodeId{-1};
72 static constexpr std::int32_t kDeletedNodeMarker = std::numeric_limits<node_t>::max();
73 static constexpr std::int32_t kRoot{0};
74
75 Tree() = default;
76 // ~Tree() = default;
77 Tree(Tree const&) = delete;
78 Tree& operator=(Tree const&) = delete;
79 Tree(Tree&&) noexcept = default;
80 Tree& operator=(Tree&&) noexcept = default;
87 void CloneFromTree(Tree* tree);
88
89 std::int32_t num_nodes{0};
90 std::int32_t num_deleted_nodes{0};
91
93 void Reset();
95 void Init(int output_dimension = 1, bool is_log_scale = false);
97 int AllocNode();
99 void DeleteNode(std::int32_t nid);
101 void ExpandNode(std::int32_t nid, int split_index, double split_value, double left_value, double right_value);
103 void ExpandNode(std::int32_t nid, int split_index, std::vector<std::uint32_t> const& categorical_indices, double left_value, double right_value);
105 void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
107 void ExpandNode(std::int32_t nid, int split_index, std::vector<std::uint32_t> const& categorical_indices, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
109 void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, double left_value, double right_value);
111 void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
112
114 inline bool IsRoot() {return leaves_.size() == 1;}
115
117 json to_json();
123 void from_json(const json& tree_json);
124
125 void ChangeToLeaf(std::int32_t nid, double value) {
126 CHECK(this->IsLeaf(this->LeftChild(nid)));
127 CHECK(this->IsLeaf(this->RightChild(nid)));
128 this->DeleteNode(this->LeftChild(nid));
129 this->DeleteNode(this->RightChild(nid));
130 this->SetLeaf(nid, value);
131
132 // Add nid to leaves and remove from internal nodes and leaf parents (if it was there)
133 leaves_.push_back(nid);
134 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
135 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
136
137 // Check if the other child of nid's parent node is also a leaf, if so, add parent back to leaf parents
138 // TODO refactor and add this to the multivariate case as well
139 if (!IsRoot(nid)) {
140 int parent_id = Parent(nid);
141 if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))){
142 leaf_parents_.push_back(parent_id);
143 }
144 }
145 }
146
152 void CollapseToLeaf(std::int32_t nid, double value) {
153 CHECK_EQ(output_dimension_, 1);
154 if (this->IsLeaf(nid)) return;
155 if (!this->IsLeaf(this->LeftChild(nid))) {
156 CollapseToLeaf(this->LeftChild(nid), value);
157 }
158 if (!this->IsLeaf(this->RightChild(nid))) {
159 CollapseToLeaf(this->RightChild(nid), value);
160 }
161 this->ChangeToLeaf(nid, value);
162 }
163
164 void ChangeToLeaf(std::int32_t nid, std::vector<double> value_vector) {
165 CHECK(this->IsLeaf(this->LeftChild(nid)));
166 CHECK(this->IsLeaf(this->RightChild(nid)));
167 this->DeleteNode(this->LeftChild(nid));
168 this->DeleteNode(this->RightChild(nid));
169 this->SetLeafVector(nid, value_vector);
170
171 // Add nid to leaves and remove from internal nodes and leaf parents (if it was there)
172 leaves_.push_back(nid);
173 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
174 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
175
176 // Check if the other child of nid's parent node is also a leaf, if so, add parent back to leaf parents
177 // TODO refactor and add this to the multivariate case as well
178 if (!IsRoot(nid)) {
179 int parent_id = Parent(nid);
180 if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))){
181 leaf_parents_.push_back(parent_id);
182 }
183 }
184 }
185
191 void CollapseToLeaf(std::int32_t nid, std::vector<double> value_vector) {
192 CHECK_GT(output_dimension_, 1);
193 CHECK_EQ(output_dimension_, value_vector.size());
194 if (this->IsLeaf(nid)) return;
195 if (!this->IsLeaf(this->LeftChild(nid))) {
196 CollapseToLeaf(this->LeftChild(nid), value_vector);
197 }
198 if (!this->IsLeaf(this->RightChild(nid))) {
199 CollapseToLeaf(this->RightChild(nid), value_vector);
200 }
201 this->ChangeToLeaf(nid, value_vector);
202 }
203
210 template <typename Func> void WalkTree(Func func) const {
211 std::stack<std::int32_t> nodes;
212 nodes.push(kRoot);
213 auto &self = *this;
214 while (!nodes.empty()) {
215 auto nidx = nodes.top();
216 nodes.pop();
217 if (!func(nidx)) {
218 return;
219 }
220 auto left = self.LeftChild(nidx);
221 auto right = self.RightChild(nidx);
222 if (left != Tree::kInvalidNodeId) {
223 nodes.push(left);
224 }
225 if (right != Tree::kInvalidNodeId) {
226 nodes.push(right);
227 }
228 }
229 }
230
231 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices);
232 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices, Eigen::MatrixXd& basis);
233 double PredictFromNode(std::int32_t node_id);
234 double PredictFromNode(std::int32_t node_id, Eigen::MatrixXd& basis, int row_idx);
235
240 bool HasVectorOutput() const {
241 return output_dimension_ > 1;
242 }
243
247 std::int32_t OutputDimension() const {
248 return output_dimension_;
249 }
250
254 bool IsLogScale() const {
255 return is_log_scale_;
256 }
257
262 std::int32_t Parent(std::int32_t nid) const {
263 return parent_[nid];
264 }
265
270 std::int32_t LeftChild(std::int32_t nid) const {
271 return cleft_[nid];
272 }
273
278 std::int32_t RightChild(std::int32_t nid) const {
279 return cright_[nid];
280 }
281
286 std::int32_t DefaultChild(std::int32_t nid) const {
287 return cleft_[nid];
288 }
289
294 std::int32_t SplitIndex(std::int32_t nid) const {
295 return split_index_[nid];
296 }
297
302 bool IsLeaf(std::int32_t nid) const {
303 return cleft_[nid] == kInvalidNodeId;
304 }
305
310 bool IsRoot(std::int32_t nid) const {
311 return parent_[nid] == kInvalidNodeId;
312 }
313
318 bool IsDeleted(std::int32_t nid) const {
319 return node_deleted_[nid];
320 }
321
326 double LeafValue(std::int32_t nid) const {
327 return leaf_value_[nid];
328 }
329
335 double LeafValue(std::int32_t nid, std::int32_t dim_id) const {
336 CHECK_LT(dim_id, output_dimension_);
337 if (output_dimension_ == 1 && dim_id == 0) {
338 return leaf_value_[nid];
339 } else {
340 std::size_t const offset_begin = leaf_vector_begin_[nid];
341 std::size_t const offset_end = leaf_vector_end_[nid];
342 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
343 Log::Fatal("No leaf vector set for node nid");
344 }
345 return leaf_vector_[offset_begin + dim_id];
346 }
347 }
348
352 std::int32_t MaxLeafDepth() const {
353 std::int32_t max_depth = 0;
354 std::stack<std::int32_t> nodes;
355 std::stack<std::int32_t> node_depths;
356 nodes.push(kRoot);
357 node_depths.push(0);
358 auto &self = *this;
359 while (!nodes.empty()) {
360 auto nidx = nodes.top();
361 nodes.pop();
362 auto node_depth = node_depths.top();
363 node_depths.pop();
364 bool valid_node = !self.IsDeleted(nidx);
365 if (valid_node) {
366 if (node_depth > max_depth) max_depth = node_depth;
367 auto left = self.LeftChild(nidx);
368 auto right = self.RightChild(nidx);
369 if (left != Tree::kInvalidNodeId) {
370 nodes.push(left);
371 node_depths.push(node_depth+1);
372 }
373 if (right != Tree::kInvalidNodeId) {
374 nodes.push(right);
375 node_depths.push(node_depth+1);
376 }
377 }
378 }
379 return max_depth;
380 }
381
386 std::vector<double> LeafVector(std::int32_t nid) const {
387 std::size_t const offset_begin = leaf_vector_begin_[nid];
388 std::size_t const offset_end = leaf_vector_end_[nid];
389 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
390 // Return empty vector, to indicate the lack of leaf vector
391 return std::vector<double>();
392 }
393 return std::vector<double>(&leaf_vector_[offset_begin], &leaf_vector_[offset_end]);
394 // Use unsafe access here, since we may need to take the address of one past the last
395 // element, to follow with the range semantic of std::vector<>.
396 }
397
402 double SumSquaredNodeValues(std::int32_t nid) const {
403 if (output_dimension_ == 1) {
404 return std::pow(leaf_value_[nid], 2.0);
405 } else {
406 double result = 0.;
407 std::size_t const offset_begin = leaf_vector_begin_[nid];
408 std::size_t const offset_end = leaf_vector_end_[nid];
409 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
410 Log::Fatal("No leaf vector set for node nid");
411 }
412 for (std::size_t i = offset_begin; i < offset_end; i++) {
413 result += std::pow(leaf_vector_[i], 2.0);
414 }
415 return result;
416 }
417 }
418
422 double SumSquaredLeafValues() const {
423 double result = 0.;
424 for (auto& leaf : leaves_) {
425 result += SumSquaredNodeValues(leaf);
426 }
427 return result;
428 }
429
434 bool HasLeafVector(std::int32_t nid) const {
435 return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
436 }
437
442 double Threshold(std::int32_t nid) const {
443 return threshold_[nid];
444 }
445
453 std::vector<std::uint32_t> CategoryList(std::int32_t nid) const {
454 std::size_t const offset_begin = category_list_begin_[nid];
455 std::size_t const offset_end = category_list_end_[nid];
456 if (offset_begin >= category_list_.size() || offset_end > category_list_.size()) {
457 // Return empty vector, to indicate the lack of any category list
458 // The node might be a numerical split
459 return {};
460 }
461 return std::vector<std::uint32_t>(&category_list_[offset_begin], &category_list_[offset_end]);
462 // Use unsafe access here, since we may need to take the address of one past the last
463 // element, to follow with the range semantic of std::vector<>.
464 }
465
470 TreeNodeType NodeType(std::int32_t nid) const {
471 return node_type_[nid];
472 }
473
478 bool IsNumericSplitNode(std::int32_t nid) const {
479 return node_type_[nid] == TreeNodeType::kNumericalSplitNode;
480 }
481
486 bool IsCategoricalSplitNode(std::int32_t nid) const {
487 return node_type_[nid] == TreeNodeType::kCategoricalSplitNode;
488 }
489
493 bool HasCategoricalSplit() const {
494 return has_categorical_split_;
495 }
496
497 /* \brief Count number of leaves in tree. */
498 [[nodiscard]] std::int32_t NumLeaves() const;
499 [[nodiscard]] std::int32_t NumLeafParents() const;
500 [[nodiscard]] std::int32_t NumSplitNodes() const;
501
502 /* \brief Determine whether nid is leaf parent */
503 [[nodiscard]] bool IsLeafParent(std::int32_t nid) const {
504 // False until we deduce left and right node are
505 // available and both are leaves
506 bool is_left_leaf = false;
507 bool is_right_leaf = false;
508 // Check if node nidx is a leaf, if so, return false
509 bool is_leaf = this->IsLeaf(nid);
510 if (is_leaf){
511 return false;
512 } else {
513 // If nidx is not a leaf, it must have left and right nodes
514 // so we check if those are leaves
515 std::int32_t left_node = LeftChild(nid);
516 std::int32_t right_node = RightChild(nid);
517 is_left_leaf = IsLeaf(left_node);
518 is_right_leaf = IsLeaf(right_node);
519 }
520 return is_left_leaf && is_right_leaf;
521 }
522
526 [[nodiscard]] std::vector<std::int32_t> const& GetInternalNodes() const {
527 return internal_nodes_;
528 }
529
533 [[nodiscard]] std::vector<std::int32_t> const& GetLeaves() const {
534 return leaves_;
535 }
536
540 [[nodiscard]] std::vector<std::int32_t> const& GetLeafParents() const {
541 return leaf_parents_;
542 }
543
547 [[nodiscard]] std::vector<std::int32_t> GetNodes() {
548 std::vector<std::int32_t> output;
549 auto const& self = *this;
550 this->WalkTree([&output, &self](std::int32_t nidx) {
551 if (!self.IsDeleted(nidx)) {
552 output.push_back(nidx);
553 }
554 return true;
555 });
556 return output;
557 }
558
563 [[nodiscard]] std::int32_t GetDepth(std::int32_t nid) const {
564 int depth = 0;
565 while (!IsRoot(nid)) {
566 ++depth;
567 nid = Parent(nid);
568 }
569 return depth;
570 }
571
575 [[nodiscard]] std::int32_t NumNodes() const noexcept { return num_nodes; }
576
580 [[nodiscard]] std::int32_t NumDeletedNodes() const noexcept { return num_deleted_nodes; }
581
585 [[nodiscard]] std::int32_t NumValidNodes() const noexcept {
586 return num_nodes - num_deleted_nodes;
587 }
588
595 void SetLeftChild(std::int32_t nid, std::int32_t left_child) {
596 cleft_[nid] = left_child;
597 }
598
604 void SetRightChild(std::int32_t nid, std::int32_t right_child) {
605 cright_[nid] = right_child;
606 }
607
614 void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
615 SetLeftChild(nid, left_child);
616 SetRightChild(nid, right_child);
617 }
618
624 void SetParent(std::int32_t child_node, std::int32_t parent_node) {
625 parent_[child_node] = parent_node;
626 }
627
634 void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
635 SetParent(left_child, nid);
636 SetParent(right_child, nid);
637 }
638
646 std::int32_t nid, std::int32_t split_index, double threshold);
647
656 void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index,
657 std::vector<std::uint32_t> const& category_list);
658
664 void SetLeaf(std::int32_t nid, double value);
665
671 void SetLeafVector(std::int32_t nid, std::vector<double> const& leaf_vector);
672
693 void PredictLeafIndexInplace(ForestDataset* dataset, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
694
715 void PredictLeafIndexInplace(Eigen::MatrixXd& covariates, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
716
737 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
738
739 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
740 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
741 int column_ind, int32_t offset, int32_t max_leaf);
742
743 // Node info
744 std::vector<TreeNodeType> node_type_;
745 std::vector<std::int32_t> parent_;
746 std::vector<std::int32_t> cleft_;
747 std::vector<std::int32_t> cright_;
748 std::vector<std::int32_t> split_index_;
749 std::vector<double> leaf_value_;
750 std::vector<double> threshold_;
751 std::vector<bool> node_deleted_;
752 std::vector<std::int32_t> internal_nodes_;
753 std::vector<std::int32_t> leaves_;
754 std::vector<std::int32_t> leaf_parents_;
755 std::vector<std::int32_t> deleted_nodes_;
756
757 // Leaf vector
758 std::vector<double> leaf_vector_;
759 std::vector<std::uint64_t> leaf_vector_begin_;
760 std::vector<std::uint64_t> leaf_vector_end_;
761
762 // Category list
763 std::vector<std::uint32_t> category_list_;
764 std::vector<std::uint64_t> category_list_begin_;
765 std::vector<std::uint64_t> category_list_end_;
766
767 bool has_categorical_split_{false};
768 int output_dimension_{1};
769 bool is_log_scale_{false};
770};
771
773inline bool operator==(const Tree& lhs, const Tree& rhs) {
774 return (
775 (lhs.has_categorical_split_ == rhs.has_categorical_split_) &&
776 (lhs.output_dimension_ == rhs.output_dimension_) &&
777 (lhs.is_log_scale_ == rhs.is_log_scale_) &&
778 (lhs.node_type_ == rhs.node_type_) &&
779 (lhs.parent_ == rhs.parent_) &&
780 (lhs.cleft_ == rhs.cleft_) &&
781 (lhs.cright_ == rhs.cright_) &&
782 (lhs.split_index_ == rhs.split_index_) &&
783 (lhs.leaf_value_ == rhs.leaf_value_) &&
784 (lhs.threshold_ == rhs.threshold_) &&
785 (lhs.internal_nodes_ == rhs.internal_nodes_) &&
786 (lhs.leaves_ == rhs.leaves_) &&
787 (lhs.leaf_parents_ == rhs.leaf_parents_) &&
788 (lhs.deleted_nodes_ == rhs.deleted_nodes_) &&
789 (lhs.leaf_vector_ == rhs.leaf_vector_) &&
790 (lhs.leaf_vector_begin_ == rhs.leaf_vector_begin_) &&
791 (lhs.leaf_vector_end_ == rhs.leaf_vector_end_) &&
792 (lhs.category_list_ == rhs.category_list_) &&
793 (lhs.category_list_begin_ == rhs.category_list_begin_) &&
794 (lhs.category_list_end_ == rhs.category_list_end_)
795 );
796}
797
804inline bool SplitTrueNumeric(double fvalue, double threshold) {
805 return (fvalue <= threshold);
806}
807
814inline bool SplitTrueCategorical(double fvalue, std::vector<std::uint32_t> const& category_list) {
815 bool category_matched;
816 // A valid (integer) category must satisfy two criteria:
817 // 1) it must be exactly representable as double
818 // 2) it must fit into uint32_t
819 auto max_representable_int
820 = std::min(static_cast<double>(std::numeric_limits<std::uint32_t>::max()),
821 static_cast<double>(std::uint64_t(1) << std::numeric_limits<double>::digits));
822 if (fvalue < 0 || std::fabs(fvalue) > max_representable_int) {
823 category_matched = false;
824 } else {
825 auto const category_value = static_cast<std::uint32_t>(fvalue);
826 category_matched = (std::find(category_list.begin(), category_list.end(), category_value)
827 != category_list.end());
828 }
829 return category_matched;
830}
831
838inline int NextNodeNumeric(double fvalue, double threshold, int left_child, int right_child) {
839 return (SplitTrueNumeric(fvalue, threshold) ? left_child : right_child);
840}
841
848inline int NextNodeCategorical(double fvalue, std::vector<std::uint32_t> const& category_list, int left_child, int right_child) {
849 return SplitTrueCategorical(fvalue, category_list) ? left_child : right_child;
850}
851
859inline int EvaluateTree(Tree const& tree, Eigen::MatrixXd& data, int row) {
860 int node_id = 0;
861 while (!tree.IsLeaf(node_id)) {
862 auto const split_index = tree.SplitIndex(node_id);
863 double const fvalue = data(row, split_index);
864 if (std::isnan(fvalue)) {
865 node_id = tree.DefaultChild(node_id);
866 } else {
867 if (tree.NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
868 node_id = NextNodeCategorical(fvalue, tree.CategoryList(node_id),
869 tree.LeftChild(node_id), tree.RightChild(node_id));
870 } else {
871 node_id = NextNodeNumeric(fvalue, tree.Threshold(node_id), tree.LeftChild(node_id), tree.RightChild(node_id));
872 }
873 }
874 }
875 return node_id;
876}
877
885inline int EvaluateTree(Tree const& tree, Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& data, int row) {
886 int node_id = 0;
887 while (!tree.IsLeaf(node_id)) {
888 auto const split_index = tree.SplitIndex(node_id);
889 double const fvalue = data(row, split_index);
890 if (std::isnan(fvalue)) {
891 node_id = tree.DefaultChild(node_id);
892 } else {
893 if (tree.NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
894 node_id = NextNodeCategorical(fvalue, tree.CategoryList(node_id),
895 tree.LeftChild(node_id), tree.RightChild(node_id));
896 } else {
897 node_id = NextNodeNumeric(fvalue, tree.Threshold(node_id), tree.LeftChild(node_id), tree.RightChild(node_id));
898 }
899 }
900 }
901 return node_id;
902}
903
910inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, double split_value) {
911 double const fvalue = covariates(row, split_index);
912 return SplitTrueNumeric(fvalue, split_value);
913}
914
921inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, std::vector<std::uint32_t> const& category_list) {
922 double const fvalue = covariates(row, split_index);
923 return SplitTrueCategorical(fvalue, category_list);
924}
925
928 public:
929 TreeSplit() {}
935 TreeSplit(double split_value) {
936 numeric_ = true;
937 split_value_ = split_value;
938 split_set_ = true;
939 }
945 TreeSplit(std::vector<std::uint32_t>& split_categories) {
946 numeric_ = false;
947 split_categories_ = split_categories;
948 split_set_ = true;
949 }
950 ~TreeSplit() {}
951 bool SplitSet() {return split_set_;}
953 bool NumericSplit() {return numeric_;}
959 bool SplitTrue(double fvalue) {
960 if (numeric_) return SplitTrueNumeric(fvalue, split_value_);
961 else return SplitTrueCategorical(fvalue, split_categories_);
962 }
964 double SplitValue() {return split_value_;}
966 std::vector<std::uint32_t> SplitCategories() {return split_categories_;}
967 private:
968 bool split_set_{false};
969 bool numeric_;
970 double split_value_;
971 std::vector<std::uint32_t> split_categories_;
972};
973
// end of tree_group
975
976} // namespace StochTree
977
978#endif // STOCHTREE_TREE_H_
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:927
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:959
bool NumericSplit()
Whether or not a TreeSplit rule is numeric.
Definition tree.h:953
std::vector< std::uint32_t > SplitCategories()
Categories defining a TreeSplit object.
Definition tree.h:966
TreeSplit(std::vector< std::uint32_t > &split_categories)
Construct a categorical TreeSplit.
Definition tree.h:945
double SplitValue()
Numeric cutoff value defining a TreeSplit object.
Definition tree.h:964
TreeSplit(double split_value)
Construct a numeric TreeSplit.
Definition tree.h:935
Decision tree data structure.
Definition tree.h:69
void SetNumericSplit(std::int32_t nid, std::int32_t split_index, double threshold)
Create a numerical split.
void SetLeftChild(std::int32_t nid, std::int32_t left_child)
Identify left child node.
Definition tree.h:595
std::int32_t LeftChild(std::int32_t nid) const
Index of the node's left child.
Definition tree.h:270
void PredictLeafIndexInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
void CloneFromTree(Tree *tree)
Copy the structure and parameters of another tree. If the Tree object calling this method already has...
std::int32_t RightChild(std::int32_t nid) const
Index of the node's right child.
Definition tree.h:278
std::int32_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition tree.h:585
void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index, std::vector< std::uint32_t > const &category_list)
Create a categorical split.
std::int32_t SplitIndex(std::int32_t nid) const
Feature index defining the node's split rule.
Definition tree.h:294
void PredictLeafIndexInplace(Eigen::MatrixXd &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsLeaf(std::int32_t nid) const
Whether the node is a leaf node.
Definition tree.h:302
bool IsNumericSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:478
void WalkTree(Func func) const
Iterate through all nodes in this tree.
Definition tree.h:210
double SumSquaredLeafValues() const
Sum of squared values for all leaves in a tree.
Definition tree.h:422
std::int32_t DefaultChild(std::int32_t nid) const
Index of the node's "default" child (potentially used in the case of a missing feature at prediction ...
Definition tree.h:286
int AllocNode()
Allocate a new node and return the node's ID.
void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a numeric split rule.
std::vector< double > LeafVector(std::int32_t nid) const
Get vector-valued parameters of a node (typically leaf)
Definition tree.h:386
std::int32_t Parent(std::int32_t nid) const
Index of the node's parent.
Definition tree.h:262
void PredictLeafIndexInplace(ForestDataset *dataset, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsCategoricalSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:486
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition tree.h:493
double LeafValue(std::int32_t nid, std::int32_t dim_id) const
Get parameter value of a node (typically though not necessarily a leaf node) at a given output dimens...
Definition tree.h:335
void from_json(const json &tree_json)
Load from JSON.
std::int32_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition tree.h:575
TreeNodeType NodeType(std::int32_t nid) const
Get the type of a node (i.e. numeric split, categorical split, leaf)
Definition tree.h:470
double SumSquaredNodeValues(std::int32_t nid) const
Sum of squared parameter values for a given node (typically though not necessarily a leaf node)
Definition tree.h:402
void DeleteNode(std::int32_t nid)
Deletes node indexed by node ID.
std::int32_t MaxLeafDepth() const
Get maximum depth of all of the leaf nodes.
Definition tree.h:352
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, double left_value, double right_value)
Expand a node based on a generic split rule.
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a generic split rule.
std::int32_t OutputDimension() const
Dimension of tree output.
Definition tree.h:247
bool HasVectorOutput() const
Whether or not a tree has vector output.
Definition tree.h:240
std::int32_t NumDeletedNodes() const noexcept
Get the total number of deleted nodes in this tree.
Definition tree.h:580
void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify two child nodes of the node and the corresponding parent node of the child nodes.
Definition tree.h:614
std::vector< std::uint32_t > CategoryList(std::int32_t nid) const
Get list of all categories belonging to the left child node. Categories are integers ranging from 0 t...
Definition tree.h:453
void SetParent(std::int32_t child_node, std::int32_t parent_node)
Identify parent node.
Definition tree.h:624
void SetLeaf(std::int32_t nid, double value)
Set the leaf value of the node.
bool IsRoot()
Whether or not a tree is a "stump" consisting of a single root node.
Definition tree.h:114
double LeafValue(std::int32_t nid) const
Get parameter value of a node (typically though not necessarily a leaf node)
Definition tree.h:326
void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify parent node of the left and right node ids.
Definition tree.h:634
std::vector< std::int32_t > const & GetInternalNodes() const
Get indices of all internal nodes.
Definition tree.h:526
std::vector< std::int32_t > const & GetLeafParents() const
Get indices of all leaf parent nodes.
Definition tree.h:540
std::vector< std::int32_t > GetNodes()
Get indices of all valid (non-deleted) nodes.
Definition tree.h:547
bool HasLeafVector(std::int32_t nid) const
Tests whether the leaf node has a non-empty leaf vector.
Definition tree.h:434
void Init(int output_dimension=1, bool is_log_scale=false)
Initialize the tree with a single root node.
void CollapseToLeaf(std::int32_t nid, std::vector< double > value_vector)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:191
double Threshold(std::int32_t nid) const
Get split threshold of the node.
Definition tree.h:442
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, double left_value, double right_value)
Expand a node based on a categorical split rule.
bool IsRoot(std::int32_t nid) const
Whether the node is root.
Definition tree.h:310
void Reset()
Reset tree to empty vectors and default values of boolean / integer variables.
bool IsDeleted(std::int32_t nid) const
Whether the node has been deleted.
Definition tree.h:318
void ExpandNode(std::int32_t nid, int split_index, double split_value, double left_value, double right_value)
Expand a node based on a numeric split rule.
void SetRightChild(std::int32_t nid, std::int32_t right_child)
Identify right child node.
Definition tree.h:604
std::vector< std::int32_t > const & GetLeaves() const
Get indices of all leaf nodes.
Definition tree.h:533
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a categorical split rule.
bool IsLogScale() const
Whether or not tree parameters should be exponentiated at prediction time.
Definition tree.h:254
void CollapseToLeaf(std::int32_t nid, double value)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:152
void SetLeafVector(std::int32_t nid, std::vector< double > const &leaf_vector)
Set the leaf vector of the node; useful for multi-output trees.
std::int32_t GetDepth(std::int32_t nid) const
Get the depth of a node.
Definition tree.h:563
json to_json()
Convert tree to JSON and return JSON in-memory.
bool SplitTrueNumeric(double fvalue, double threshold)
Determine whether an observation produces a "true" value in a numeric split node.
Definition tree.h:804
int NextNodeCategorical(double fvalue, std::vector< std::uint32_t > const &category_list, int left_child, int right_child)
Return left or right node id based on a categorical split.
Definition tree.h:848
bool SplitTrueCategorical(double fvalue, std::vector< std::uint32_t > const &category_list)
Determine whether an observation produces a "true" value in a categorical split node.
Definition tree.h:814
int EvaluateTree(Tree const &tree, Eigen::MatrixXd &data, int row)
Definition tree.h:859
int NextNodeNumeric(double fvalue, double threshold, int left_child, int right_child)
Return left or right node id based on a numeric split.
Definition tree.h:838
bool operator==(const Tree &lhs, const Tree &rhs)
Comparison operator for trees.
Definition tree.h:773
bool RowSplitLeft(Eigen::MatrixXd &covariates, int row, int split_index, double split_value)
Determine whether a given observation is "true" at a split proposed by split_index and split_value.
Definition tree.h:910
Definition category_tracker.h:40
std::string TreeNodeTypeToString(TreeNodeType type)
Get string representation of TreeNodeType.
TreeNodeType
Tree node type.
Definition tree.h:27
TreeNodeType TreeNodeTypeFromString(std::string const &name)
Get NodeType from string.