6#ifndef STOCHTREE_TREE_H_
7#define STOCHTREE_TREE_H_
9#include <nlohmann/json.hpp>
10#include <stochtree/data.h>
11#include <stochtree/log.h>
12#include <stochtree/meta.h>
22using json = nlohmann::json;
29 kNumericalSplitNode = 1,
30 kCategoricalSplitNode = 2
49enum FeatureSplitType {
51 kOrderedCategoricalSplit,
52 kUnorderedCategoricalSplit
71 static constexpr std::int32_t kInvalidNodeId{-1};
72 static constexpr std::int32_t kDeletedNodeMarker = std::numeric_limits<node_t>::max();
73 static constexpr std::int32_t kRoot{0};
78 Tree& operator=(
Tree const&) =
delete;
80 Tree& operator=(
Tree&&)
noexcept =
default;
89 std::int32_t num_nodes{0};
90 std::int32_t num_deleted_nodes{0};
95 void Init(
int output_dimension = 1,
bool is_log_scale =
false);
101 void ExpandNode(std::int32_t nid,
int split_index,
double split_value,
double left_value,
double right_value);
103 void ExpandNode(std::int32_t nid,
int split_index, std::vector<std::uint32_t>
const& categorical_indices,
double left_value,
double right_value);
105 void ExpandNode(std::int32_t nid,
int split_index,
double split_value, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
107 void ExpandNode(std::int32_t nid,
int split_index, std::vector<std::uint32_t>
const& categorical_indices, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
109 void ExpandNode(std::int32_t nid,
int split_index,
TreeSplit& split,
double left_value,
double right_value);
111 void ExpandNode(std::int32_t nid,
int split_index,
TreeSplit& split, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
114 inline bool IsRoot() {
return leaves_.size() == 1;}
125 void ChangeToLeaf(std::int32_t nid,
double value) {
133 leaves_.push_back(nid);
134 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
135 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
140 int parent_id =
Parent(nid);
142 leaf_parents_.push_back(parent_id);
153 CHECK_EQ(output_dimension_, 1);
154 if (this->
IsLeaf(nid))
return;
161 this->ChangeToLeaf(nid, value);
164 void ChangeToLeaf(std::int32_t nid, std::vector<double> value_vector) {
172 leaves_.push_back(nid);
173 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
174 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
179 int parent_id =
Parent(nid);
181 leaf_parents_.push_back(parent_id);
192 CHECK_GT(output_dimension_, 1);
193 CHECK_EQ(output_dimension_, value_vector.size());
194 if (this->
IsLeaf(nid))
return;
201 this->ChangeToLeaf(nid, value_vector);
210 template <
typename Func>
void WalkTree(Func func)
const {
211 std::stack<std::int32_t> nodes;
214 while (!nodes.empty()) {
215 auto nidx = nodes.top();
220 auto left = self.LeftChild(nidx);
221 auto right = self.RightChild(nidx);
222 if (left != Tree::kInvalidNodeId) {
225 if (right != Tree::kInvalidNodeId) {
231 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices);
232 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices, Eigen::MatrixXd& basis);
233 double PredictFromNode(std::int32_t node_id);
234 double PredictFromNode(std::int32_t node_id, Eigen::MatrixXd& basis,
int row_idx);
241 return output_dimension_ > 1;
248 return output_dimension_;
255 return is_log_scale_;
262 std::int32_t
Parent(std::int32_t nid)
const {
295 return split_index_[nid];
303 return cleft_[nid] == kInvalidNodeId;
311 return parent_[nid] == kInvalidNodeId;
319 return node_deleted_[nid];
327 return leaf_value_[nid];
335 double LeafValue(std::int32_t nid, std::int32_t dim_id)
const {
336 CHECK_LT(dim_id, output_dimension_);
337 if (output_dimension_ == 1 && dim_id == 0) {
338 return leaf_value_[nid];
340 std::size_t
const offset_begin = leaf_vector_begin_[nid];
341 std::size_t
const offset_end = leaf_vector_end_[nid];
342 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
343 Log::Fatal(
"No leaf vector set for node nid");
345 return leaf_vector_[offset_begin + dim_id];
353 std::int32_t max_depth = 0;
354 std::stack<std::int32_t> nodes;
355 std::stack<std::int32_t> node_depths;
359 while (!nodes.empty()) {
360 auto nidx = nodes.top();
362 auto node_depth = node_depths.top();
364 bool valid_node = !self.IsDeleted(nidx);
366 if (node_depth > max_depth) max_depth = node_depth;
367 auto left = self.LeftChild(nidx);
368 auto right = self.RightChild(nidx);
369 if (left != Tree::kInvalidNodeId) {
371 node_depths.push(node_depth+1);
373 if (right != Tree::kInvalidNodeId) {
375 node_depths.push(node_depth+1);
387 std::size_t
const offset_begin = leaf_vector_begin_[nid];
388 std::size_t
const offset_end = leaf_vector_end_[nid];
389 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
391 return std::vector<double>();
393 return std::vector<double>(&leaf_vector_[offset_begin], &leaf_vector_[offset_end]);
403 if (output_dimension_ == 1) {
404 return std::pow(leaf_value_[nid], 2.0);
407 std::size_t
const offset_begin = leaf_vector_begin_[nid];
408 std::size_t
const offset_end = leaf_vector_end_[nid];
409 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
410 Log::Fatal(
"No leaf vector set for node nid");
412 for (std::size_t i = offset_begin; i < offset_end; i++) {
413 result += std::pow(leaf_vector_[i], 2.0);
424 for (
auto& leaf : leaves_) {
435 return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
443 return threshold_[nid];
454 std::size_t
const offset_begin = category_list_begin_[nid];
455 std::size_t
const offset_end = category_list_end_[nid];
456 if (offset_begin >= category_list_.size() || offset_end > category_list_.size()) {
461 return std::vector<std::uint32_t>(&category_list_[offset_begin], &category_list_[offset_end]);
471 return node_type_[nid];
479 return node_type_[nid] == TreeNodeType::kNumericalSplitNode;
487 return node_type_[nid] == TreeNodeType::kCategoricalSplitNode;
494 return has_categorical_split_;
498 [[nodiscard]] std::int32_t NumLeaves()
const;
499 [[nodiscard]] std::int32_t NumLeafParents()
const;
500 [[nodiscard]] std::int32_t NumSplitNodes()
const;
503 [[nodiscard]]
bool IsLeafParent(std::int32_t nid)
const {
506 bool is_left_leaf =
false;
507 bool is_right_leaf =
false;
509 bool is_leaf = this->
IsLeaf(nid);
517 is_left_leaf =
IsLeaf(left_node);
518 is_right_leaf =
IsLeaf(right_node);
520 return is_left_leaf && is_right_leaf;
527 return internal_nodes_;
533 [[nodiscard]] std::vector<std::int32_t>
const&
GetLeaves()
const {
541 return leaf_parents_;
547 [[nodiscard]] std::vector<std::int32_t>
GetNodes() {
548 std::vector<std::int32_t> output;
549 auto const& self = *
this;
550 this->
WalkTree([&output, &self](std::int32_t nidx) {
551 if (!self.IsDeleted(nidx)) {
552 output.push_back(nidx);
563 [[nodiscard]] std::int32_t
GetDepth(std::int32_t nid)
const {
575 [[nodiscard]] std::int32_t
NumNodes() const noexcept {
return num_nodes; }
580 [[nodiscard]] std::int32_t
NumDeletedNodes() const noexcept {
return num_deleted_nodes; }
586 return num_nodes - num_deleted_nodes;
596 cleft_[nid] = left_child;
605 cright_[nid] = right_child;
614 void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
624 void SetParent(std::int32_t child_node, std::int32_t parent_node) {
625 parent_[child_node] = parent_node;
634 void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
646 std::int32_t nid, std::int32_t split_index,
double threshold);
657 std::vector<std::uint32_t>
const& category_list);
671 void SetLeafVector(std::int32_t nid, std::vector<double>
const& leaf_vector);
737 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
739 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
740 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
741 int column_ind, int32_t offset, int32_t max_leaf);
744 std::vector<TreeNodeType> node_type_;
745 std::vector<std::int32_t> parent_;
746 std::vector<std::int32_t> cleft_;
747 std::vector<std::int32_t> cright_;
748 std::vector<std::int32_t> split_index_;
749 std::vector<double> leaf_value_;
750 std::vector<double> threshold_;
751 std::vector<bool> node_deleted_;
752 std::vector<std::int32_t> internal_nodes_;
753 std::vector<std::int32_t> leaves_;
754 std::vector<std::int32_t> leaf_parents_;
755 std::vector<std::int32_t> deleted_nodes_;
758 std::vector<double> leaf_vector_;
759 std::vector<std::uint64_t> leaf_vector_begin_;
760 std::vector<std::uint64_t> leaf_vector_end_;
763 std::vector<std::uint32_t> category_list_;
764 std::vector<std::uint64_t> category_list_begin_;
765 std::vector<std::uint64_t> category_list_end_;
767 bool has_categorical_split_{
false};
768 int output_dimension_{1};
769 bool is_log_scale_{
false};
775 (lhs.has_categorical_split_ == rhs.has_categorical_split_) &&
776 (lhs.output_dimension_ == rhs.output_dimension_) &&
777 (lhs.is_log_scale_ == rhs.is_log_scale_) &&
778 (lhs.node_type_ == rhs.node_type_) &&
779 (lhs.parent_ == rhs.parent_) &&
780 (lhs.cleft_ == rhs.cleft_) &&
781 (lhs.cright_ == rhs.cright_) &&
782 (lhs.split_index_ == rhs.split_index_) &&
783 (lhs.leaf_value_ == rhs.leaf_value_) &&
784 (lhs.threshold_ == rhs.threshold_) &&
785 (lhs.internal_nodes_ == rhs.internal_nodes_) &&
786 (lhs.leaves_ == rhs.leaves_) &&
787 (lhs.leaf_parents_ == rhs.leaf_parents_) &&
788 (lhs.deleted_nodes_ == rhs.deleted_nodes_) &&
789 (lhs.leaf_vector_ == rhs.leaf_vector_) &&
790 (lhs.leaf_vector_begin_ == rhs.leaf_vector_begin_) &&
791 (lhs.leaf_vector_end_ == rhs.leaf_vector_end_) &&
792 (lhs.category_list_ == rhs.category_list_) &&
793 (lhs.category_list_begin_ == rhs.category_list_begin_) &&
794 (lhs.category_list_end_ == rhs.category_list_end_)
805 return (fvalue <= threshold);
815 bool category_matched;
819 auto max_representable_int
820 = std::min(
static_cast<double>(std::numeric_limits<std::uint32_t>::max()),
821 static_cast<double>(std::uint64_t(1) << std::numeric_limits<double>::digits));
822 if (fvalue < 0 || std::fabs(fvalue) > max_representable_int) {
823 category_matched =
false;
825 auto const category_value =
static_cast<std::uint32_t
>(fvalue);
826 category_matched = (std::find(category_list.begin(), category_list.end(), category_value)
827 != category_list.end());
829 return category_matched;
838inline int NextNodeNumeric(
double fvalue,
double threshold,
int left_child,
int right_child) {
848inline int NextNodeCategorical(
double fvalue, std::vector<std::uint32_t>
const& category_list,
int left_child,
int right_child) {
861 while (!tree.
IsLeaf(node_id)) {
862 auto const split_index = tree.
SplitIndex(node_id);
863 double const fvalue = data(row, split_index);
864 if (std::isnan(fvalue)) {
867 if (tree.
NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
885inline int EvaluateTree(
Tree const& tree, Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& data,
int row) {
887 while (!tree.
IsLeaf(node_id)) {
888 auto const split_index = tree.
SplitIndex(node_id);
889 double const fvalue = data(row, split_index);
890 if (std::isnan(fvalue)) {
893 if (tree.
NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
910inline bool RowSplitLeft(Eigen::MatrixXd& covariates,
int row,
int split_index,
double split_value) {
911 double const fvalue = covariates(row, split_index);
921inline bool RowSplitLeft(Eigen::MatrixXd& covariates,
int row,
int split_index, std::vector<std::uint32_t>
const& category_list) {
922 double const fvalue = covariates(row, split_index);
937 split_value_ = split_value;
945 TreeSplit(std::vector<std::uint32_t>& split_categories) {
947 split_categories_ = split_categories;
951 bool SplitSet() {
return split_set_;}
968 bool split_set_{
false};
971 std::vector<std::uint32_t> split_categories_;
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:927
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:959
bool NumericSplit()
Whether or not a TreeSplit rule is numeric.
Definition tree.h:953
std::vector< std::uint32_t > SplitCategories()
Categories defining a TreeSplit object.
Definition tree.h:966
TreeSplit(std::vector< std::uint32_t > &split_categories)
Construct a categorical TreeSplit.
Definition tree.h:945
double SplitValue()
Numeric cutoff value defining a TreeSplit object.
Definition tree.h:964
TreeSplit(double split_value)
Construct a numeric TreeSplit.
Definition tree.h:935
Decision tree data structure.
Definition tree.h:69
void SetNumericSplit(std::int32_t nid, std::int32_t split_index, double threshold)
Create a numerical split.
void SetLeftChild(std::int32_t nid, std::int32_t left_child)
Identify left child node.
Definition tree.h:595
std::int32_t LeftChild(std::int32_t nid) const
Index of the node's left child.
Definition tree.h:270
void PredictLeafIndexInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
void CloneFromTree(Tree *tree)
Copy the structure and parameters of another tree. If the Tree object calling this method already has...
std::int32_t RightChild(std::int32_t nid) const
Index of the node's right child.
Definition tree.h:278
std::int32_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition tree.h:585
void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index, std::vector< std::uint32_t > const &category_list)
Create a categorical split.
std::int32_t SplitIndex(std::int32_t nid) const
Feature index defining the node's split rule.
Definition tree.h:294
void PredictLeafIndexInplace(Eigen::MatrixXd &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsLeaf(std::int32_t nid) const
Whether the node is a leaf node.
Definition tree.h:302
bool IsNumericSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:478
void WalkTree(Func func) const
Iterate through all nodes in this tree.
Definition tree.h:210
double SumSquaredLeafValues() const
Sum of squared values for all leaves in a tree.
Definition tree.h:422
std::int32_t DefaultChild(std::int32_t nid) const
Index of the node's "default" child (potentially used in the case of a missing feature at prediction ...
Definition tree.h:286
int AllocNode()
Allocate a new node and return the node's ID.
void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a numeric split rule.
std::vector< double > LeafVector(std::int32_t nid) const
Get vector-valued parameters of a node (typically leaf)
Definition tree.h:386
std::int32_t Parent(std::int32_t nid) const
Index of the node's parent.
Definition tree.h:262
void PredictLeafIndexInplace(ForestDataset *dataset, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsCategoricalSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:486
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition tree.h:493
double LeafValue(std::int32_t nid, std::int32_t dim_id) const
Get parameter value of a node (typically though not necessarily a leaf node) at a given output dimens...
Definition tree.h:335
void from_json(const json &tree_json)
Load from JSON.
std::int32_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition tree.h:575
TreeNodeType NodeType(std::int32_t nid) const
Get the type of a node (i.e. numeric split, categorical split, leaf)
Definition tree.h:470
double SumSquaredNodeValues(std::int32_t nid) const
Sum of squared parameter values for a given node (typically though not necessarily a leaf node)
Definition tree.h:402
void DeleteNode(std::int32_t nid)
Deletes node indexed by node ID.
std::int32_t MaxLeafDepth() const
Get maximum depth of all of the leaf nodes.
Definition tree.h:352
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, double left_value, double right_value)
Expand a node based on a generic split rule.
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a generic split rule.
std::int32_t OutputDimension() const
Dimension of tree output.
Definition tree.h:247
bool HasVectorOutput() const
Whether or not a tree has vector output.
Definition tree.h:240
std::int32_t NumDeletedNodes() const noexcept
Get the total number of deleted nodes in this tree.
Definition tree.h:580
void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify two child nodes of the node and the corresponding parent node of the child nodes.
Definition tree.h:614
std::vector< std::uint32_t > CategoryList(std::int32_t nid) const
Get list of all categories belonging to the left child node. Categories are integers ranging from 0 t...
Definition tree.h:453
void SetParent(std::int32_t child_node, std::int32_t parent_node)
Identify parent node.
Definition tree.h:624
void SetLeaf(std::int32_t nid, double value)
Set the leaf value of the node.
bool IsRoot()
Whether or not a tree is a "stump" consisting of a single root node.
Definition tree.h:114
double LeafValue(std::int32_t nid) const
Get parameter value of a node (typically though not necessarily a leaf node)
Definition tree.h:326
void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify parent node of the left and right node ids.
Definition tree.h:634
std::vector< std::int32_t > const & GetInternalNodes() const
Get indices of all internal nodes.
Definition tree.h:526
std::vector< std::int32_t > const & GetLeafParents() const
Get indices of all leaf parent nodes.
Definition tree.h:540
std::vector< std::int32_t > GetNodes()
Get indices of all valid (non-deleted) nodes.
Definition tree.h:547
bool HasLeafVector(std::int32_t nid) const
Tests whether the leaf node has a non-empty leaf vector.
Definition tree.h:434
void Init(int output_dimension=1, bool is_log_scale=false)
Initialize the tree with a single root node.
void CollapseToLeaf(std::int32_t nid, std::vector< double > value_vector)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:191
double Threshold(std::int32_t nid) const
Get split threshold of the node.
Definition tree.h:442
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, double left_value, double right_value)
Expand a node based on a categorical split rule.
bool IsRoot(std::int32_t nid) const
Whether the node is root.
Definition tree.h:310
void Reset()
Reset tree to empty vectors and default values of boolean / integer variables.
bool IsDeleted(std::int32_t nid) const
Whether the node has been deleted.
Definition tree.h:318
void ExpandNode(std::int32_t nid, int split_index, double split_value, double left_value, double right_value)
Expand a node based on a numeric split rule.
void SetRightChild(std::int32_t nid, std::int32_t right_child)
Identify right child node.
Definition tree.h:604
std::vector< std::int32_t > const & GetLeaves() const
Get indices of all leaf nodes.
Definition tree.h:533
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a categorical split rule.
bool IsLogScale() const
Whether or not tree parameters should be exponentiated at prediction time.
Definition tree.h:254
void CollapseToLeaf(std::int32_t nid, double value)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:152
void SetLeafVector(std::int32_t nid, std::vector< double > const &leaf_vector)
Set the leaf vector of the node; useful for multi-output trees.
std::int32_t GetDepth(std::int32_t nid) const
Get the depth of a node.
Definition tree.h:563
json to_json()
Convert tree to JSON and return JSON in-memory.
bool SplitTrueNumeric(double fvalue, double threshold)
Determine whether an observation produces a "true" value in a numeric split node.
Definition tree.h:804
int NextNodeCategorical(double fvalue, std::vector< std::uint32_t > const &category_list, int left_child, int right_child)
Return left or right node id based on a categorical split.
Definition tree.h:848
bool SplitTrueCategorical(double fvalue, std::vector< std::uint32_t > const &category_list)
Determine whether an observation produces a "true" value in a categorical split node.
Definition tree.h:814
int EvaluateTree(Tree const &tree, Eigen::MatrixXd &data, int row)
Definition tree.h:859
int NextNodeNumeric(double fvalue, double threshold, int left_child, int right_child)
Return left or right node id based on a numeric split.
Definition tree.h:838
bool operator==(const Tree &lhs, const Tree &rhs)
Comparison operator for trees.
Definition tree.h:773
bool RowSplitLeft(Eigen::MatrixXd &covariates, int row, int split_index, double split_value)
Determine whether a given observation is "true" at a split proposed by split_index and split_value.
Definition tree.h:910
Definition category_tracker.h:40
std::string TreeNodeTypeToString(TreeNodeType type)
Get string representation of TreeNodeType.
TreeNodeType
Tree node type.
Definition tree.h:27
TreeNodeType TreeNodeTypeFromString(std::string const &name)
Get NodeType from string.