StochTree 0.0.1
Loading...
Searching...
No Matches
container.h
1
6#ifndef STOCHTREE_CONTAINER_H_
7#define STOCHTREE_CONTAINER_H_
8
9#include <stochtree/data.h>
10#include <stochtree/ensemble.h>
11#include <nlohmann/json.hpp>
12#include <stochtree/tree.h>
13
14#include <algorithm>
15#include <deque>
16#include <fstream>
17#include <optional>
18#include <random>
19#include <unordered_map>
20
21namespace StochTree {
22
29 public:
38 ForestContainer(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false);
48 ForestContainer(int num_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false);
55 void DeleteSample(int sample_num);
61 void AddSample(TreeEnsemble& forest);
67 void InitializeRoot(double leaf_value);
73 void InitializeRoot(std::vector<double>& leaf_vector);
79 void AddSamples(int num_samples);
86 void CopyFromPreviousSample(int new_sample_id, int previous_sample_id);
97 std::vector<double> Predict(ForestDataset& dataset);
111 std::vector<double> PredictRaw(ForestDataset& dataset);
112 std::vector<double> PredictRaw(ForestDataset& dataset, int forest_num);
113 std::vector<double> PredictRawSingleTree(ForestDataset& dataset, int forest_num, int tree_num);
114 void PredictInPlace(ForestDataset& dataset, std::vector<double>& output);
115 void PredictRawInPlace(ForestDataset& dataset, std::vector<double>& output);
116 void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector<double>& output);
117 void PredictRawSingleTreeInPlace(ForestDataset& dataset, int forest_num, int tree_num, std::vector<double>& output);
118 void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
119 Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
120 std::vector<int>& forest_indices, int num_trees, data_size_t n);
121
122 inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();}
123 inline int32_t NumSamples() {return num_samples_;}
124 inline int32_t NumTrees() {return num_trees_;}
125 inline int32_t NumTrees(int ensemble_num) {return forests_[ensemble_num]->NumTrees();}
126 inline int32_t NumLeaves(int ensemble_num) {return forests_[ensemble_num]->NumLeaves();}
127 inline int32_t EnsembleTreeMaxDepth(int ensemble_num, int tree_num) {return forests_[ensemble_num]->TreeMaxDepth(tree_num);}
128 inline double EnsembleAverageMaxDepth(int ensemble_num) {return forests_[ensemble_num]->AverageMaxDepth();}
129 inline double AverageMaxDepth() {
130 double numerator = 0.;
131 double denominator = 0.;
132 for (int i = 0; i < num_samples_; i++) {
133 for (int j = 0; j < num_trees_; j++) {
134 numerator += static_cast<double>(forests_[i]->TreeMaxDepth(j));
135 denominator += 1.;
136 }
137 }
138 return numerator / denominator;
139 }
140 inline int32_t OutputDimension() {return output_dimension_;}
141 inline int32_t OutputDimension(int ensemble_num) {return forests_[ensemble_num]->OutputDimension();}
142 inline bool IsLeafConstant() {return is_leaf_constant_;}
143 inline bool IsLeafConstant(int ensemble_num) {return forests_[ensemble_num]->IsLeafConstant();}
144 inline bool IsExponentiated() {return is_exponentiated_;}
145 inline bool IsExponentiated(int ensemble_num) {return forests_[ensemble_num]->IsExponentiated();}
146 inline bool AllRoots(int ensemble_num) {return forests_[ensemble_num]->AllRoots();}
147 inline void SetLeafValue(int ensemble_num, double leaf_value) {forests_[ensemble_num]->SetLeafValue(leaf_value);}
148 inline void SetLeafVector(int ensemble_num, std::vector<double>& leaf_vector) {forests_[ensemble_num]->SetLeafVector(leaf_vector);}
149 inline void IncrementSampleCount() {num_samples_++;}
150
151 void SaveToJsonFile(std::string filename) {
152 nlohmann::json model_json = this->to_json();
153 std::ofstream output_file(filename);
154 output_file << model_json << std::endl;
155 }
156
157 void LoadFromJsonFile(std::string filename) {
158 std::ifstream f(filename);
159 nlohmann::json file_tree_json = nlohmann::json::parse(f);
160 this->Reset();
161 this->from_json(file_tree_json);
162 }
163
164 std::string DumpJsonString() {
165 nlohmann::json model_json = this->to_json();
166 return model_json.dump();
167 }
168
169 void LoadFromJsonString(std::string& json_string) {
170 nlohmann::json file_tree_json = nlohmann::json::parse(json_string);
171 this->Reset();
172 this->from_json(file_tree_json);
173 }
174
175 void Reset() {
176 forests_.clear();
177 num_samples_ = 0;
178 num_trees_ = 0;
179 output_dimension_ = 0;
180 is_leaf_constant_ = 0;
181 initialized_ = false;
182 }
183
185 nlohmann::json to_json();
187 void from_json(const nlohmann::json& forest_container_json);
189 void append_from_json(const nlohmann::json& forest_container_json);
190
191 private:
192 std::vector<std::unique_ptr<TreeEnsemble>> forests_;
193 int num_samples_;
194 int num_trees_;
195 int output_dimension_;
196 bool is_exponentiated_{false};
197 bool is_leaf_constant_;
198 bool initialized_{false};
199};
200} // namespace StochTree
201
202#endif // STOCHTREE_CONTAINER_H_
Container of TreeEnsemble forest objects. This is the primary (in-memory) storage interface for multi...
Definition container.h:28
ForestContainer(int num_samples, int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Construct a new ForestContainer object.
std::vector< double > Predict(ForestDataset &dataset)
Predict from every forest in the container on every observation in the provided dataset....
void InitializeRoot(std::vector< double > &leaf_vector)
Initialize a "root" forest of multivariate trees as the first element of the container,...
void DeleteSample(int sample_num)
Remove a forest from a container of forest samples and delete the corresponding object,...
std::vector< double > PredictRaw(ForestDataset &dataset)
Predict from every forest in the container on every observation in the provided dataset....
void append_from_json(const nlohmann::json &forest_container_json)
Append to a forest container from JSON, requires that the ensemble already contains a nonzero number ...
void CopyFromPreviousSample(int new_sample_id, int previous_sample_id)
Copy the forest stored at previous_sample_id to the forest stored at new_sample_id.
void from_json(const nlohmann::json &forest_container_json)
Load from JSON.
void InitializeRoot(double leaf_value)
Initialize a "root" forest of univariate trees as the first element of the container,...
nlohmann::json to_json()
Save to JSON.
void AddSamples(int num_samples)
Pre-allocate space for num_samples additional forests in the container.
ForestContainer(int num_trees, int output_dimension=1, bool is_leaf_constant=true, bool is_exponentiated=false)
Construct a new ForestContainer object.
void AddSample(TreeEnsemble &forest)
Add a new forest to the container by copying forest.
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:272
Class storing a "forest," or an ensemble of decision trees.
Definition ensemble.h:37
Definition category_tracker.h:40