StochTree 0.1.1
Loading...
Searching...
No Matches
ensemble.h
1
10#ifndef STOCHTREE_ENSEMBLE_H_
11#define STOCHTREE_ENSEMBLE_H_
12
13#include <stochtree/data.h>
14#include <stochtree/tree.h>
15#include <nlohmann/json.hpp>
16
17using json = nlohmann::json;
18
19namespace StochTree {
20
32 public:
41 TreeEnsemble(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) {
42 // Initialize trees in the ensemble
43 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees);
44 for (int i = 0; i < num_trees; i++) {
45 trees_[i].reset(new Tree());
46 trees_[i]->Init(output_dimension, is_exponentiated);
47 }
48 // Store ensemble configurations
49 num_trees_ = num_trees;
50 output_dimension_ = output_dimension;
51 is_leaf_constant_ = is_leaf_constant;
52 is_exponentiated_ = is_exponentiated;
53 }
54
61 // Unpack ensemble configurations
62 num_trees_ = ensemble.num_trees_;
63 output_dimension_ = ensemble.output_dimension_;
64 is_leaf_constant_ = ensemble.is_leaf_constant_;
65 is_exponentiated_ = ensemble.is_exponentiated_;
66 // Initialize trees in the ensemble
67 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees_);
68 for (int i = 0; i < num_trees_; i++) {
69 trees_[i].reset(new Tree());
70 }
71 // Clone trees in the ensemble
72 for (int j = 0; j < num_trees_; j++) {
73 Tree* tree = ensemble.GetTree(j);
74 this->CloneFromExistingTree(j, tree);
75 }
76 }
77
78 ~TreeEnsemble() {}
79
85 void MergeForest(TreeEnsemble& ensemble) {
86 // Unpack ensemble configurations
87 int old_num_trees = num_trees_;
88 num_trees_ += ensemble.num_trees_;
89 CHECK_EQ(output_dimension_, ensemble.output_dimension_);
90 CHECK_EQ(is_leaf_constant_, ensemble.is_leaf_constant_);
91 CHECK_EQ(is_exponentiated_, ensemble.is_exponentiated_);
92 // Resize tree vector and reset new trees
93 trees_.resize(num_trees_);
94 for (int i = old_num_trees; i < num_trees_; i++) {
95 trees_[i].reset(new Tree());
96 }
97 // Clone trees in the input ensemble
98 for (int j = 0; j < ensemble.num_trees_; j++) {
99 Tree* tree = ensemble.GetTree(j);
100 this->CloneFromExistingTree(old_num_trees + j, tree);
101 }
102 }
103
109 void AddValueToLeaves(double constant_value) {
110 for (int j = 0; j < num_trees_; j++) {
111 Tree* tree = GetTree(j);
112 tree->AddValueToLeaves(constant_value);
113 }
114 }
115
121 void MultiplyLeavesByValue(double constant_multiple) {
122 for (int j = 0; j < num_trees_; j++) {
123 Tree* tree = GetTree(j);
124 tree->MultiplyLeavesByValue(constant_multiple);
125 }
126 }
127
134 inline Tree* GetTree(int i) {
135 return trees_[i].get();
136 }
137
141 inline void ResetRoot() {
142 for (int i = 0; i < num_trees_; i++) {
143 ResetInitTree(i);
144 }
145 }
146
153 inline void ResetTree(int i) {
154 trees_[i].reset(new Tree());
155 }
156
163 inline void ResetInitTree(int i) {
164 trees_[i].reset(new Tree());
165 trees_[i]->Init(output_dimension_, is_exponentiated_);
166 }
167
174 inline void CloneFromExistingTree(int i, Tree* tree) {
175 return trees_[i]->CloneFromTree(tree);
176 }
177
183 inline void ReconstituteFromForest(TreeEnsemble& ensemble) {
184 // Delete old tree pointers
185 trees_.clear();
186 // Unpack ensemble configurations
187 num_trees_ = ensemble.num_trees_;
188 output_dimension_ = ensemble.output_dimension_;
189 is_leaf_constant_ = ensemble.is_leaf_constant_;
190 is_exponentiated_ = ensemble.is_exponentiated_;
191 // Initialize trees in the ensemble
192 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees_);
193 for (int i = 0; i < num_trees_; i++) {
194 trees_[i].reset(new Tree());
195 }
196 // Clone trees in the ensemble
197 for (int j = 0; j < num_trees_; j++) {
198 Tree* tree = ensemble.GetTree(j);
199 this->CloneFromExistingTree(j, tree);
200 }
201 }
202
203 std::vector<double> Predict(ForestDataset& dataset) {
204 data_size_t n = dataset.NumObservations();
205 std::vector<double> output(n);
206 PredictInplace(dataset, output, 0);
207 return output;
208 }
209
210 std::vector<double> PredictRaw(ForestDataset& dataset) {
211 data_size_t n = dataset.NumObservations();
212 data_size_t total_output_size = n * output_dimension_;
213 std::vector<double> output(total_output_size);
214 PredictRawInplace(dataset, output, 0);
215 return output;
216 }
217
218 inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output, data_size_t offset = 0) {
219 PredictInplace(dataset, output, 0, trees_.size(), offset);
220 }
221
222 inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output,
223 int tree_begin, int tree_end, data_size_t offset = 0) {
224 if (is_leaf_constant_) {
225 PredictInplace(dataset.GetCovariates(), output, tree_begin, tree_end, offset);
226 } else {
227 CHECK(dataset.HasBasis());
228 PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
229 }
230 }
231
232 inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector<double> &output, data_size_t offset = 0) {
233 PredictInplace(covariates, basis, output, 0, trees_.size(), offset);
234 }
235
236 inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector<double> &output,
237 int tree_begin, int tree_end, data_size_t offset = 0) {
238 double pred;
239 CHECK_EQ(covariates.rows(), basis.rows());
240 CHECK_EQ(output_dimension_, trees_[0]->OutputDimension());
241 CHECK_EQ(output_dimension_, basis.cols());
242 data_size_t n = covariates.rows();
243 data_size_t total_output_size = n;
244 if (output.size() < total_output_size + offset) {
245 Log::Fatal("Mismatched size of prediction vector and training data");
246 }
247 for (data_size_t i = 0; i < n; i++) {
248 pred = 0.0;
249 for (size_t j = tree_begin; j < tree_end; j++) {
250 auto &tree = *trees_[j];
251 std::int32_t nidx = EvaluateTree(tree, covariates, i);
252 for (int32_t k = 0; k < output_dimension_; k++) {
253 pred += tree.LeafValue(nidx, k) * basis(i, k);
254 }
255 }
256 if (is_exponentiated_) output[i + offset] = std::exp(pred);
257 else output[i + offset] = pred;
258 }
259 }
260
261 inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector<double> &output, data_size_t offset = 0) {
262 PredictInplace(covariates, output, 0, trees_.size(), offset);
263 }
264
265 inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector<double> &output, int tree_begin, int tree_end, data_size_t offset = 0) {
266 double pred;
267 data_size_t n = covariates.rows();
268 data_size_t total_output_size = n;
269 if (output.size() < total_output_size + offset) {
270 Log::Fatal("Mismatched size of prediction vector and training data");
271 }
272 for (data_size_t i = 0; i < n; i++) {
273 pred = 0.0;
274 for (size_t j = tree_begin; j < tree_end; j++) {
275 auto &tree = *trees_[j];
276 std::int32_t nidx = EvaluateTree(tree, covariates, i);
277 pred += tree.LeafValue(nidx, 0);
278 }
279 if (is_exponentiated_) output[i + offset] = std::exp(pred);
280 else output[i + offset] = pred;
281 }
282 }
283
284 inline void PredictRawInplace(ForestDataset& dataset, std::vector<double> &output, data_size_t offset = 0) {
285 PredictRawInplace(dataset, output, 0, trees_.size(), offset);
286 }
287
288 inline void PredictRawInplace(ForestDataset& dataset, std::vector<double> &output,
289 int tree_begin, int tree_end, data_size_t offset = 0) {
290 double pred;
291 Eigen::MatrixXd covariates = dataset.GetCovariates();
292 CHECK_EQ(output_dimension_, trees_[0]->OutputDimension());
293 data_size_t n = covariates.rows();
294 data_size_t total_output_size = n * output_dimension_;
295 if (output.size() < total_output_size + offset) {
296 Log::Fatal("Mismatched size of raw prediction vector and training data");
297 }
298 for (data_size_t i = 0; i < n; i++) {
299 for (int32_t k = 0; k < output_dimension_; k++) {
300 pred = 0.0;
301 for (size_t j = tree_begin; j < tree_end; j++) {
302 auto &tree = *trees_[j];
303 int32_t nidx = EvaluateTree(tree, covariates, i);
304 pred += tree.LeafValue(nidx, k);
305 }
306 output[i*output_dimension_ + k + offset] = pred;
307 }
308 }
309 }
310
311 inline int32_t NumTrees() {
312 return num_trees_;
313 }
314
315 inline int32_t NumLeaves() {
316 int32_t result = 0;
317 for (int i = 0; i < num_trees_; i++) {
318 result += trees_[i]->NumLeaves();
319 }
320 return result;
321 }
322
323 inline double SumLeafSquared() {
324 double result = 0.;
325 for (int i = 0; i < num_trees_; i++) {
326 result += trees_[i]->SumSquaredLeafValues();
327 }
328 return result;
329 }
330
331 inline int32_t OutputDimension() {
332 return output_dimension_;
333 }
334
335 inline bool IsLeafConstant() {
336 return is_leaf_constant_;
337 }
338
339 inline bool IsExponentiated() {
340 return is_exponentiated_;
341 }
342
343 inline int32_t TreeMaxDepth(int tree_num) {
344 return trees_[tree_num]->MaxLeafDepth();
345 }
346
347 inline double AverageMaxDepth() {
348 double numerator = 0.;
349 double denominator = 0.;
350 for (int i = 0; i < num_trees_; i++) {
351 numerator += static_cast<double>(TreeMaxDepth(i));
352 denominator += 1.;
353 }
354 return numerator / denominator;
355 }
356
357 inline bool AllRoots() {
358 for (int i = 0; i < num_trees_; i++) {
359 if (!trees_[i]->IsRoot()) {
360 return false;
361 }
362 }
363 return true;
364 }
365
366 inline void SetLeafValue(double leaf_value) {
367 CHECK_EQ(output_dimension_, 1);
368 for (int i = 0; i < num_trees_; i++) {
369 CHECK(trees_[i]->IsRoot());
370 trees_[i]->SetLeaf(0, leaf_value);
371 }
372 }
373
374 inline void SetLeafVector(std::vector<double>& leaf_vector) {
375 CHECK_EQ(output_dimension_, leaf_vector.size());
376 for (int i = 0; i < num_trees_; i++) {
377 CHECK(trees_[i]->IsRoot());
378 trees_[i]->SetLeafVector(0, leaf_vector);
379 }
380 }
381
388 int max_leaf = 0;
389 for (int j = 0; j < num_trees_; j++) {
390 auto &tree = *trees_[j];
391 max_leaf += tree.NumLeaves();
392 }
393 return max_leaf;
394 }
395
413 void PredictLeafIndicesInplace(ForestDataset* dataset, std::vector<int32_t>& output, int num_trees, data_size_t n) {
414 PredictLeafIndicesInplace(dataset->GetCovariates(), output, num_trees, n);
415 }
416
434 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int num_trees, data_size_t n) {
435 CHECK_GE(output.size(), num_trees*n);
436 int offset = 0;
437 int max_leaf = 0;
438 for (int j = 0; j < num_trees; j++) {
439 auto &tree = *trees_[j];
440 int num_leaves = tree.NumLeaves();
441 tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf);
442 offset += n;
443 max_leaf += num_leaves;
444 }
445 }
446
465 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
466 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
467 int column_ind, int num_trees, data_size_t n) {
468 CHECK_GE(output.size(), num_trees*n);
469 int offset = 0;
470 int max_leaf = 0;
471 for (int j = 0; j < num_trees; j++) {
472 auto &tree = *trees_[j];
473 int num_leaves = tree.NumLeaves();
474 tree.PredictLeafIndexInplace(covariates, output, column_ind, offset, max_leaf);
475 offset += n;
476 max_leaf += num_leaves;
477 }
478 }
479
497 void PredictLeafIndicesInplace(Eigen::MatrixXd& covariates, std::vector<int32_t>& output, int num_trees, data_size_t n) {
498 CHECK_GE(output.size(), num_trees*n);
499 int offset = 0;
500 int max_leaf = 0;
501 for (int j = 0; j < num_trees; j++) {
502 auto &tree = *trees_[j];
503 int num_leaves = tree.NumLeaves();
504 tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf);
505 offset += n;
506 max_leaf += num_leaves;
507 }
508 }
509
514 std::vector<int32_t> PredictLeafIndices(ForestDataset* dataset) {
515 int num_trees = num_trees_;
516 data_size_t n = dataset->NumObservations();
517 std::vector<int32_t> output(n*num_trees);
518 PredictLeafIndicesInplace(dataset, output, num_trees, n);
519 return output;
520 }
521
523 json to_json() {
524 json result_obj;
525 result_obj.emplace("num_trees", this->num_trees_);
526 result_obj.emplace("output_dimension", this->output_dimension_);
527 result_obj.emplace("is_leaf_constant", this->is_leaf_constant_);
528 result_obj.emplace("is_exponentiated", this->is_exponentiated_);
529
530 std::string tree_label;
531 for (int i = 0; i < trees_.size(); i++) {
532 tree_label = "tree_" + std::to_string(i);
533 result_obj.emplace(tree_label, trees_[i]->to_json());
534 }
535
536 return result_obj;
537 }
538
540 void from_json(const json& ensemble_json) {
541 this->num_trees_ = ensemble_json.at("num_trees");
542 this->output_dimension_ = ensemble_json.at("output_dimension");
543 this->is_leaf_constant_ = ensemble_json.at("is_leaf_constant");
544 this->is_exponentiated_ = ensemble_json.at("is_exponentiated");
545
546 std::string tree_label;
547 trees_.clear();
548 trees_.resize(this->num_trees_);
549 for (int i = 0; i < this->num_trees_; i++) {
550 tree_label = "tree_" + std::to_string(i);
551 trees_[i] = std::make_unique<Tree>();
552 trees_[i]->from_json(ensemble_json.at(tree_label));
553 }
554 }
555
556 private:
557 std::vector<std::unique_ptr<Tree>> trees_;
558 int num_trees_;
559 int output_dimension_;
560 bool is_leaf_constant_;
561 bool is_exponentiated_;
562};
563
// end of forest_group
565
566} // namespace StochTree
567
568#endif // STOCHTREE_ENSEMBLE_H_
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
data_size_t NumObservations()
Number of observations (rows) in the dataset.
Definition data.h:353
Eigen::MatrixXd & GetCovariates()
Return a reference to the raw Eigen::MatrixXd storing the covariate data.
Definition data.h:383
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:31
Tree * GetTree(int i)
Return a pointer to a tree in the forest.
Definition ensemble.h:134
void ResetInitTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:163
int GetMaxLeafIndex()
Obtain a 0-based "maximum" leaf index for an ensemble, which is equivalent to the sum of the number o...
Definition ensemble.h:387
TreeEnsemble(TreeEnsemble &ensemble)
Initialize an ensemble based on the state of an existing ensemble.
Definition ensemble.h:60
void ResetTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:153
void ReconstituteFromForest(TreeEnsemble &ensemble)
Reset an ensemble to clone another ensemble.
Definition ensemble.h:183
void AddValueToLeaves(double constant_value)
Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional,...
Definition ensemble.h:109
json to_json()
Save to JSON.
Definition ensemble.h:523
void MergeForest(TreeEnsemble &ensemble)
Combine two forests into a single forest by merging their trees.
Definition ensemble.h:85
void PredictLeafIndicesInplace(ForestDataset *dataset, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:413
void PredictLeafIndicesInplace(Eigen::MatrixXd &covariates, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:497
void ResetRoot()
Reset a TreeEnsemble to all single-node "root" trees.
Definition ensemble.h:141
void from_json(const json &ensemble_json)
Load from JSON.
Definition ensemble.h:540
std::vector< int32_t > PredictLeafIndices(ForestDataset *dataset)
Same as PredictLeafIndicesInplace but assumes responsibility for allocating and returning output vect...
Definition ensemble.h:514
void PredictLeafIndicesInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:434
void MultiplyLeavesByValue(double constant_multiple)
Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional,...
Definition ensemble.h:121
TreeEnsemble(int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Initialize a new TreeEnsemble.
Definition ensemble.h:41
void CloneFromExistingTree(int i, Tree *tree)
Clone a single tree in an ensemble from an existing tree, overwriting current tree.
Definition ensemble.h:174
void PredictLeafIndicesInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, Eigen::Map< Eigen::Matrix< int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &output, int column_ind, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:465
Decision tree data structure.
Definition tree.h:66
void AddValueToLeaves(double constant_value)
Add a constant value to every leaf of a tree. If leaves are multi-dimensional, constant_value will be...
Definition tree.h:206
void MultiplyLeavesByValue(double constant_multiple)
Multiply every leaf of a tree by a constant value. If leaves are multi-dimensional,...
Definition tree.h:223
int EvaluateTree(Tree const &tree, Eigen::MatrixXd &data, int row)
Definition tree.h:890
Definition category_tracker.h:36