StochTree 0.0.1
Loading...
Searching...
No Matches
ensemble.h
1
10#ifndef STOCHTREE_ENSEMBLE_H_
11#define STOCHTREE_ENSEMBLE_H_
12
13#include <stochtree/data.h>
14#include <stochtree/tree.h>
15#include <nlohmann/json.hpp>
16
17#include <algorithm>
18#include <deque>
19#include <optional>
20#include <random>
21#include <unordered_map>
22
23using json = nlohmann::json;
24
25namespace StochTree {
26
38 public:
47 TreeEnsemble(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) {
48 // Initialize trees in the ensemble
49 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees);
50 for (int i = 0; i < num_trees; i++) {
51 trees_[i].reset(new Tree());
52 trees_[i]->Init(output_dimension, is_exponentiated);
53 }
54 // Store ensemble configurations
55 num_trees_ = num_trees;
56 output_dimension_ = output_dimension;
57 is_leaf_constant_ = is_leaf_constant;
58 is_exponentiated_ = is_exponentiated;
59 }
60
67 // Unpack ensemble configurations
68 num_trees_ = ensemble.num_trees_;
69 output_dimension_ = ensemble.output_dimension_;
70 is_leaf_constant_ = ensemble.is_leaf_constant_;
71 is_exponentiated_ = ensemble.is_exponentiated_;
72 // Initialize trees in the ensemble
73 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees_);
74 for (int i = 0; i < num_trees_; i++) {
75 trees_[i].reset(new Tree());
76 }
77 // Clone trees in the ensemble
78 for (int j = 0; j < num_trees_; j++) {
79 Tree* tree = ensemble.GetTree(j);
80 this->CloneFromExistingTree(j, tree);
81 }
82 }
83
84 ~TreeEnsemble() {}
85
92 inline Tree* GetTree(int i) {
93 return trees_[i].get();
94 }
95
99 inline void ResetRoot() {
100 for (int i = 0; i < num_trees_; i++) {
101 ResetInitTree(i);
102 }
103 }
104
111 inline void ResetTree(int i) {
112 trees_[i].reset(new Tree());
113 }
114
121 inline void ResetInitTree(int i) {
122 trees_[i].reset(new Tree());
123 trees_[i]->Init(output_dimension_, is_exponentiated_);
124 }
125
132 inline void CloneFromExistingTree(int i, Tree* tree) {
133 return trees_[i]->CloneFromTree(tree);
134 }
135
141 inline void ReconstituteFromForest(TreeEnsemble& ensemble) {
142 // Delete old tree pointers
143 trees_.clear();
144 // Unpack ensemble configurations
145 num_trees_ = ensemble.num_trees_;
146 output_dimension_ = ensemble.output_dimension_;
147 is_leaf_constant_ = ensemble.is_leaf_constant_;
148 is_exponentiated_ = ensemble.is_exponentiated_;
149 // Initialize trees in the ensemble
150 trees_ = std::vector<std::unique_ptr<Tree>>(num_trees_);
151 for (int i = 0; i < num_trees_; i++) {
152 trees_[i].reset(new Tree());
153 }
154 // Clone trees in the ensemble
155 for (int j = 0; j < num_trees_; j++) {
156 Tree* tree = ensemble.GetTree(j);
157 this->CloneFromExistingTree(j, tree);
158 }
159 }
160
161 std::vector<double> Predict(ForestDataset& dataset) {
162 data_size_t n = dataset.NumObservations();
163 std::vector<double> output(n);
164 PredictInplace(dataset, output, 0);
165 return output;
166 }
167
168 std::vector<double> PredictRaw(ForestDataset& dataset) {
169 data_size_t n = dataset.NumObservations();
170 data_size_t total_output_size = n * output_dimension_;
171 std::vector<double> output(total_output_size);
172 PredictRawInplace(dataset, output, 0);
173 return output;
174 }
175
176 inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output, data_size_t offset = 0) {
177 PredictInplace(dataset, output, 0, trees_.size(), offset);
178 }
179
180 inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output,
181 int tree_begin, int tree_end, data_size_t offset = 0) {
182 if (is_leaf_constant_) {
183 PredictInplace(dataset.GetCovariates(), output, tree_begin, tree_end, offset);
184 } else {
185 CHECK(dataset.HasBasis());
186 PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
187 }
188 }
189
190 inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector<double> &output, data_size_t offset = 0) {
191 PredictInplace(covariates, basis, output, 0, trees_.size(), offset);
192 }
193
194 inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector<double> &output,
195 int tree_begin, int tree_end, data_size_t offset = 0) {
196 double pred;
197 CHECK_EQ(covariates.rows(), basis.rows());
198 CHECK_EQ(output_dimension_, trees_[0]->OutputDimension());
199 CHECK_EQ(output_dimension_, basis.cols());
200 data_size_t n = covariates.rows();
201 data_size_t total_output_size = n;
202 if (output.size() < total_output_size + offset) {
203 Log::Fatal("Mismatched size of prediction vector and training data");
204 }
205 for (data_size_t i = 0; i < n; i++) {
206 pred = 0.0;
207 for (size_t j = tree_begin; j < tree_end; j++) {
208 auto &tree = *trees_[j];
209 std::int32_t nidx = EvaluateTree(tree, covariates, i);
210 for (int32_t k = 0; k < output_dimension_; k++) {
211 pred += tree.LeafValue(nidx, k) * basis(i, k);
212 }
213 }
214 if (is_exponentiated_) output[i + offset] = std::exp(pred);
215 else output[i + offset] = pred;
216 }
217 }
218
219 inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector<double> &output, data_size_t offset = 0) {
220 PredictInplace(covariates, output, 0, trees_.size(), offset);
221 }
222
223 inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector<double> &output, int tree_begin, int tree_end, data_size_t offset = 0) {
224 double pred;
225 data_size_t n = covariates.rows();
226 data_size_t total_output_size = n;
227 if (output.size() < total_output_size + offset) {
228 Log::Fatal("Mismatched size of prediction vector and training data");
229 }
230 for (data_size_t i = 0; i < n; i++) {
231 pred = 0.0;
232 for (size_t j = tree_begin; j < tree_end; j++) {
233 auto &tree = *trees_[j];
234 std::int32_t nidx = EvaluateTree(tree, covariates, i);
235 pred += tree.LeafValue(nidx, 0);
236 }
237 if (is_exponentiated_) output[i + offset] = std::exp(pred);
238 else output[i + offset] = pred;
239 }
240 }
241
242 inline void PredictRawInplace(ForestDataset& dataset, std::vector<double> &output, data_size_t offset = 0) {
243 PredictRawInplace(dataset, output, 0, trees_.size(), offset);
244 }
245
246 inline void PredictRawInplace(ForestDataset& dataset, std::vector<double> &output,
247 int tree_begin, int tree_end, data_size_t offset = 0) {
248 double pred;
249 Eigen::MatrixXd covariates = dataset.GetCovariates();
250 CHECK_EQ(output_dimension_, trees_[0]->OutputDimension());
251 data_size_t n = covariates.rows();
252 data_size_t total_output_size = n * output_dimension_;
253 if (output.size() < total_output_size + offset) {
254 Log::Fatal("Mismatched size of raw prediction vector and training data");
255 }
256 for (data_size_t i = 0; i < n; i++) {
257 for (int32_t k = 0; k < output_dimension_; k++) {
258 pred = 0.0;
259 for (size_t j = tree_begin; j < tree_end; j++) {
260 auto &tree = *trees_[j];
261 int32_t nidx = EvaluateTree(tree, covariates, i);
262 pred += tree.LeafValue(nidx, k);
263 }
264 output[i*output_dimension_ + k + offset] = pred;
265 }
266 }
267 }
268
269 inline int32_t NumTrees() {
270 return num_trees_;
271 }
272
273 inline int32_t NumLeaves() {
274 int32_t result = 0;
275 for (int i = 0; i < num_trees_; i++) {
276 result += trees_[i]->NumLeaves();
277 }
278 return result;
279 }
280
281 inline double SumLeafSquared() {
282 double result = 0.;
283 for (int i = 0; i < num_trees_; i++) {
284 result += trees_[i]->SumSquaredLeafValues();
285 }
286 return result;
287 }
288
289 inline int32_t OutputDimension() {
290 return output_dimension_;
291 }
292
293 inline bool IsLeafConstant() {
294 return is_leaf_constant_;
295 }
296
297 inline bool IsExponentiated() {
298 return is_exponentiated_;
299 }
300
301 inline int32_t TreeMaxDepth(int tree_num) {
302 return trees_[tree_num]->MaxLeafDepth();
303 }
304
305 inline double AverageMaxDepth() {
306 double numerator = 0.;
307 double denominator = 0.;
308 for (int i = 0; i < num_trees_; i++) {
309 numerator += static_cast<double>(TreeMaxDepth(i));
310 denominator += 1.;
311 }
312 return numerator / denominator;
313 }
314
315 inline bool AllRoots() {
316 for (int i = 0; i < num_trees_; i++) {
317 if (!trees_[i]->IsRoot()) {
318 return false;
319 }
320 }
321 return true;
322 }
323
324 inline void SetLeafValue(double leaf_value) {
325 CHECK_EQ(output_dimension_, 1);
326 for (int i = 0; i < num_trees_; i++) {
327 CHECK(trees_[i]->IsRoot());
328 trees_[i]->SetLeaf(0, leaf_value);
329 }
330 }
331
332 inline void SetLeafVector(std::vector<double>& leaf_vector) {
333 CHECK_EQ(output_dimension_, leaf_vector.size());
334 for (int i = 0; i < num_trees_; i++) {
335 CHECK(trees_[i]->IsRoot());
336 trees_[i]->SetLeafVector(0, leaf_vector);
337 }
338 }
339
346 int max_leaf = 0;
347 for (int j = 0; j < num_trees_; j++) {
348 auto &tree = *trees_[j];
349 max_leaf += tree.NumLeaves();
350 }
351 return max_leaf;
352 }
353
371 void PredictLeafIndicesInplace(ForestDataset* dataset, std::vector<int32_t>& output, int num_trees, data_size_t n) {
372 PredictLeafIndicesInplace(dataset->GetCovariates(), output, num_trees, n);
373 }
374
392 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int num_trees, data_size_t n) {
393 CHECK_GE(output.size(), num_trees*n);
394 int offset = 0;
395 int max_leaf = 0;
396 for (int j = 0; j < num_trees; j++) {
397 auto &tree = *trees_[j];
398 int num_leaves = tree.NumLeaves();
399 tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf);
400 offset += n;
401 max_leaf += num_leaves;
402 }
403 }
404
423 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
424 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
425 int column_ind, int num_trees, data_size_t n) {
426 CHECK_GE(output.size(), num_trees*n);
427 int offset = 0;
428 int max_leaf = 0;
429 for (int j = 0; j < num_trees; j++) {
430 auto &tree = *trees_[j];
431 int num_leaves = tree.NumLeaves();
432 tree.PredictLeafIndexInplace(covariates, output, column_ind, offset, max_leaf);
433 offset += n;
434 max_leaf += num_leaves;
435 }
436 }
437
455 void PredictLeafIndicesInplace(Eigen::MatrixXd& covariates, std::vector<int32_t>& output, int num_trees, data_size_t n) {
456 CHECK_GE(output.size(), num_trees*n);
457 int offset = 0;
458 int max_leaf = 0;
459 for (int j = 0; j < num_trees; j++) {
460 auto &tree = *trees_[j];
461 int num_leaves = tree.NumLeaves();
462 tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf);
463 offset += n;
464 max_leaf += num_leaves;
465 }
466 }
467
472 std::vector<int32_t> PredictLeafIndices(ForestDataset* dataset) {
473 int num_trees = num_trees_;
474 data_size_t n = dataset->NumObservations();
475 std::vector<int32_t> output(n*num_trees);
476 PredictLeafIndicesInplace(dataset, output, num_trees, n);
477 return output;
478 }
479
481 json to_json() {
482 json result_obj;
483 result_obj.emplace("num_trees", this->num_trees_);
484 result_obj.emplace("output_dimension", this->output_dimension_);
485 result_obj.emplace("is_leaf_constant", this->is_leaf_constant_);
486 result_obj.emplace("is_exponentiated", this->is_exponentiated_);
487
488 std::string tree_label;
489 for (int i = 0; i < trees_.size(); i++) {
490 tree_label = "tree_" + std::to_string(i);
491 result_obj.emplace(tree_label, trees_[i]->to_json());
492 }
493
494 return result_obj;
495 }
496
498 void from_json(const json& ensemble_json) {
499 this->num_trees_ = ensemble_json.at("num_trees");
500 this->output_dimension_ = ensemble_json.at("output_dimension");
501 this->is_leaf_constant_ = ensemble_json.at("is_leaf_constant");
502 this->is_exponentiated_ = ensemble_json.at("is_exponentiated");
503
504 std::string tree_label;
505 trees_.clear();
506 trees_.resize(this->num_trees_);
507 for (int i = 0; i < this->num_trees_; i++) {
508 tree_label = "tree_" + std::to_string(i);
509 trees_[i] = std::make_unique<Tree>();
510 trees_[i]->from_json(ensemble_json.at(tree_label));
511 }
512 }
513
514 private:
515 std::vector<std::unique_ptr<Tree>> trees_;
516 int num_trees_;
517 int output_dimension_;
518 bool is_leaf_constant_;
519 bool is_exponentiated_;
520};
521
// end of forest_group
523
524} // namespace StochTree
525
526#endif // STOCHTREE_ENSEMBLE_H_
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
data_size_t NumObservations()
Number of observations (rows) in the dataset.
Definition data.h:354
Eigen::MatrixXd & GetCovariates()
Return a reference to the raw Eigen::MatrixXd storing the covariate data.
Definition data.h:384
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Tree * GetTree(int i)
Return a pointer to a tree in the forest.
Definition ensemble.h:92
void ResetInitTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:121
int GetMaxLeafIndex()
Obtain a 0-based "maximum" leaf index for an ensemble, which is equivalent to the sum of the number o...
Definition ensemble.h:345
TreeEnsemble(TreeEnsemble &ensemble)
Initialize an ensemble based on the state of an existing ensemble.
Definition ensemble.h:66
void ResetTree(int i)
Reset a single tree in an ensemble.
Definition ensemble.h:111
void ReconstituteFromForest(TreeEnsemble &ensemble)
Reset an ensemble to clone another ensemble.
Definition ensemble.h:141
json to_json()
Save to JSON.
Definition ensemble.h:481
void PredictLeafIndicesInplace(ForestDataset *dataset, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:371
void PredictLeafIndicesInplace(Eigen::MatrixXd &covariates, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:455
void ResetRoot()
Reset a TreeEnsemble to all single-node "root" trees.
Definition ensemble.h:99
void from_json(const json &ensemble_json)
Load from JSON.
Definition ensemble.h:498
std::vector< int32_t > PredictLeafIndices(ForestDataset *dataset)
Same as PredictLeafIndicesInplace but assumes responsibility for allocating and returning output vect...
Definition ensemble.h:472
void PredictLeafIndicesInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, std::vector< int32_t > &output, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:392
TreeEnsemble(int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Initialize a new TreeEnsemble.
Definition ensemble.h:47
void CloneFromExistingTree(int i, Tree *tree)
Clone a single tree in an ensemble from an existing tree, overwriting current tree.
Definition ensemble.h:132
void PredictLeafIndicesInplace(Eigen::Map< Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &covariates, Eigen::Map< Eigen::Matrix< int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor > > &output, int column_ind, int num_trees, data_size_t n)
Obtain a 0-based leaf index for every tree in an ensemble and for each observation in a ForestDataset...
Definition ensemble.h:423
Decision tree data structure.
Definition tree.h:69
int EvaluateTree(Tree const &tree, Eigen::MatrixXd &data, int row)
Definition tree.h:859
Definition category_tracker.h:40