6#ifndef STOCHTREE_TREE_H_
7#define STOCHTREE_TREE_H_
9#include <nlohmann/json.hpp>
10#include <stochtree/data.h>
11#include <stochtree/log.h>
12#include <stochtree/meta.h>
19using json = nlohmann::json;
26 kNumericalSplitNode = 1,
27 kCategoricalSplitNode = 2
46enum FeatureSplitType {
48 kOrderedCategoricalSplit,
49 kUnorderedCategoricalSplit
68 static constexpr std::int32_t kInvalidNodeId{-1};
69 static constexpr std::int32_t kDeletedNodeMarker = std::numeric_limits<node_t>::max();
70 static constexpr std::int32_t kRoot{0};
75 Tree& operator=(
Tree const&) =
delete;
77 Tree& operator=(
Tree&&)
noexcept =
default;
86 std::int32_t num_nodes{0};
87 std::int32_t num_deleted_nodes{0};
92 void Init(
int output_dimension = 1,
bool is_log_scale =
false);
98 void ExpandNode(std::int32_t nid,
int split_index,
double split_value,
double left_value,
double right_value);
100 void ExpandNode(std::int32_t nid,
int split_index, std::vector<std::uint32_t>
const& categorical_indices,
double left_value,
double right_value);
102 void ExpandNode(std::int32_t nid,
int split_index,
double split_value, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
104 void ExpandNode(std::int32_t nid,
int split_index, std::vector<std::uint32_t>
const& categorical_indices, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
106 void ExpandNode(std::int32_t nid,
int split_index,
TreeSplit& split,
double left_value,
double right_value);
108 void ExpandNode(std::int32_t nid,
int split_index,
TreeSplit& split, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
111 inline bool IsRoot() {
return leaves_.size() == 1;}
122 void ChangeToLeaf(std::int32_t nid,
double value) {
130 leaves_.push_back(nid);
131 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
132 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
137 int parent_id =
Parent(nid);
139 leaf_parents_.push_back(parent_id);
150 CHECK_EQ(output_dimension_, 1);
151 if (this->
IsLeaf(nid))
return;
158 this->ChangeToLeaf(nid, value);
161 void ChangeToLeaf(std::int32_t nid, std::vector<double> value_vector) {
169 leaves_.push_back(nid);
170 leaf_parents_.erase(std::remove(leaf_parents_.begin(), leaf_parents_.end(), nid), leaf_parents_.end());
171 internal_nodes_.erase(std::remove(internal_nodes_.begin(), internal_nodes_.end(), nid), internal_nodes_.end());
176 int parent_id =
Parent(nid);
178 leaf_parents_.push_back(parent_id);
189 CHECK_GT(output_dimension_, 1);
190 CHECK_EQ(output_dimension_, value_vector.size());
191 if (this->
IsLeaf(nid))
return;
198 this->ChangeToLeaf(nid, value_vector);
207 if (output_dimension_ == 1) {
208 for (
int j = 0; j < leaf_value_.size(); j++) {
209 leaf_value_[j] += constant_value;
212 for (
int j = 0; j < leaf_vector_.size(); j++) {
213 leaf_vector_[j] += constant_value;
224 if (output_dimension_ == 1) {
225 for (
int j = 0; j < leaf_value_.size(); j++) {
226 leaf_value_[j] *= constant_multiple;
229 for (
int j = 0; j < leaf_vector_.size(); j++) {
230 leaf_vector_[j] *= constant_multiple;
241 template <
typename Func>
void WalkTree(Func func)
const {
242 std::stack<std::int32_t> nodes;
245 while (!nodes.empty()) {
246 auto nidx = nodes.top();
251 auto left = self.LeftChild(nidx);
252 auto right = self.RightChild(nidx);
253 if (left != Tree::kInvalidNodeId) {
256 if (right != Tree::kInvalidNodeId) {
262 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices);
263 std::vector<double> PredictFromNodes(std::vector<std::int32_t> node_indices, Eigen::MatrixXd& basis);
264 double PredictFromNode(std::int32_t node_id);
265 double PredictFromNode(std::int32_t node_id, Eigen::MatrixXd& basis,
int row_idx);
272 return output_dimension_ > 1;
279 return output_dimension_;
286 return is_log_scale_;
293 std::int32_t
Parent(std::int32_t nid)
const {
326 return split_index_[nid];
334 return cleft_[nid] == kInvalidNodeId;
342 return parent_[nid] == kInvalidNodeId;
350 return node_deleted_[nid];
358 return leaf_value_[nid];
366 double LeafValue(std::int32_t nid, std::int32_t dim_id)
const {
367 CHECK_LT(dim_id, output_dimension_);
368 if (output_dimension_ == 1 && dim_id == 0) {
369 return leaf_value_[nid];
371 std::size_t
const offset_begin = leaf_vector_begin_[nid];
372 std::size_t
const offset_end = leaf_vector_end_[nid];
373 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
374 Log::Fatal(
"No leaf vector set for node nid");
376 return leaf_vector_[offset_begin + dim_id];
384 std::int32_t max_depth = 0;
385 std::stack<std::int32_t> nodes;
386 std::stack<std::int32_t> node_depths;
390 while (!nodes.empty()) {
391 auto nidx = nodes.top();
393 auto node_depth = node_depths.top();
395 bool valid_node = !self.IsDeleted(nidx);
397 if (node_depth > max_depth) max_depth = node_depth;
398 auto left = self.LeftChild(nidx);
399 auto right = self.RightChild(nidx);
400 if (left != Tree::kInvalidNodeId) {
402 node_depths.push(node_depth+1);
404 if (right != Tree::kInvalidNodeId) {
406 node_depths.push(node_depth+1);
418 std::size_t
const offset_begin = leaf_vector_begin_[nid];
419 std::size_t
const offset_end = leaf_vector_end_[nid];
420 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
422 return std::vector<double>();
424 return std::vector<double>(&leaf_vector_[offset_begin], &leaf_vector_[offset_end]);
434 if (output_dimension_ == 1) {
435 return std::pow(leaf_value_[nid], 2.0);
438 std::size_t
const offset_begin = leaf_vector_begin_[nid];
439 std::size_t
const offset_end = leaf_vector_end_[nid];
440 if (offset_begin >= leaf_vector_.size() || offset_end > leaf_vector_.size()) {
441 Log::Fatal(
"No leaf vector set for node nid");
443 for (std::size_t i = offset_begin; i < offset_end; i++) {
444 result += std::pow(leaf_vector_[i], 2.0);
455 for (
auto& leaf : leaves_) {
466 return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
474 return threshold_[nid];
485 std::size_t
const offset_begin = category_list_begin_[nid];
486 std::size_t
const offset_end = category_list_end_[nid];
487 if (offset_begin >= category_list_.size() || offset_end > category_list_.size()) {
492 return std::vector<std::uint32_t>(&category_list_[offset_begin], &category_list_[offset_end]);
502 return node_type_[nid];
510 return node_type_[nid] == TreeNodeType::kNumericalSplitNode;
518 return node_type_[nid] == TreeNodeType::kCategoricalSplitNode;
525 return has_categorical_split_;
529 [[nodiscard]] std::int32_t NumLeaves()
const;
530 [[nodiscard]] std::int32_t NumLeafParents()
const;
531 [[nodiscard]] std::int32_t NumSplitNodes()
const;
534 [[nodiscard]]
bool IsLeafParent(std::int32_t nid)
const {
537 bool is_left_leaf =
false;
538 bool is_right_leaf =
false;
540 bool is_leaf = this->
IsLeaf(nid);
548 is_left_leaf =
IsLeaf(left_node);
549 is_right_leaf =
IsLeaf(right_node);
551 return is_left_leaf && is_right_leaf;
558 return internal_nodes_;
564 [[nodiscard]] std::vector<std::int32_t>
const&
GetLeaves()
const {
572 return leaf_parents_;
578 [[nodiscard]] std::vector<std::int32_t>
GetNodes() {
579 std::vector<std::int32_t> output;
580 auto const& self = *
this;
581 this->
WalkTree([&output, &self](std::int32_t nidx) {
582 if (!self.IsDeleted(nidx)) {
583 output.push_back(nidx);
594 [[nodiscard]] std::int32_t
GetDepth(std::int32_t nid)
const {
606 [[nodiscard]] std::int32_t
NumNodes() const noexcept {
return num_nodes; }
611 [[nodiscard]] std::int32_t
NumDeletedNodes() const noexcept {
return num_deleted_nodes; }
617 return num_nodes - num_deleted_nodes;
627 cleft_[nid] = left_child;
636 cright_[nid] = right_child;
645 void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
655 void SetParent(std::int32_t child_node, std::int32_t parent_node) {
656 parent_[child_node] = parent_node;
665 void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child) {
677 std::int32_t nid, std::int32_t split_index,
double threshold);
688 std::vector<std::uint32_t>
const& category_list);
702 void SetLeafVector(std::int32_t nid, std::vector<double>
const& leaf_vector);
768 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int32_t offset, int32_t max_leaf);
770 void PredictLeafIndexInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
771 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
772 int column_ind, int32_t offset, int32_t max_leaf);
775 std::vector<TreeNodeType> node_type_;
776 std::vector<std::int32_t> parent_;
777 std::vector<std::int32_t> cleft_;
778 std::vector<std::int32_t> cright_;
779 std::vector<std::int32_t> split_index_;
780 std::vector<double> leaf_value_;
781 std::vector<double> threshold_;
782 std::vector<bool> node_deleted_;
783 std::vector<std::int32_t> internal_nodes_;
784 std::vector<std::int32_t> leaves_;
785 std::vector<std::int32_t> leaf_parents_;
786 std::vector<std::int32_t> deleted_nodes_;
789 std::vector<double> leaf_vector_;
790 std::vector<std::uint64_t> leaf_vector_begin_;
791 std::vector<std::uint64_t> leaf_vector_end_;
794 std::vector<std::uint32_t> category_list_;
795 std::vector<std::uint64_t> category_list_begin_;
796 std::vector<std::uint64_t> category_list_end_;
798 bool has_categorical_split_{
false};
799 int output_dimension_{1};
800 bool is_log_scale_{
false};
806 (lhs.has_categorical_split_ == rhs.has_categorical_split_) &&
807 (lhs.output_dimension_ == rhs.output_dimension_) &&
808 (lhs.is_log_scale_ == rhs.is_log_scale_) &&
809 (lhs.node_type_ == rhs.node_type_) &&
810 (lhs.parent_ == rhs.parent_) &&
811 (lhs.cleft_ == rhs.cleft_) &&
812 (lhs.cright_ == rhs.cright_) &&
813 (lhs.split_index_ == rhs.split_index_) &&
814 (lhs.leaf_value_ == rhs.leaf_value_) &&
815 (lhs.threshold_ == rhs.threshold_) &&
816 (lhs.internal_nodes_ == rhs.internal_nodes_) &&
817 (lhs.leaves_ == rhs.leaves_) &&
818 (lhs.leaf_parents_ == rhs.leaf_parents_) &&
819 (lhs.deleted_nodes_ == rhs.deleted_nodes_) &&
820 (lhs.leaf_vector_ == rhs.leaf_vector_) &&
821 (lhs.leaf_vector_begin_ == rhs.leaf_vector_begin_) &&
822 (lhs.leaf_vector_end_ == rhs.leaf_vector_end_) &&
823 (lhs.category_list_ == rhs.category_list_) &&
824 (lhs.category_list_begin_ == rhs.category_list_begin_) &&
825 (lhs.category_list_end_ == rhs.category_list_end_)
836 return (fvalue <= threshold);
846 bool category_matched;
850 auto max_representable_int
851 = std::min(
static_cast<double>(std::numeric_limits<std::uint32_t>::max()),
852 static_cast<double>(std::uint64_t(1) << std::numeric_limits<double>::digits));
853 if (fvalue < 0 || std::fabs(fvalue) > max_representable_int) {
854 category_matched =
false;
856 auto const category_value =
static_cast<std::uint32_t
>(fvalue);
857 category_matched = (std::find(category_list.begin(), category_list.end(), category_value)
858 != category_list.end());
860 return category_matched;
869inline int NextNodeNumeric(
double fvalue,
double threshold,
int left_child,
int right_child) {
879inline int NextNodeCategorical(
double fvalue, std::vector<std::uint32_t>
const& category_list,
int left_child,
int right_child) {
892 while (!tree.
IsLeaf(node_id)) {
893 auto const split_index = tree.
SplitIndex(node_id);
894 double const fvalue = data(row, split_index);
895 if (std::isnan(fvalue)) {
898 if (tree.
NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
916inline int EvaluateTree(
Tree const& tree, Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& data,
int row) {
918 while (!tree.
IsLeaf(node_id)) {
919 auto const split_index = tree.
SplitIndex(node_id);
920 double const fvalue = data(row, split_index);
921 if (std::isnan(fvalue)) {
924 if (tree.
NodeType(node_id) == StochTree::TreeNodeType::kCategoricalSplitNode) {
941inline bool RowSplitLeft(Eigen::MatrixXd& covariates,
int row,
int split_index,
double split_value) {
942 double const fvalue = covariates(row, split_index);
952inline bool RowSplitLeft(Eigen::MatrixXd& covariates,
int row,
int split_index, std::vector<std::uint32_t>
const& category_list) {
953 double const fvalue = covariates(row, split_index);
968 split_value_ = split_value;
976 TreeSplit(std::vector<std::uint32_t>& split_categories) {
978 split_categories_ = split_categories;
982 bool SplitSet() {
return split_set_;}
999 bool split_set_{
false};
1001 double split_value_;
1002 std::vector<std::uint32_t> split_categories_;
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
Representation of arbitrary tree split rules, including numeric split rules (X[,i] <= c) and categori...
Definition tree.h:958
bool SplitTrue(double fvalue)
Whether a given covariate value is True or False on the rule defined by a TreeSplit object.
Definition tree.h:990
bool NumericSplit()
Whether or not a TreeSplit rule is numeric.
Definition tree.h:984
std::vector< std::uint32_t > SplitCategories()
Categories defining a TreeSplit object.
Definition tree.h:997
TreeSplit(std::vector< std::uint32_t > &split_categories)
Construct a categorical TreeSplit.
Definition tree.h:976
double SplitValue()
Numeric cutoff value defining a TreeSplit object.
Definition tree.h:995
TreeSplit(double split_value)
Construct a numeric TreeSplit.
Definition tree.h:966
Decision tree data structure.
Definition tree.h:66
void SetNumericSplit(std::int32_t nid, std::int32_t split_index, double threshold)
Create a numerical split.
void SetLeftChild(std::int32_t nid, std::int32_t left_child)
Identify left child node.
Definition tree.h:626
std::int32_t LeftChild(std::int32_t nid) const
Index of the node's left child.
Definition tree.h:301
void PredictLeafIndexInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
void CloneFromTree(Tree *tree)
Copy the structure and parameters of another tree. If the Tree object calling this method already has...
std::int32_t RightChild(std::int32_t nid) const
Index of the node's right child.
Definition tree.h:309
std::int32_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition tree.h:616
void SetCategoricalSplit(std::int32_t nid, std::int32_t split_index, std::vector< std::uint32_t > const &category_list)
Create a categorical split.
std::int32_t SplitIndex(std::int32_t nid) const
Feature index defining the node's split rule.
Definition tree.h:325
void PredictLeafIndexInplace(Eigen::MatrixXd &covariates, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsLeaf(std::int32_t nid) const
Whether the node is a leaf node.
Definition tree.h:333
bool IsNumericSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:509
void WalkTree(Func func) const
Iterate through all nodes in this tree.
Definition tree.h:241
double SumSquaredLeafValues() const
Sum of squared values for all leaves in a tree.
Definition tree.h:453
std::int32_t DefaultChild(std::int32_t nid) const
Index of the node's "default" child (potentially used in the case of a missing feature at prediction ...
Definition tree.h:317
void AddValueToLeaves(double constant_value)
Add a constant value to every leaf of a tree. If leaves are multi-dimensional, constant_value will be...
Definition tree.h:206
int AllocNode()
Allocate a new node and return the node's ID.
void ExpandNode(std::int32_t nid, int split_index, double split_value, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a numeric split rule.
std::vector< double > LeafVector(std::int32_t nid) const
Get vector-valued parameters of a node (typically leaf)
Definition tree.h:417
std::int32_t Parent(std::int32_t nid) const
Index of the node's parent.
Definition tree.h:293
void PredictLeafIndexInplace(ForestDataset *dataset, std::vector< int32_t > &output, int32_t offset, int32_t max_leaf)
Obtain a 0-based leaf index for each observation in a ForestDataset. Internally, trees are stored as ...
bool IsCategoricalSplitNode(std::int32_t nid) const
Whether the node is a numeric split node.
Definition tree.h:517
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition tree.h:524
double LeafValue(std::int32_t nid, std::int32_t dim_id) const
Get parameter value of a node (typically though not necessarily a leaf node) at a given output dimens...
Definition tree.h:366
void from_json(const json &tree_json)
Load from JSON.
std::int32_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition tree.h:606
TreeNodeType NodeType(std::int32_t nid) const
Get the type of a node (i.e. numeric split, categorical split, leaf)
Definition tree.h:501
double SumSquaredNodeValues(std::int32_t nid) const
Sum of squared parameter values for a given node (typically though not necessarily a leaf node)
Definition tree.h:433
void DeleteNode(std::int32_t nid)
Deletes node indexed by node ID.
std::int32_t MaxLeafDepth() const
Get maximum depth of all of the leaf nodes.
Definition tree.h:383
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, double left_value, double right_value)
Expand a node based on a generic split rule.
void ExpandNode(std::int32_t nid, int split_index, TreeSplit &split, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a generic split rule.
std::int32_t OutputDimension() const
Dimension of tree output.
Definition tree.h:278
bool HasVectorOutput() const
Whether or not a tree has vector output.
Definition tree.h:271
void MultiplyLeavesByValue(double constant_multiple)
Multiply every leaf of a tree by a constant value. If leaves are multi-dimensional,...
Definition tree.h:223
std::int32_t NumDeletedNodes() const noexcept
Get the total number of deleted nodes in this tree.
Definition tree.h:611
void SetChildren(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify two child nodes of the node and the corresponding parent node of the child nodes.
Definition tree.h:645
std::vector< std::uint32_t > CategoryList(std::int32_t nid) const
Get list of all categories belonging to the left child node. Categories are integers ranging from 0 t...
Definition tree.h:484
void SetParent(std::int32_t child_node, std::int32_t parent_node)
Identify parent node.
Definition tree.h:655
void SetLeaf(std::int32_t nid, double value)
Set the leaf value of the node.
bool IsRoot()
Whether or not a tree is a "stump" consisting of a single root node.
Definition tree.h:111
double LeafValue(std::int32_t nid) const
Get parameter value of a node (typically though not necessarily a leaf node)
Definition tree.h:357
void SetParents(std::int32_t nid, std::int32_t left_child, std::int32_t right_child)
Identify parent node of the left and right node ids.
Definition tree.h:665
std::vector< std::int32_t > const & GetInternalNodes() const
Get indices of all internal nodes.
Definition tree.h:557
std::vector< std::int32_t > const & GetLeafParents() const
Get indices of all leaf parent nodes.
Definition tree.h:571
std::vector< std::int32_t > GetNodes()
Get indices of all valid (non-deleted) nodes.
Definition tree.h:578
bool HasLeafVector(std::int32_t nid) const
Tests whether the leaf node has a non-empty leaf vector.
Definition tree.h:465
void Init(int output_dimension=1, bool is_log_scale=false)
Initialize the tree with a single root node.
void CollapseToLeaf(std::int32_t nid, std::vector< double > value_vector)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:188
double Threshold(std::int32_t nid) const
Get split threshold of the node.
Definition tree.h:473
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, double left_value, double right_value)
Expand a node based on a categorical split rule.
bool IsRoot(std::int32_t nid) const
Whether the node is root.
Definition tree.h:341
void Reset()
Reset tree to empty vectors and default values of boolean / integer variables.
bool IsDeleted(std::int32_t nid) const
Whether the node has been deleted.
Definition tree.h:349
void ExpandNode(std::int32_t nid, int split_index, double split_value, double left_value, double right_value)
Expand a node based on a numeric split rule.
void SetRightChild(std::int32_t nid, std::int32_t right_child)
Identify right child node.
Definition tree.h:635
std::vector< std::int32_t > const & GetLeaves() const
Get indices of all leaf nodes.
Definition tree.h:564
void ExpandNode(std::int32_t nid, int split_index, std::vector< std::uint32_t > const &categorical_indices, std::vector< double > left_value_vector, std::vector< double > right_value_vector)
Expand a node based on a categorical split rule.
bool IsLogScale() const
Whether or not tree parameters should be exponentiated at prediction time.
Definition tree.h:285
void CollapseToLeaf(std::int32_t nid, double value)
Collapse an internal node to a leaf node, deleting its children from the tree.
Definition tree.h:149
void SetLeafVector(std::int32_t nid, std::vector< double > const &leaf_vector)
Set the leaf vector of the node; useful for multi-output trees.
std::int32_t GetDepth(std::int32_t nid) const
Get the depth of a node.
Definition tree.h:594
json to_json()
Convert tree to JSON and return JSON in-memory.
bool SplitTrueNumeric(double fvalue, double threshold)
Determine whether an observation produces a "true" value in a numeric split node.
Definition tree.h:835
int NextNodeCategorical(double fvalue, std::vector< std::uint32_t > const &category_list, int left_child, int right_child)
Return left or right node id based on a categorical split.
Definition tree.h:879
bool SplitTrueCategorical(double fvalue, std::vector< std::uint32_t > const &category_list)
Determine whether an observation produces a "true" value in a categorical split node.
Definition tree.h:845
int EvaluateTree(Tree const &tree, Eigen::MatrixXd &data, int row)
Definition tree.h:890
int NextNodeNumeric(double fvalue, double threshold, int left_child, int right_child)
Return left or right node id based on a numeric split.
Definition tree.h:869
bool operator==(const Tree &lhs, const Tree &rhs)
Comparison operator for trees.
Definition tree.h:804
bool RowSplitLeft(Eigen::MatrixXd &covariates, int row, int split_index, double split_value)
Determine whether a given observation is "true" at a split proposed by split_index and split_value.
Definition tree.h:941
Definition category_tracker.h:36
std::string TreeNodeTypeToString(TreeNodeType type)
Get string representation of TreeNodeType.
TreeNodeType
Tree node type.
Definition tree.h:24
TreeNodeType TreeNodeTypeFromString(std::string const &name)
Get NodeType from string.