StochTree 0.1.1
Loading...
Searching...
No Matches
container.h
1
6#ifndef STOCHTREE_CONTAINER_H_
7#define STOCHTREE_CONTAINER_H_
8
9#include <stochtree/data.h>
10#include <stochtree/ensemble.h>
11#include <nlohmann/json.hpp>
12#include <stochtree/tree.h>
13
14#include <fstream>
15
16namespace StochTree {
17
24 public:
33 ForestContainer(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false);
43 ForestContainer(int num_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false);
51 void MergeForests(int inbound_forest_index, int outbound_forest_index) {
52 forests_[inbound_forest_index]->MergeForest(*forests_[outbound_forest_index]);
53 }
60 void AddToForest(int forest_index, double constant_value) {
61 forests_[forest_index]->AddValueToLeaves(constant_value);
62 }
69 void MultiplyForest(int forest_index, double constant_multiple) {
70 forests_[forest_index]->MultiplyLeavesByValue(constant_multiple);
71 }
77 void DeleteSample(int sample_num);
83 void AddSample(TreeEnsemble& forest);
89 void InitializeRoot(double leaf_value);
95 void InitializeRoot(std::vector<double>& leaf_vector);
101 void AddSamples(int num_samples);
108 void CopyFromPreviousSample(int new_sample_id, int previous_sample_id);
119 std::vector<double> Predict(ForestDataset& dataset);
133 std::vector<double> PredictRaw(ForestDataset& dataset);
134 std::vector<double> PredictRaw(ForestDataset& dataset, int forest_num);
135 std::vector<double> PredictRawSingleTree(ForestDataset& dataset, int forest_num, int tree_num);
136 void PredictInPlace(ForestDataset& dataset, std::vector<double>& output);
137 void PredictRawInPlace(ForestDataset& dataset, std::vector<double>& output);
138 void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector<double>& output);
139 void PredictRawSingleTreeInPlace(ForestDataset& dataset, int forest_num, int tree_num, std::vector<double>& output);
140 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
141 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
142 std::vector<int>& forest_indices, int num_trees, data_size_t n);
143
144 inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();}
145 inline int32_t NumSamples() {return num_samples_;}
146 inline int32_t NumTrees() {return num_trees_;}
147 inline int32_t NumTrees(int ensemble_num) {return forests_[ensemble_num]->NumTrees();}
148 inline int32_t NumLeaves(int ensemble_num) {return forests_[ensemble_num]->NumLeaves();}
149 inline int32_t EnsembleTreeMaxDepth(int ensemble_num, int tree_num) {return forests_[ensemble_num]->TreeMaxDepth(tree_num);}
150 inline double EnsembleAverageMaxDepth(int ensemble_num) {return forests_[ensemble_num]->AverageMaxDepth();}
151 inline double AverageMaxDepth() {
152 double numerator = 0.;
153 double denominator = 0.;
154 for (int i = 0; i < num_samples_; i++) {
155 for (int j = 0; j < num_trees_; j++) {
156 numerator += static_cast<double>(forests_[i]->TreeMaxDepth(j));
157 denominator += 1.;
158 }
159 }
160 return numerator / denominator;
161 }
162 inline int32_t OutputDimension() {return output_dimension_;}
163 inline int32_t OutputDimension(int ensemble_num) {return forests_[ensemble_num]->OutputDimension();}
164 inline bool IsLeafConstant() {return is_leaf_constant_;}
165 inline bool IsLeafConstant(int ensemble_num) {return forests_[ensemble_num]->IsLeafConstant();}
166 inline bool IsExponentiated() {return is_exponentiated_;}
167 inline bool IsExponentiated(int ensemble_num) {return forests_[ensemble_num]->IsExponentiated();}
168 inline bool AllRoots(int ensemble_num) {return forests_[ensemble_num]->AllRoots();}
169 inline void SetLeafValue(int ensemble_num, double leaf_value) {forests_[ensemble_num]->SetLeafValue(leaf_value);}
170 inline void SetLeafVector(int ensemble_num, std::vector<double>& leaf_vector) {forests_[ensemble_num]->SetLeafVector(leaf_vector);}
171 inline void IncrementSampleCount() {num_samples_++;}
172
173 void SaveToJsonFile(std::string filename) {
174 nlohmann::json model_json = this->to_json();
175 std::ofstream output_file(filename);
176 output_file << model_json << std::endl;
177 }
178
179 void LoadFromJsonFile(std::string filename) {
180 std::ifstream f(filename);
181 nlohmann::json file_tree_json = nlohmann::json::parse(f);
182 this->Reset();
183 this->from_json(file_tree_json);
184 }
185
186 std::string DumpJsonString() {
187 nlohmann::json model_json = this->to_json();
188 return model_json.dump();
189 }
190
191 void LoadFromJsonString(std::string& json_string) {
192 nlohmann::json file_tree_json = nlohmann::json::parse(json_string);
193 this->Reset();
194 this->from_json(file_tree_json);
195 }
196
197 void Reset() {
198 forests_.clear();
199 num_samples_ = 0;
200 num_trees_ = 0;
201 output_dimension_ = 0;
202 is_leaf_constant_ = 0;
203 initialized_ = false;
204 }
205
207 nlohmann::json to_json();
209 void from_json(const nlohmann::json& forest_container_json);
211 void append_from_json(const nlohmann::json& forest_container_json);
212
213 private:
214 std::vector<std::unique_ptr<TreeEnsemble>> forests_;
215 int num_samples_;
216 int num_trees_;
217 int output_dimension_;
218 bool is_exponentiated_{false};
219 bool is_leaf_constant_;
220 bool initialized_{false};
221};
222} // namespace StochTree
223
224#endif // STOCHTREE_CONTAINER_H_
Container of TreeEnsemble forest objects. This is the primary (in-memory) storage interface for multi...
Definition container.h:23
ForestContainer(int num_samples, int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Construct a new ForestContainer object.
std::vector< double > Predict(ForestDataset &dataset)
Predict from every forest in the container on every observation in the provided dataset....
void MultiplyForest(int forest_index, double constant_multiple)
Multiply every leaf of every tree of a specified forest by a constant value.
Definition container.h:69
void InitializeRoot(std::vector< double > &leaf_vector)
Initialize a "root" forest of multivariate trees as the first element of the container,...
void DeleteSample(int sample_num)
Remove a forest from a container of forest samples and delete the corresponding object,...
std::vector< double > PredictRaw(ForestDataset &dataset)
Predict from every forest in the container on every observation in the provided dataset....
void MergeForests(int inbound_forest_index, int outbound_forest_index)
Combine two forests into a single forest by merging their trees.
Definition container.h:51
void append_from_json(const nlohmann::json &forest_container_json)
Append to a forest container from JSON, requires that the ensemble already contains a nonzero number ...
void CopyFromPreviousSample(int new_sample_id, int previous_sample_id)
Copy the forest stored at previous_sample_id to the forest stored at new_sample_id.
void from_json(const nlohmann::json &forest_container_json)
Load from JSON.
void InitializeRoot(double leaf_value)
Initialize a "root" forest of univariate trees as the first element of the container,...
nlohmann::json to_json()
Save to JSON.
void AddSamples(int num_samples)
Pre-allocate space for num_samples additional forests in the container.
ForestContainer(int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Construct a new ForestContainer object.
void AddToForest(int forest_index, double constant_value)
Add a constant value to every leaf of every tree of a specified forest.
Definition container.h:60
void AddSample(TreeEnsemble &forest)
Add a new forest to the container by copying forest.
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:31
Definition category_tracker.h:36