StochTree 0.1.1
Loading...
Searching...
No Matches
container.h
1
6#ifndef STOCHTREE_CONTAINER_H_
7#define STOCHTREE_CONTAINER_H_
8
9#include <stochtree/data.h>
10#include <stochtree/ensemble.h>
11#include <nlohmann/json.hpp>
12#include <stochtree/tree.h>
13
14#include <algorithm>
15#include <deque>
16#include <fstream>
17#include <optional>
18#include <random>
19#include <unordered_map>
20
21namespace StochTree {
22
29 public:
38 ForestContainer(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false);
48 ForestContainer(int num_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false);
56 void MergeForests(int inbound_forest_index, int outbound_forest_index) {
57 forests_[inbound_forest_index]->MergeForest(*forests_[outbound_forest_index]);
58 }
65 void AddToForest(int forest_index, double constant_value) {
66 forests_[forest_index]->AddValueToLeaves(constant_value);
67 }
74 void MultiplyForest(int forest_index, double constant_multiple) {
75 forests_[forest_index]->MultiplyLeavesByValue(constant_multiple);
76 }
82 void DeleteSample(int sample_num);
88 void AddSample(TreeEnsemble& forest);
94 void InitializeRoot(double leaf_value);
100 void InitializeRoot(std::vector<double>& leaf_vector);
106 void AddSamples(int num_samples);
113 void CopyFromPreviousSample(int new_sample_id, int previous_sample_id);
124 std::vector<double> Predict(ForestDataset& dataset);
138 std::vector<double> PredictRaw(ForestDataset& dataset);
139 std::vector<double> PredictRaw(ForestDataset& dataset, int forest_num);
140 std::vector<double> PredictRawSingleTree(ForestDataset& dataset, int forest_num, int tree_num);
141 void PredictInPlace(ForestDataset& dataset, std::vector<double>& output);
142 void PredictRawInPlace(ForestDataset& dataset, std::vector<double>& output);
143 void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector<double>& output);
144 void PredictRawSingleTreeInPlace(ForestDataset& dataset, int forest_num, int tree_num, std::vector<double>& output);
145 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
146 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
147 std::vector<int>& forest_indices, int num_trees, data_size_t n);
148
149 inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();}
150 inline int32_t NumSamples() {return num_samples_;}
151 inline int32_t NumTrees() {return num_trees_;}
152 inline int32_t NumTrees(int ensemble_num) {return forests_[ensemble_num]->NumTrees();}
153 inline int32_t NumLeaves(int ensemble_num) {return forests_[ensemble_num]->NumLeaves();}
154 inline int32_t EnsembleTreeMaxDepth(int ensemble_num, int tree_num) {return forests_[ensemble_num]->TreeMaxDepth(tree_num);}
155 inline double EnsembleAverageMaxDepth(int ensemble_num) {return forests_[ensemble_num]->AverageMaxDepth();}
156 inline double AverageMaxDepth() {
157 double numerator = 0.;
158 double denominator = 0.;
159 for (int i = 0; i < num_samples_; i++) {
160 for (int j = 0; j < num_trees_; j++) {
161 numerator += static_cast<double>(forests_[i]->TreeMaxDepth(j));
162 denominator += 1.;
163 }
164 }
165 return numerator / denominator;
166 }
167 inline int32_t OutputDimension() {return output_dimension_;}
168 inline int32_t OutputDimension(int ensemble_num) {return forests_[ensemble_num]->OutputDimension();}
169 inline bool IsLeafConstant() {return is_leaf_constant_;}
170 inline bool IsLeafConstant(int ensemble_num) {return forests_[ensemble_num]->IsLeafConstant();}
171 inline bool IsExponentiated() {return is_exponentiated_;}
172 inline bool IsExponentiated(int ensemble_num) {return forests_[ensemble_num]->IsExponentiated();}
173 inline bool AllRoots(int ensemble_num) {return forests_[ensemble_num]->AllRoots();}
174 inline void SetLeafValue(int ensemble_num, double leaf_value) {forests_[ensemble_num]->SetLeafValue(leaf_value);}
175 inline void SetLeafVector(int ensemble_num, std::vector<double>& leaf_vector) {forests_[ensemble_num]->SetLeafVector(leaf_vector);}
176 inline void IncrementSampleCount() {num_samples_++;}
177
178 void SaveToJsonFile(std::string filename) {
179 nlohmann::json model_json = this->to_json();
180 std::ofstream output_file(filename);
181 output_file << model_json << std::endl;
182 }
183
184 void LoadFromJsonFile(std::string filename) {
185 std::ifstream f(filename);
186 nlohmann::json file_tree_json = nlohmann::json::parse(f);
187 this->Reset();
188 this->from_json(file_tree_json);
189 }
190
191 std::string DumpJsonString() {
192 nlohmann::json model_json = this->to_json();
193 return model_json.dump();
194 }
195
196 void LoadFromJsonString(std::string& json_string) {
197 nlohmann::json file_tree_json = nlohmann::json::parse(json_string);
198 this->Reset();
199 this->from_json(file_tree_json);
200 }
201
202 void Reset() {
203 forests_.clear();
204 num_samples_ = 0;
205 num_trees_ = 0;
206 output_dimension_ = 0;
207 is_leaf_constant_ = 0;
208 initialized_ = false;
209 }
210
212 nlohmann::json to_json();
214 void from_json(const nlohmann::json& forest_container_json);
216 void append_from_json(const nlohmann::json& forest_container_json);
217
218 private:
219 std::vector<std::unique_ptr<TreeEnsemble>> forests_;
220 int num_samples_;
221 int num_trees_;
222 int output_dimension_;
223 bool is_exponentiated_{false};
224 bool is_leaf_constant_;
225 bool initialized_{false};
226};
227} // namespace StochTree
228
229#endif // STOCHTREE_CONTAINER_H_
Container of TreeEnsemble forest objects. This is the primary (in-memory) storage interface for multi...
Definition container.h:28
ForestContainer(int num_samples, int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Construct a new ForestContainer object.
std::vector< double > Predict(ForestDataset &dataset)
Predict from every forest in the container on every observation in the provided dataset....
void MultiplyForest(int forest_index, double constant_multiple)
Multiply every leaf of every tree of a specified forest by a constant value.
Definition container.h:74
void InitializeRoot(std::vector< double > &leaf_vector)
Initialize a "root" forest of multivariate trees as the first element of the container,...
void DeleteSample(int sample_num)
Remove a forest from a container of forest samples and delete the corresponding object,...
std::vector< double > PredictRaw(ForestDataset &dataset)
Predict from every forest in the container on every observation in the provided dataset....
void MergeForests(int inbound_forest_index, int outbound_forest_index)
Combine two forests into a single forest by merging their trees.
Definition container.h:56
void append_from_json(const nlohmann::json &forest_container_json)
Append to a forest container from JSON, requires that the ensemble already contains a nonzero number ...
void CopyFromPreviousSample(int new_sample_id, int previous_sample_id)
Copy the forest stored at previous_sample_id to the forest stored at new_sample_id.
void from_json(const nlohmann::json &forest_container_json)
Load from JSON.
void InitializeRoot(double leaf_value)
Initialize a "root" forest of univariate trees as the first element of the container,...
nlohmann::json to_json()
Save to JSON.
void AddSamples(int num_samples)
Pre-allocate space for num_samples additional forests in the container.
ForestContainer(int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Construct a new ForestContainer object.
void AddToForest(int forest_index, double constant_value)
Add a constant value to every leaf of every tree of a specified forest.
Definition container.h:65
void AddSample(TreeEnsemble &forest)
Add a new forest to the container by copying forest.
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Definition category_tracker.h:40