StochTree 0.1.1
Loading...
Searching...
No Matches
ensemble.h
1
10#ifndef STOCHTREE_ENSEMBLE_H_
11#define STOCHTREE_ENSEMBLE_H_
12
13#include <stochtree/data.h>
14#include <stochtree/tree.h>
15#include <nlohmann/json.hpp>
16
17#include <algorithm>
18#include <deque>
19#include <optional>
20#include <random>
21#include <unordered_map>
22
23using json = nlohmann::json;
24
25namespace StochTree {
26
38 public:
47 TreeEnsemble(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) {
48 // Initialize trees in the ensemble
49 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees);
50 for (int i = 0; i < num_trees; i++) {
51 trees_[i].reset(new Tree());
52 trees_[i]->Init(output_dimension, is_exponentiated);
53 }
54 // Store ensemble configurations
55 num_trees_ = num_trees;
56 output_dimension_ = output_dimension;
57 is_leaf_constant_ = is_leaf_constant;
58 is_exponentiated_ = is_exponentiated;
59 }
60
67 // Unpack ensemble configurations
68 num_trees_ = ensemble.num_trees_;
69 output_dimension_ = ensemble.output_dimension_;
70 is_leaf_constant_ = ensemble.is_leaf_constant_;
71 is_exponentiated_ = ensemble.is_exponentiated_;
72 // Initialize trees in the ensemble
73 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees_);
74 for (int i = 0; i < num_trees_; i++) {
75 trees_[i].reset(new Tree());
76 }
77 // Clone trees in the ensemble
78 for (int j = 0; j < num_trees_; j++) {
79 Tree* tree = ensemble.GetTree(j);
80 this->CloneFromExistingTree(j, tree);
81 }
82 }
83
84 ~TreeEnsemble() {}
85
91 void MergeForest(TreeEnsemble& ensemble) {
92 // Unpack ensemble configurations
93 int old_num_trees = num_trees_;
94 num_trees_ += ensemble.num_trees_;
95 CHECK_EQ(output_dimension_, ensemble.output_dimension_);
96 CHECK_EQ(is_leaf_constant_, ensemble.is_leaf_constant_);
97 CHECK_EQ(is_exponentiated_, ensemble.is_exponentiated_);
98 // Resize tree vector and reset new trees
99 trees_.resize(num_trees_);
100 for (int i = old_num_trees; i < num_trees_; i++) {
101 trees_[i].reset(new Tree());
102 }
103 // Clone trees in the input ensemble
104 for (int j = 0; j < ensemble.num_trees_; j++) {
105 Tree* tree = ensemble.GetTree(j);
106 this->CloneFromExistingTree(old_num_trees + j, tree);
107 }
108 }
109
115 void AddValueToLeaves(double constant_value) {
116 for (int j = 0; j < num_trees_; j++) {
117 Tree* tree = GetTree(j);
118 tree->AddValueToLeaves(constant_value);
119 }
120 }
121
127 void MultiplyLeavesByValue(double constant_multiple) {
128 for (int j = 0; j < num_trees_; j++) {
129 Tree* tree = GetTree(j);
130 tree->MultiplyLeavesByValue(constant_multiple);
131 }
132 }
133
140 inline Tree* GetTree(int i) {
141 return trees_[i].get();
142 }
143
147 inline void ResetRoot() {
148 for (int i = 0; i < num_trees_; i++) {
149 ResetInitTree(i);
150 }
151 }
152
159 inline void ResetTree(int i) {
160 trees_[i].reset(new Tree());
161 }
162
169 inline void ResetInitTree(int i) {
170 trees_[i].reset(new Tree());
171 trees_[i]->Init(output_dimension_, is_exponentiated_);
172 }
173
180 inline void CloneFromExistingTree(int i, Tree* tree) {
181 return trees_[i]->CloneFromTree(tree);
182 }
183
189 inline void ReconstituteFromForest(TreeEnsemble& ensemble) {
190 // Delete old tree pointers
191 trees_.clear();
192 // Unpack ensemble configurations
193 num_trees_ = ensemble.num_trees_;
194 output_dimension_ = ensemble.output_dimension_;
195 is_leaf_constant_ = ensemble.is_leaf_constant_;
196 is_exponentiated_ = ensemble.is_exponentiated_;
197 // Initialize trees in the ensemble
198 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees_);
199 for (int i = 0; i < num_trees_; i++) {
200 trees_[i].reset(new Tree());
201 }
202 // Clone trees in the ensemble
203 for (int j = 0; j < num_trees_; j++) {
204 Tree* tree = ensemble.GetTree(j);
205 this->CloneFromExistingTree(j, tree);
206 }
207 }
208
209 std::vector<double> Predict(ForestDataset& dataset) {
210 data_size_t n = dataset.NumObservations();
211 std::vector<double> output(n);
212 PredictInplace(dataset, output, 0);
213 return output;
214 }
215
216 std::vector<double> PredictRaw(ForestDataset& dataset) {
217 data_size_t n = dataset.NumObservations();
218 data_size_t total_output_size = n * output_dimension_;
219 std::vector<double> output(total_output_size);
220 PredictRawInplace(dataset, output, 0);
221 return output;
222 }
223
224 inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output, data_size_t offset = 0) {
225 PredictInplace(dataset, output, 0, trees_.size(), offset);
226 }
227
228 inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output,
229 int tree_begin, int tree_end, data_size_t offset = 0) {
230 if (is_leaf_constant_) {
231 PredictInplace(dataset.GetCovariates(), output, tree_begin, tree_end, offset);
232 } else {
233 CHECK(dataset.HasBasis());
234 PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
235 }
236 }
237
238 inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector<double> &output, data_size_t offset = 0) {
239 PredictInplace(covariates, basis, output, 0, trees_.size(), offset);
240 }
241
242 inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector<double> &output,
243 int tree_begin, int tree_end, data_size_t offset = 0) {
244 double pred;
245 CHECK_EQ(covariates.rows(), basis.rows());
246 CHECK_EQ(output_dimension_, trees_[0]->OutputDimension());
247 CHECK_EQ(output_dimension_, basis.cols());
248 data_size_t n = covariates.rows();
249 data_size_t total_output_size = n;
250 if (output.size() < total_output_size + offset) {
251 Log::Fatal("Mismatched size of prediction vector and training data");
252 }
253 for (data_size_t i = 0; i < n; i++) {
254 pred = 0.0;
255 for (size_t j = tree_begin; j < tree_end; j++) {
256 auto &tree = *trees_[j];
257 std::int32_t nidx = EvaluateTree(tree, covariates, i);
258 for (int32_t k = 0; k < output_dimension_; k++) {
259 pred += tree.LeafValue(nidx, k) * basis(i, k);
260 }
261 }
262 if (is_exponentiated_) output[i + offset] = std::exp(pred);
263 else output[i + offset] = pred;
264 }
265 }
266
267 inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector<double> &output, data_size_t offset = 0) {
268 PredictInplace(covariates, output, 0, trees_.size(), offset);
269 }
270
271 inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector<double> &output, int tree_begin, int tree_end, data_size_t offset = 0) {
272 double pred;
273 data_size_t n = covariates.rows();
274 data_size_t total_output_size = n;
275 if (output.size() < total_output_size + offset) {
276 Log::Fatal("Mismatched size of prediction vector and training data");
277 }
278 for (data_size_t i = 0; i < n; i++) {
279 pred = 0.0;
280 for (size_t j = tree_begin; j < tree_end; j++) {
281 auto &tree = *trees_[j];
282 std::int32_t nidx = EvaluateTree(tree, covariates, i);
283 pred += tree.LeafValue(nidx, 0);
284 }
285 if (is_exponentiated_) output[i + offset] = std::exp(pred);
286 else output[i + offset] = pred;
287 }
288 }
289
290 inline void PredictRawInplace(ForestDataset& dataset, std::vector<double> &output, data_size_t offset = 0) {
291 PredictRawInplace(dataset, output, 0, trees_.size(), offset);
292 }
293
294 inline void PredictRawInplace(ForestDataset& dataset, std::vector<double> &output,
295 int tree_begin, int tree_end, data_size_t offset = 0) {
296 double pred;
297 Eigen::MatrixXd covariates = dataset.GetCovariates();
298 CHECK_EQ(output_dimension_, trees_[0]->OutputDimension());
299 data_size_t n = covariates.rows();
300 data_size_t total_output_size = n * output_dimension_;
301 if (output.size() < total_output_size + offset) {
302 Log::Fatal("Mismatched size of raw prediction vector and training data");
303 }
304 for (data_size_t i = 0; i < n; i++) {
305 for (int32_t k = 0; k < output_dimension_; k++) {
306 pred = 0.0;
307 for (size_t j = tree_begin; j < tree_end; j++) {
308 auto &tree = *trees_[j];
309 int32_t nidx = EvaluateTree(tree, covariates, i);
310 pred += tree.LeafValue(nidx, k);
311 }
312 output[i*output_dimension_ + k + offset] = pred;
313 }
314 }
315 }
316
317 inline int32_t NumTrees() {
318 return num_trees_;
319 }
320
321 inline int32_t NumLeaves() {
322 int32_t result = 0;
323 for (int i = 0; i < num_trees_; i++) {
324 result += trees_[i]->NumLeaves();
325 }
326 return result;
327 }
328
329 inline double SumLeafSquared() {
330 double result = 0.;
331 for (int i = 0; i < num_trees_; i++) {
332 result += trees_[i]->SumSquaredLeafValues();
333 }
334 return result;
335 }
336
337 inline int32_t OutputDimension() {
338 return output_dimension_;
339 }
340
341 inline bool IsLeafConstant() {
342 return is_leaf_constant_;
343 }
344
345 inline bool IsExponentiated() {
346 return is_exponentiated_;
347 }
348
349 inline int32_t TreeMaxDepth(int tree_num) {
350 return trees_[tree_num]->MaxLeafDepth();
351 }
352
353 inline double AverageMaxDepth() {
354 double numerator = 0.;
355 double denominator = 0.;
356 for (int i = 0; i < num_trees_; i++) {
357 numerator += static_cast<double>(TreeMaxDepth(i));
358 denominator += 1.;
359 }
360 return numerator / denominator;
361 }
362
363 inline bool AllRoots() {
364 for (int i = 0; i < num_trees_; i++) {
365 if (!trees_[i]->IsRoot()) {
366 return false;
367 }
368 }
369 return true;
370 }
371
372 inline void SetLeafValue(double leaf_value) {
373 CHECK_EQ(output_dimension_, 1);
374 for (int i = 0; i < num_trees_; i++) {
375 CHECK(trees_[i]->IsRoot());
376 trees_[i]->SetLeaf(0, leaf_value);
377 }
378 }
379
380 inline void SetLeafVector(std::vector<double>& leaf_vector) {
381 CHECK_EQ(output_dimension_, leaf_vector.size());
382 for (int i = 0; i < num_trees_; i++) {
383 CHECK(trees_[i]->IsRoot());
384 trees_[i]->SetLeafVector(0, leaf_vector);
385 }
386 }
387
394 int max_leaf = 0;
395 for (int j = 0; j < num_trees_; j++) {
396 auto &tree = *trees_[j];
397 max_leaf += tree.NumLeaves();
398 }
399 return max_leaf;
400 }
401
419 void PredictLeafIndicesInplace(ForestDataset* dataset, std::vector<int32_t>& output, int num_trees, data_size_t n) {
420 PredictLeafIndicesInplace(dataset->GetCovariates(), output, num_trees, n);
421 }
422
440 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int num_trees, data_size_t n) {
441 CHECK_GE(output.size(), num_trees*n);
442 int offset = 0;
443 int max_leaf = 0;
444 for (int j = 0; j < num_trees; j++) {
445 auto &tree = *trees_[j];
446 int num_leaves = tree.NumLeaves();
447 tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf);
448 offset += n;
449 max_leaf += num_leaves;
450 }
451 }
452
471 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
472 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
473 int column_ind, int num_trees, data_size_t n) {
474 CHECK_GE(output.size(), num_trees*n);
475 int offset = 0;
476 int max_leaf = 0;
477 for (int j = 0; j < num_trees; j++) {
478 auto &tree = *trees_[j];
479 int num_leaves = tree.NumLeaves();
480 tree.PredictLeafIndexInplace(covariates, output, column_ind, offset, max_leaf);
481 offset += n;
482 max_leaf += num_leaves;
483 }
484 }
485
503 void PredictLeafIndicesInplace(Eigen::MatrixXd& covariates, std::vector<int32_t>& output, int num_trees, data_size_t n) {
504 CHECK_GE(output.size(), num_trees*n);
505 int offset = 0;
506 int max_leaf = 0;
507 for (int j = 0; j < num_trees; j++) {
508 auto &tree = *trees_[j];
509 int num_leaves = tree.NumLeaves();
510 tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf);
511 offset += n;
512 max_leaf += num_leaves;
513 }
514 }
515
520 std::vector<int32_t> PredictLeafIndices(ForestDataset* dataset) {
521 int num_trees = num_trees_;
522 data_size_t n = dataset->NumObservations();
523 std::vector<int32_t> output(n*num_trees);
524 PredictLeafIndicesInplace(dataset, output, num_trees, n);
525 return output;
526 }
527
529 json to_json() {
530 json result_obj;
531 result_obj.emplace("num_trees", this->num_trees_);
532 result_obj.emplace("output_dimension", this->output_dimension_);
533 result_obj.emplace("is_leaf_constant", this->is_leaf_constant_);
534 result_obj.emplace("is_exponentiated", this->is_exponentiated_);
535
536 std::string tree_label;
537 for (int i = 0; i < trees_.size(); i++) {
538 tree_label = "tree_" + std::to_string(i);
539 result_obj.emplace(tree_label, trees_[i]->to_json());
540 }
541
542 return result_obj;
543 }
544
546 void from_json(const json& ensemble_json) {
547 this->num_trees_ = ensemble_json.at("num_trees");
548 this->output_dimension_ = ensemble_json.at("output_dimension");
549 this->is_leaf_constant_ = ensemble_json.at("is_leaf_constant");
550 this->is_exponentiated_ = ensemble_json.at("is_exponentiated");
551
552 std::string tree_label;
553 trees_.clear();
554 trees_.resize(this->num_trees_);
555 for (int i = 0; i < this->num_trees_; i++) {
556 tree_label = "tree_" + std::to_string(i);
557 trees_[i] = std::make_unique<Tree>();
558 trees_[i]->from_json(ensemble_json.at(tree_label));
559 }
560 }
561
562 private:
563 std::vector<std::unique_ptr<Tree>> trees_;
564 int num_trees_;
565 int output_dimension_;
566 bool is_leaf_constant_;
567 bool is_exponentiated_;
568};
569
// end of forest_group
571
572} // namespace StochTree
573
574#endif // STOCHTREE_ENSEMBLE_H_
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
data_size_t NumObservations()
Number of observations (rows) in the dataset.
Definition data.h:354
Eigen::MatrixXd & GetCovariates()
Return a reference to the raw Eigen::MatrixXd storing the covariate data.
Definition data.h:384
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Tree * GetTree(int i)
Return a pointer to a tree in the forest.
Definition ensemble.h:140
void ResetInitTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:169
int GetMaxLeafIndex()
Obtain a 0-based "maximum" leaf index for an ensemble, which is equivalent to the sum of the number o...
Definition ensemble.h:393
TreeEnsemble(TreeEnsemble &ensemble)
Initialize an ensemble based on the state of an existing ensemble.
Definition ensemble.h:66
void ResetTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:159
void ReconstituteFromForest(TreeEnsemble &ensemble)
Reset an ensemble to clone another ensemble.
Definition ensemble.h:189
void AddValueToLeaves(double constant_value)
Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional,...
Definition ensemble.h:115
json to_json()
Save to JSON.
Definition ensemble.h:529
void MergeForest(TreeEnsemble &ensemble)
Combine two forests into a single forest by merging their trees.
Definition ensemble.h:91
void PredictLeafIndicesInplace(ForestDataset *dataset, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:419
void PredictLeafIndicesInplace(Eigen::MatrixXd &covariates, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:503
void ResetRoot()
Reset a TreeEnsemble to all single-node "root" trees.
Definition ensemble.h:147
void from_json(const json &ensemble_json)
Load from JSON.
Definition ensemble.h:546
std::vector< int32_t > PredictLeafIndices(ForestDataset *dataset)
Same as PredictLeafIndicesInplace but assumes responsibility for allocating and returning output vect...
Definition ensemble.h:520
void PredictLeafIndicesInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:440
void MultiplyLeavesByValue(double constant_multiple)
Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional,...
Definition ensemble.h:127
TreeEnsemble(int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Initialize a new TreeEnsemble.
Definition ensemble.h:47
void CloneFromExistingTree(int i, Tree *tree)
Clone a single tree in an ensemble from an existing tree, overwriting current tree.
Definition ensemble.h:180
void PredictLeafIndicesInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, Eigen::Map< Eigen::Matrix< int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &output, int column_ind, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:471
Decision tree data structure.
Definition tree.h:69
void AddValueToLeaves(double constant_value)
Add a constant value to every leaf of a tree. If leaves are multi-dimensional, constant_value will be...
Definition tree.h:209
void MultiplyLeavesByValue(double constant_multiple)
Multiply every leaf of a tree by a constant value. If leaves are multi-dimensional,...
Definition tree.h:226
int EvaluateTree(Tree const &tree, Eigen::MatrixXd &data, int row)
Definition tree.h:893
Definition category_tracker.h:40