StochTree 0.1.1
Loading...
Searching...
No Matches
tree.h
1
6#ifndef STOCHTREE_TREE_H_
7#define STOCHTREE_TREE_H_
8
9#include <nlohmann/json.hpp>
10#include <stochtree/data.h>
11#include <stochtree/log.h>
12#include <stochtree/meta.h>
13#include <Eigen/Dense>
14
15#include <cstdint>
16#include <stack>
17#include <string>
18
19using json = nlohmann::json;
20
21namespace StochTree {
22
25 kLeafNode = 0,
26 kNumericalSplitNode = 1,
27 kCategoricalSplitNode = 2
28};
29
30// template<typename T>
31// int enum_to_int(T& input_enum) {
32// return static_cast<int>(input_enum);
33// }
34
35// template<typename T>
36// T json_to_enum(json& input_json) {
37// return static_cast<T>(input_json);
38// }
39
42
44TreeNodeType TreeNodeTypeFromString(std::string const& name);
45
46enum FeatureSplitType {
47 kNumericSplit,
48 kOrderedCategoricalSplit,
49 kUnorderedCategoricalSplit
50};
51
53class TreeSplit;
54
66class Tree {
67 public:
68 static constexpr std::int32_t kInvalidNodeId{-1};
69 static constexpr std::int32_t kDeletedNodeMarker = std::numeric_limits<node_t>::max();
70 static constexpr std::int32_t kRoot{0};
71
72 Tree() = default;
73 // ~Tree() = default;
74 Tree(Tree const&) = delete;
75 Tree& operator=(Tree const&) = delete;
76 Tree(Tree&&) noexcept = default;
77 Tree& operator=(Tree&&) noexcept = default;
84 void CloneFromTree(Tree* tree);
85
86 std::int32_t num_nodes{0};
87 std::int32_t num_deleted_nodes{0};
88
90 void Reset();
92 void Init(int output_dimension = 1, bool is_log_scale = false);
94 int AllocNode();
96 void DeleteNode(std::int32_t nid);
98 void ExpandNode(std::int32_t nid, int split_index, double split_value, double left_value, double right_value);
100 void ExpandNode(std::int32_t nid, int split_index, std::vector<std::uint32_t> const& categorical_indices, double left_value, double right_value);
102 void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
104 void ExpandNode(std::int32_t nid, int split_index, std::vector<std::uint32_t> const& categorical_indices, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
106 void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, double left_value, double right_value);
108 void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
109
111 inline bool IsRoot() {return leaves_.size() == 1;}
112
114 json to_json();
120 void from_json(const json& tree_json);
121
122 void ChangeToLeaf(std::int32_t nid, double value) {
123 CHECK(this->IsLeaf(this->LeftChild(nid)));
124 CHECK(this->IsLeaf(this->RightChild(nid)));
125 this->DeleteNode(this->LeftChild(nid));
126 this->DeleteNode(this->RightChild(nid));
127 this->SetLeaf(nid, value);
128
129 // Add nid to leaves and remove from internal nodes and leaf parents (if it was there)
130 leaves_.push_back(nid);
131 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
132 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
133
134 // Check if the other child of nid's parent node is also a leaf, if so, add parent back to leaf parents
135 // TODO refactor and add this to the multivariate case as well
136 if (!IsRoot(nid)) {
137 int parent_id = Parent(nid);
138 if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))){
139 leaf_parents_.push_back(parent_id);
140 }
141 }
142 }
143
149 void CollapseToLeaf(std::int32_t nid, double value) {
150 CHECK_EQ(output_dimension_, 1);
151 if (this->IsLeaf(nid)) return;
152 if (!this->IsLeaf(this->LeftChild(nid))) {
153 CollapseToLeaf(this->LeftChild(nid), value);
154 }
155 if (!this->IsLeaf(this->RightChild(nid))) {
156 CollapseToLeaf(this->RightChild(nid), value);
157 }
158 this->ChangeToLeaf(nid, value);
159 }
160
161 void ChangeToLeaf(std::int32_t nid, std::vector<double> value_vector) {
162 CHECK(this->IsLeaf(this->LeftChild(nid)));
163 CHECK(this->IsLeaf(this->RightChild(nid)));
164 this->DeleteNode(this->LeftChild(nid));
165 this->DeleteNode(this->RightChild(nid));
166 this->SetLeafVector(nid, value_vector);
167
168 // Add nid to leaves and remove from internal nodes and leaf parents (if it was there)
169 leaves_.push_back(nid);
170 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
171 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
172
173 // Check if the other child of nid's parent node is also a leaf, if so, add parent back to leaf parents
174 // TODO refactor and add this to the multivariate case as well
175 if (!IsRoot(nid)) {
176 int parent_id = Parent(nid);
177 if ((IsLeaf(LeftChild(parent_id))) && (IsLeaf(RightChild(parent_id)))){
178 leaf_parents_.push_back(parent_id);
179 }
180 }
181 }
182
188 void CollapseToLeaf(std::int32_t nid, std::vector<double> value_vector) {
189 CHECK_GT(output_dimension_, 1);
190 CHECK_EQ(output_dimension_, value_vector.size());
191 if (this->IsLeaf(nid)) return;
192 if (!this->IsLeaf(this->LeftChild(nid))) {
193 CollapseToLeaf(this->LeftChild(nid), value_vector);
194 }
195 if (!this->IsLeaf(this->RightChild(nid))) {
196 CollapseToLeaf(this->RightChild(nid), value_vector);
197 }
198 this->ChangeToLeaf(nid, value_vector);
199 }
200
206 void AddValueToLeaves(double constant_value) {
207 if (output_dimension_ == 1) {
208 for (int j = 0; j < leaf_value_.size(); j++) {
209 leaf_value_[j] += constant_value;
210 }
211 } else {
212 for (int j = 0; j < leaf_vector_.size(); j++) {
213 leaf_vector_[j] += constant_value;
214 }
215 }
216 }
217
223 void MultiplyLeavesByValue(double constant_multiple) {
224 if (output_dimension_ == 1) {
225 for (int j = 0; j < leaf_value_.size(); j++) {
226 leaf_value_[j] *= constant_multiple;
227 }
228 } else {
229 for (int j = 0; j < leaf_vector_.size(); j++) {
230 leaf_vector_[j] *= constant_multiple;
231 }
232 }
233 }
234
241 template <typename Func> void WalkTree(Func func) const {
242 std::stack<std::int32_t> nodes;
243 nodes.push(kRoot);
244 auto &self = *this;
245 while (!nodes.empty()) {
246 auto nidx = nodes.top();
247 nodes.pop();
248 if (!func(nidx)) {
249 return;
250 }
251 auto left = self.LeftChild(nidx);
252 auto right = self.RightChild(nidx);
253 if (left != Tree::kInvalidNodeId) {
254 nodes.push(left);
255 }
256 if (right != Tree::kInvalidNodeId) {
257 nodes.push(right);
258 }
259 }
260 }
261
262 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices);
263 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices, Eigen::MatrixXd& basis);
264 double PredictFromNode(std::int32_t node_id);
265 double PredictFromNode(std::int32_t node_id, Eigen::MatrixXd& basis, int row_idx);
266
271 bool HasVectorOutput() const {
272 return output_dimension_ > 1;
273 }
274
278 std::int32_t OutputDimension() const {
279 return output_dimension_;
280 }
281
285 bool IsLogScale() const {
286 return is_log_scale_;
287 }
288
293 std::int32_t Parent(std::int32_t nid) const {
294 return parent_[nid];
295 }
296
301 std::int32_t LeftChild(std::int32_t nid) const {
302 return cleft_[nid];
303 }
304
309 std::int32_t RightChild(std::int32_t nid) const {
310 return cright_[nid];
311 }
312
317 std::int32_t DefaultChild(std::int32_t nid) const {
318 return cleft_[nid];
319 }
320
325 std::int32_t SplitIndex(std::int32_t nid) const {
326 return split_index_[nid];
327 }
328
333 bool IsLeaf(std::int32_t nid) const {
334 return cleft_[nid] == kInvalidNodeId;
335 }
336
341 bool IsRoot(std::int32_t nid) const {
342 return parent_[nid] == kInvalidNodeId;
343 }
344
349 bool IsDeleted(std::int32_t nid) const {
350 return node_deleted_[nid];
351 }
352
357 double LeafValue(std::int32_t nid) const {
358 return leaf_value_[nid];
359 }
360
366 double LeafValue(std::int32_t nid, std::int32_t dim_id) const {
367 CHECK_LT(dim_id, output_dimension_);
368 if (output_dimension_ == 1 && dim_id == 0) {
369 return leaf_value_[nid];
370 } else {
371 std::size_t const offset_begin = leaf_vector_begin_[nid];
372 std::size_t const offset_end = leaf_vector_end_[nid];
373 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
374 Log::Fatal("No leaf vector set for node nid");
375 }
376 return leaf_vector_[offset_begin + dim_id];
377 }
378 }
379
383 std::int32_t MaxLeafDepth() const {
384 std::int32_t max_depth = 0;
385 std::stack<std::int32_t> nodes;
386 std::stack<std::int32_t> node_depths;
387 nodes.push(kRoot);
388 node_depths.push(0);
389 auto &self = *this;
390 while (!nodes.empty()) {
391 auto nidx = nodes.top();
392 nodes.pop();
393 auto node_depth = node_depths.top();
394 node_depths.pop();
395 bool valid_node = !self.IsDeleted(nidx);
396 if (valid_node) {
397 if (node_depth > max_depth) max_depth = node_depth;
398 auto left = self.LeftChild(nidx);
399 auto right = self.RightChild(nidx);
400 if (left != Tree::kInvalidNodeId) {
401 nodes.push(left);
402 node_depths.push(node_depth+1);
403 }
404 if (right != Tree::kInvalidNodeId) {
405 nodes.push(right);
406 node_depths.push(node_depth+1);
407 }
408 }
409 }
410 return max_depth;
411 }
412
417 std::vector<double> LeafVector(std::int32_t nid) const {
418 std::size_t const offset_begin = leaf_vector_begin_[nid];
419 std::size_t const offset_end = leaf_vector_end_[nid];
420 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
421 // Return empty vector, to indicate the lack of leaf vector
422 return std::vector<double>();
423 }
424 return std::vector<double>(&leaf_vector_[offset_begin], &leaf_vector_[offset_end]);
425 // Use unsafe access here, since we may need to take the address of one past the last
426 // element, to follow with the range semantic of std::vector<>.
427 }
428
433 double SumSquaredNodeValues(std::int32_t nid) const {
434 if (output_dimension_ == 1) {
435 return std::pow(leaf_value_[nid], 2.0);
436 } else {
437 double result = 0.;
438 std::size_t const offset_begin = leaf_vector_begin_[nid];
439 std::size_t const offset_end = leaf_vector_end_[nid];
440 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
441 Log::Fatal("No leaf vector set for node nid");
442 }
443 for (std::size_t i = offset_begin; i < offset_end; i++) {
444 result += std::pow(leaf_vector_[i], 2.0);
445 }
446 return result;
447 }
448 }
449
453 double SumSquaredLeafValues() const {
454 double result = 0.;
455 for (auto& leaf : leaves_) {
456 result += SumSquaredNodeValues(leaf);
457 }
458 return result;
459 }
460
465 bool HasLeafVector(std::int32_t nid) const {
466 return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
467 }
468
473 double Threshold(std::int32_t nid) const {
474 return threshold_[nid];
475 }
476
484 std::vector<std::uint32_t> CategoryList(std::int32_t nid) const {
485 std::size_t const offset_begin = category_list_begin_[nid];
486 std::size_t const offset_end = category_list_end_[nid];
487 if (offset_begin >= category_list_.size() || offset_end > category_list_.size()) {
488 // Return empty vector, to indicate the lack of any category list
489 // The node might be a numerical split
490 return {};
491 }
492 return std::vector<std::uint32_t>(&category_list_[offset_begin], &category_list_[offset_end]);
493 // Use unsafe access here, since we may need to take the address of one past the last
494 // element, to follow with the range semantic of std::vector<>.
495 }
496
501 TreeNodeType NodeType(std::int32_t nid) const {
502 return node_type_[nid];
503 }
504
509 bool IsNumericSplitNode(std::int32_t nid) const {
510 return node_type_[nid] == TreeNodeType::kNumericalSplitNode;
511 }
512
517 bool IsCategoricalSplitNode(std::int32_t nid) const {
518 return node_type_[nid] == TreeNodeType::kCategoricalSplitNode;
519 }
520
524 bool HasCategoricalSplit() const {
525 return has_categorical_split_;
526 }
527
528 /* \brief Count number of leaves in tree. */
529 [[nodiscard]] std::int32_t NumLeaves() const;
530 [[nodiscard]] std::int32_t NumLeafParents() const;
531 [[nodiscard]] std::int32_t NumSplitNodes() const;
532
533 /* \brief Determine whether nid is leaf parent */
534 [[nodiscard]] bool IsLeafParent(std::int32_t nid) const {
535 // False until we deduce left and right node are
536 // available and both are leaves
537 bool is_left_leaf = false;
538 bool is_right_leaf = false;
539 // Check if node nidx is a leaf, if so, return false
540 bool is_leaf = this->IsLeaf(nid);
541 if (is_leaf){
542 return false;
543 } else {
544 // If nidx is not a leaf, it must have left and right nodes
545 // so we check if those are leaves
546 std::int32_t left_node = LeftChild(nid);
547 std::int32_t right_node = RightChild(nid);
548 is_left_leaf = IsLeaf(left_node);
549 is_right_leaf = IsLeaf(right_node);
550 }
551 return is_left_leaf && is_right_leaf;
552 }
553
557 [[nodiscard]] std::vector<std::int32_t> const& GetInternalNodes() const {
558 return internal_nodes_;
559 }
560
564 [[nodiscard]] std::vector<std::int32_t> const& GetLeaves() const {
565 return leaves_;
566 }
567
571 [[nodiscard]] std::vector<std::int32_t> const& GetLeafParents() const {
572 return leaf_parents_;
573 }
574
578 [[nodiscard]] std::vector<std::int32_t> GetNodes() {
579 std::vector<std::int32_t> output;
580 auto const& self = *this;
581 this->WalkTree([&output, &self](std::int32_t nidx) {
582 if (!self.IsDeleted(nidx)) {
583 output.push_back(nidx);
584 }
585 return true;
586 });
587 return output;
588 }
589
594 [[nodiscard]] std::int32_t GetDepth(std::int32_t nid) const {
595 int depth = 0;
596 while (!IsRoot(nid)) {
597 ++depth;
598 nid = Parent(nid);
599 }
600 return depth;
601 }
602
606 [[nodiscard]] std::int32_t NumNodes() const noexcept { return num_nodes; }
607
611 [[nodiscard]] std::int32_t NumDeletedNodes() const noexcept { return num_deleted_nodes; }
612
616 [[nodiscard]] std::int32_t NumValidNodes() const noexcept {
617 return num_nodes - num_deleted_nodes;
618 }
619
626 void SetLeftChild(std::int32_t nid, std::int32_t left_child) {
627 cleft_[nid] = left_child;
628 }
629
635 void SetRightChild(std::int32_t nid, std::int32_t right_child) {
636 cright_[nid] = right_child;
637 }
638
645 void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
646 SetLeftChild(nid, left_child);
647 SetRightChild(nid, right_child);
648 }
649
655 void SetParent(std::int32_t child_node, std::int32_t parent_node) {
656 parent_[child_node] = parent_node;
657 }
658
665 void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
666 SetParent(left_child, nid);
667 SetParent(right_child, nid);
668 }
669
677 std::int32_t nid, std::int32_t split_index, double threshold);
678
687 void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index,
688 std::vector<std::uint32_t> const& category_list);
689
695 void SetLeaf(std::int32_t nid, double value);
696
702 void SetLeafVector(std::int32_t nid, std::vector<double> const& leaf_vector);
703
724 void PredictLeafIndexInplace(ForestDataset* dataset, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
725
746 void PredictLeafIndexInplace(Eigen::MatrixXd& covariates, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
747
768 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
769
770 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
771 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
772 int column_ind, int32_t offset, int32_t max_leaf);
773
774 // Node info
775 std::vector<TreeNodeType> node_type_;
776 std::vector<std::int32_t> parent_;
777 std::vector<std::int32_t> cleft_;
778 std::vector<std::int32_t> cright_;
779 std::vector<std::int32_t> split_index_;
780 std::vector<double> leaf_value_;
781 std::vector<double> threshold_;
782 std::vector<bool> node_deleted_;
783 std::vector<std::int32_t> internal_nodes_;
784 std::vector<std::int32_t> leaves_;
785 std::vector<std::int32_t> leaf_parents_;
786 std::vector<std::int32_t> deleted_nodes_;
787
788 // Leaf vector
789 std::vector<double> leaf_vector_;
790 std::vector<std::uint64_t> leaf_vector_begin_;
791 std::vector<std::uint64_t> leaf_vector_end_;
792
793 // Category list
794 std::vector<std::uint32_t> category_list_;
795 std::vector<std::uint64_t> category_list_begin_;
796 std::vector<std::uint64_t> category_list_end_;
797
798 bool has_categorical_split_{false};
799 int output_dimension_{1};
800 bool is_log_scale_{false};
801};
802
804inline bool operator==(const Tree& lhs, const Tree& rhs) {
805 return (
806 (lhs.has_categorical_split_ == rhs.has_categorical_split_) &&
807 (lhs.output_dimension_ == rhs.output_dimension_) &&
808 (lhs.is_log_scale_ == rhs.is_log_scale_) &&
809 (lhs.node_type_ == rhs.node_type_) &&
810 (lhs.parent_ == rhs.parent_) &&
811 (lhs.cleft_ == rhs.cleft_) &&
812 (lhs.cright_ == rhs.cright_) &&
813 (lhs.split_index_ == rhs.split_index_) &&
814 (lhs.leaf_value_ == rhs.leaf_value_) &&
815 (lhs.threshold_ == rhs.threshold_) &&
816 (lhs.internal_nodes_ == rhs.internal_nodes_) &&
817 (lhs.leaves_ == rhs.leaves_) &&
818 (lhs.leaf_parents_ == rhs.leaf_parents_) &&
819 (lhs.deleted_nodes_ == rhs.deleted_nodes_) &&
820 (lhs.leaf_vector_ == rhs.leaf_vector_) &&
821 (lhs.leaf_vector_begin_ == rhs.leaf_vector_begin_) &&
822 (lhs.leaf_vector_end_ == rhs.leaf_vector_end_) &&
823 (lhs.category_list_ == rhs.category_list_) &&
824 (lhs.category_list_begin_ == rhs.category_list_begin_) &&
825 (lhs.category_list_end_ == rhs.category_list_end_)
826 );
827}
828
835inline bool SplitTrueNumeric(double fvalue, double threshold) {
836 return (fvalue <= threshold);
837}
838
845inline bool SplitTrueCategorical(double fvalue, std::vector<std::uint32_t> const& category_list) {
846 bool category_matched;
847 // A valid (integer) category must satisfy two criteria:
848 // 1) it must be exactly representable as double
849 // 2) it must fit into uint32_t
850 auto max_representable_int
851 = std::min(static_cast<double>(std::numeric_limits<std::uint32_t>::max()),
852 static_cast<double>(std::uint64_t(1) << std::numeric_limits<double>::digits));
853 if (fvalue < 0 || std::fabs(fvalue) > max_representable_int) {
854 category_matched = false;
855 } else {
856 auto const category_value = static_cast<std::uint32_t>(fvalue);
857 category_matched = (std::find(category_list.begin(), category_list.end(), category_value)
858 != category_list.end());
859 }
860 return category_matched;
861}
862
869inline int NextNodeNumeric(double fvalue, double threshold, int left_child, int right_child) {
870 return (SplitTrueNumeric(fvalue, threshold) ? left_child : right_child);
871}
872
879inline int NextNodeCategorical(double fvalue, std::vector<std::uint32_t> const& category_list, int left_child, int right_child) {
880 return SplitTrueCategorical(fvalue, category_list) ? left_child : right_child;
881}
882
890inline int EvaluateTree(Tree const& tree, Eigen::MatrixXd& data, int row) {
891 int node_id = 0;
892 while (!tree.IsLeaf(node_id)) {
893 auto const split_index = tree.SplitIndex(node_id);
894 double const fvalue = data(row, split_index);
895 if (std::isnan(fvalue)) {
896 node_id = tree.DefaultChild(node_id);
897 } else {
898 if (tree.NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
899 node_id = NextNodeCategorical(fvalue, tree.CategoryList(node_id),
900 tree.LeftChild(node_id), tree.RightChild(node_id));
901 } else {
902 node_id = NextNodeNumeric(fvalue, tree.Threshold(node_id), tree.LeftChild(node_id), tree.RightChild(node_id));
903 }
904 }
905 }
906 return node_id;
907}
908
916inline int EvaluateTree(Tree const& tree, Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& data, int row) {
917 int node_id = 0;
918 while (!tree.IsLeaf(node_id)) {
919 auto const split_index = tree.SplitIndex(node_id);
920 double const fvalue = data(row, split_index);
921 if (std::isnan(fvalue)) {
922 node_id = tree.DefaultChild(node_id);
923 } else {
924 if (tree.NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
925 node_id = NextNodeCategorical(fvalue, tree.CategoryList(node_id),
926 tree.LeftChild(node_id), tree.RightChild(node_id));
927 } else {
928 node_id = NextNodeNumeric(fvalue, tree.Threshold(node_id), tree.LeftChild(node_id), tree.RightChild(node_id));
929 }
930 }
931 }
932 return node_id;
933}
934
941inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, double split_value) {
942 double const fvalue = covariates(row, split_index);
943 return SplitTrueNumeric(fvalue, split_value);
944}
945
952inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, std::vector<std::uint32_t> const& category_list) {
953 double const fvalue = covariates(row, split_index);
954 return SplitTrueCategorical(fvalue, category_list);
955}
956
959 public:
960 TreeSplit() {}
966 TreeSplit(double split_value) {
967 numeric_ = true;
968 split_value_ = split_value;
969 split_set_ = true;
970 }
976 TreeSplit(std::vector<std::uint32_t>& split_categories) {
977 numeric_ = false;
978 split_categories_ = split_categories;
979 split_set_ = true;
980 }
981 ~TreeSplit() {}
982 bool SplitSet() {return split_set_;}
984 bool NumericSplit() {return numeric_;}
990 bool SplitTrue(double fvalue) {
991 if (numeric_) return SplitTrueNumeric(fvalue, split_value_);
992 else return SplitTrueCategorical(fvalue, split_categories_);
993 }
995 double SplitValue() {return split_value_;}
997 std::vector<std::uint32_t> SplitCategories() {return split_categories_;}
998 private:
999 bool split_set_{false};
1000 bool numeric_;
1001 double split_value_;
1002 std::vector<std::uint32_t> split_categories_;
1003};
1004
// end of tree_group
1006
1007} // namespace StochTree
1008
1009#endif // STOCHTREE_TREE_H_
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:958
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:990
bool NumericSplit()
Whether or not a TreeSplit rule is numeric.
Definition tree.h:984
std::vector< std::uint32_t > SplitCategories()
Categories defining a TreeSplit object.
Definition tree.h:997
TreeSplit(std::vector< std::uint32_t > &split_categories)
Construct a categorical TreeSplit.
Definition tree.h:976
double SplitValue()
Numeric cutoff value defining a TreeSplit object.
Definition tree.h:995
TreeSplit(double split_value)
Construct a numeric TreeSplit.
Definition tree.h:966
Decision tree data structure.
Definition tree.h:66
void SetNumericSplit(std::int32_t nid, std::int32_t split_index, double threshold)
Create a numerical split.
void SetLeftChild(std::int32_t nid, std::int32_t left_child)
Identify left child node.
Definition tree.h:626
std::int32_t LeftChild(std::int32_t nid) const
Index of the node's left child.
Definition tree.h:301
void PredictLeafIndexInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
void CloneFromTree(Tree *tree)
Copy the structure and parameters of another tree. If the Tree object calling this method already has...
std::int32_t RightChild(std::int32_t nid) const
Index of the node's right child.
Definition tree.h:309
std::int32_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition tree.h:616
void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index, std::vector< std::uint32_t > const &category_list)
Create a categorical split.
std::int32_t SplitIndex(std::int32_t nid) const
Feature index defining the node's split rule.
Definition tree.h:325
void PredictLeafIndexInplace(Eigen::MatrixXd &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsLeaf(std::int32_t nid) const
Whether the node is a leaf node.
Definition tree.h:333
bool IsNumericSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:509
void WalkTree(Func func) const
Iterate through all nodes in this tree.
Definition tree.h:241
double SumSquaredLeafValues() const
Sum of squared values for all leaves in a tree.
Definition tree.h:453
std::int32_t DefaultChild(std::int32_t nid) const
Index of the node's "default" child (potentially used in the case of a missing feature at prediction ...
Definition tree.h:317
void AddValueToLeaves(double constant_value)
Add a constant value to every leaf of a tree. If leaves are multi-dimensional, constant_value will be...
Definition tree.h:206
int AllocNode()
Allocate a new node and return the node's ID.
void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a numeric split rule.
std::vector< double > LeafVector(std::int32_t nid) const
Get vector-valued parameters of a node (typically leaf)
Definition tree.h:417
std::int32_t Parent(std::int32_t nid) const
Index of the node's parent.
Definition tree.h:293
void PredictLeafIndexInplace(ForestDataset *dataset, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsCategoricalSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:517
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition tree.h:524
double LeafValue(std::int32_t nid, std::int32_t dim_id) const
Get parameter value of a node (typically though not necessarily a leaf node) at a given output dimens...
Definition tree.h:366
void from_json(const json &tree_json)
Load from JSON.
std::int32_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition tree.h:606
TreeNodeType NodeType(std::int32_t nid) const
Get the type of a node (i.e. numeric split, categorical split, leaf)
Definition tree.h:501
double SumSquaredNodeValues(std::int32_t nid) const
Sum of squared parameter values for a given node (typically though not necessarily a leaf node)
Definition tree.h:433
void DeleteNode(std::int32_t nid)
Deletes node indexed by node ID.
std::int32_t MaxLeafDepth() const
Get maximum depth of all of the leaf nodes.
Definition tree.h:383
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, double left_value, double right_value)
Expand a node based on a generic split rule.
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a generic split rule.
std::int32_t OutputDimension() const
Dimension of tree output.
Definition tree.h:278
bool HasVectorOutput() const
Whether or not a tree has vector output.
Definition tree.h:271
void MultiplyLeavesByValue(double constant_multiple)
Multiply every leaf of a tree by a constant value. If leaves are multi-dimensional,...
Definition tree.h:223
std::int32_t NumDeletedNodes() const noexcept
Get the total number of deleted nodes in this tree.
Definition tree.h:611
void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify two child nodes of the node and the corresponding parent node of the child nodes.
Definition tree.h:645
std::vector< std::uint32_t > CategoryList(std::int32_t nid) const
Get list of all categories belonging to the left child node. Categories are integers ranging from 0 t...
Definition tree.h:484
void SetParent(std::int32_t child_node, std::int32_t parent_node)
Identify parent node.
Definition tree.h:655
void SetLeaf(std::int32_t nid, double value)
Set the leaf value of the node.
bool IsRoot()
Whether or not a tree is a "stump" consisting of a single root node.
Definition tree.h:111
double LeafValue(std::int32_t nid) const
Get parameter value of a node (typically though not necessarily a leaf node)
Definition tree.h:357
void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify parent node of the left and right node ids.
Definition tree.h:665
std::vector< std::int32_t > const & GetInternalNodes() const
Get indices of all internal nodes.
Definition tree.h:557
std::vector< std::int32_t > const & GetLeafParents() const
Get indices of all leaf parent nodes.
Definition tree.h:571
std::vector< std::int32_t > GetNodes()
Get indices of all valid (non-deleted) nodes.
Definition tree.h:578
bool HasLeafVector(std::int32_t nid) const
Tests whether the leaf node has a non-empty leaf vector.
Definition tree.h:465
void Init(int output_dimension=1, bool is_log_scale=false)
Initialize the tree with a single root node.
void CollapseToLeaf(std::int32_t nid, std::vector< double > value_vector)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:188
double Threshold(std::int32_t nid) const
Get split threshold of the node.
Definition tree.h:473
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, double left_value, double right_value)
Expand a node based on a categorical split rule.
bool IsRoot(std::int32_t nid) const
Whether the node is root.
Definition tree.h:341
void Reset()
Reset tree to empty vectors and default values of boolean / integer variables.
bool IsDeleted(std::int32_t nid) const
Whether the node has been deleted.
Definition tree.h:349
void ExpandNode(std::int32_t nid, int split_index, double split_value, double left_value, double right_value)
Expand a node based on a numeric split rule.
void SetRightChild(std::int32_t nid, std::int32_t right_child)
Identify right child node.
Definition tree.h:635
std::vector< std::int32_t > const & GetLeaves() const
Get indices of all leaf nodes.
Definition tree.h:564
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a categorical split rule.
bool IsLogScale() const
Whether or not tree parameters should be exponentiated at prediction time.
Definition tree.h:285
void CollapseToLeaf(std::int32_t nid, double value)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:149
void SetLeafVector(std::int32_t nid, std::vector< double > const &leaf_vector)
Set the leaf vector of the node; useful for multi-output trees.
std::int32_t GetDepth(std::int32_t nid) const
Get the depth of a node.
Definition tree.h:594
json to_json()
Convert tree to JSON and return JSON in-memory.
bool SplitTrueNumeric(double fvalue, double threshold)
Determine whether an observation produces a "true" value in a numeric split node.
Definition tree.h:835
int NextNodeCategorical(double fvalue, std::vector< std::uint32_t > const &category_list, int left_child, int right_child)
Return left or right node id based on a categorical split.
Definition tree.h:879
bool SplitTrueCategorical(double fvalue, std::vector< std::uint32_t > const &category_list)
Determine whether an observation produces a "true" value in a categorical split node.
Definition tree.h:845
int EvaluateTree(Tree const &tree, Eigen::MatrixXd &data, int row)
Definition tree.h:890
int NextNodeNumeric(double fvalue, double threshold, int left_child, int right_child)
Return left or right node id based on a numeric split.
Definition tree.h:869
bool operator==(const Tree &lhs, const Tree &rhs)
Comparison operator for trees.
Definition tree.h:804
bool RowSplitLeft(Eigen::MatrixXd &covariates, int row, int split_index, double split_value)
Determine whether a given observation is "true" at a split proposed by split_index and split_value.
Definition tree.h:941
Definition category_tracker.h:36
std::string TreeNodeTypeToString(TreeNodeType type)
Get string representation of TreeNodeType.
TreeNodeType
Tree node type.
Definition tree.h:24
TreeNodeType TreeNodeTypeFromString(std::string const &name)
Get NodeType from string.