StochTree 0.1.1
Loading...
Searching...
No Matches
random_effects.h
1
5#ifndef STOCHTREE_RANDOM_EFFECTS_H_
6#define STOCHTREE_RANDOM_EFFECTS_H_
7
8#include <stochtree/category_tracker.h>
9#include <stochtree/cutpoint_candidates.h>
10#include <stochtree/data.h>
11#include <stochtree/ensemble.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/log.h>
14#include <stochtree/normal_sampler.h>
15#include <stochtree/partition_tracker.h>
16#include <stochtree/prior.h>
17#include <nlohmann/json.hpp>
18#include <Eigen/Dense>
19
20#include <fstream>
21#include <map>
22#include <memory>
23#include <random>
24#include <string>
25#include <vector>
26
27namespace StochTree {
28
30class LabelMapper;
31class MultivariateRegressionRandomEffectsModel;
32class RandomEffectsContainer;
33
36 public:
37 RandomEffectsTracker(std::vector<int32_t>& group_indices);
39 inline data_size_t GetCategoryId(int observation_num) {return sample_category_mapper_->GetCategoryId(observation_num);}
40 inline data_size_t CategoryBegin(int category_id) {return category_sample_tracker_->CategoryBegin(category_id);}
41 inline data_size_t CategoryEnd(int category_id) {return category_sample_tracker_->CategoryEnd(category_id);}
42 inline data_size_t CategorySize(int category_id) {return category_sample_tracker_->CategorySize(category_id);}
43 inline int32_t NumCategories() {return num_categories_;}
44 inline int32_t CategoryNumber(int32_t category_id) {return category_sample_tracker_->CategoryNumber(category_id);}
45 SampleCategoryMapper* GetSampleCategoryMapper() {return sample_category_mapper_.get();}
46 CategorySampleTracker* GetCategorySampleTracker() {return category_sample_tracker_.get();}
47 std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(int category_id);
48 std::vector<data_size_t>::iterator UnsortedNodeEndIterator(int category_id);
49 std::map<int32_t, int32_t>& GetLabelMap() {return category_sample_tracker_->GetLabelMap();}
50 std::vector<int32_t>& GetUniqueGroupIds() {return category_sample_tracker_->GetUniqueGroupIds();}
51 std::vector<data_size_t>& NodeIndices(int category_id) {return category_sample_tracker_->NodeIndices(category_id);}
52 std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
53 double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);}
54 void SetPrediction(data_size_t observation_num, double pred) {rfx_predictions_.at(observation_num) = pred;}
57 RandomEffectsDataset& rfx_dataset, ColumnVector& residual);
62 RandomEffectsDataset& rfx_dataset, ColumnVector& residual);
63
64 private:
66 std::unique_ptr<SampleCategoryMapper> sample_category_mapper_;
68 std::unique_ptr<CategorySampleTracker> category_sample_tracker_;
70 std::vector<double> rfx_predictions_;
72 int num_categories_;
73 int num_observations_;
74};
75
78 public:
79 LabelMapper() {}
80 LabelMapper(std::map<int32_t, int32_t> label_map) {
81 label_map_ = label_map;
82 for (const auto& [key, value] : label_map) keys_.push_back(key);
83 }
84 ~LabelMapper() {}
85 void LoadFromLabelMap(std::map<int32_t, int32_t> label_map) {
86 label_map_ = label_map;
87 for (const auto& [key, value] : label_map) keys_.push_back(key);
88 }
89 bool ContainsLabel(int32_t category_id) {
90 auto pos = label_map_.find(category_id);
91 return pos != label_map_.end();
92 }
93 int32_t CategoryNumber(int32_t category_id) {
94 return label_map_[category_id];
95 }
96 void SaveToJsonFile(std::string filename) {
97 nlohmann::json model_json = this->to_json();
98 std::ofstream output_file(filename);
99 output_file << model_json << std::endl;
100 }
101 void LoadFromJsonFile(std::string filename) {
102 std::ifstream f(filename);
103 nlohmann::json rfx_label_mapper_json = nlohmann::json::parse(f);
104 this->Reset();
105 this->from_json(rfx_label_mapper_json);
106 }
107 std::string DumpJsonString() {
108 nlohmann::json model_json = this->to_json();
109 return model_json.dump();
110 }
111 void LoadFromJsonString(std::string& json_string) {
112 nlohmann::json rfx_label_mapper_json = nlohmann::json::parse(json_string);
113 this->Reset();
114 this->from_json(rfx_label_mapper_json);
115 }
116 std::vector<int32_t>& Keys() {return keys_;}
117 std::map<int32_t, int32_t>& Map() {return label_map_;}
118 void Reset() {label_map_.clear(); keys_.clear();}
119 nlohmann::json to_json();
120 void from_json(const nlohmann::json& rfx_label_mapper_json);
121 private:
122 std::map<int32_t, int32_t> label_map_;
123 std::vector<int32_t> keys_;
124};
125
128 public:
129 MultivariateRegressionRandomEffectsModel(int num_components, int num_groups) {
130 normal_sampler_ = MultivariateNormalSampler();
131 ig_sampler_ = InverseGammaSampler();
132 num_components_ = num_components;
133 num_groups_ = num_groups;
134 working_parameter_ = Eigen::VectorXd(num_components_);
135 group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
136 group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
137 working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
138 }
140
142 void ResetFromSample(RandomEffectsContainer& rfx_container, int sample_num);
143
145 void SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
146 void SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
147 void SampleGroupParameters(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
148 void SampleVarianceComponents(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
149
151 void SetWorkingParameter(Eigen::VectorXd& working_parameter) {
152 working_parameter_ = working_parameter;
153 }
154 void SetGroupParameters(Eigen::MatrixXd& group_parameters) {
155 group_parameters_ = group_parameters;
156 }
157 void SetGroupParameter(Eigen::VectorXd& group_parameter, int32_t group_id) {
158 group_parameters_(Eigen::all, group_id) = group_parameter;
159 }
160 void SetWorkingParameterCovariance(Eigen::MatrixXd& working_parameter_covariance) {
161 working_parameter_covariance_ = working_parameter_covariance;
162 }
163 void SetGroupParameterCovariance(Eigen::MatrixXd& group_parameter_covariance) {
164 group_parameter_covariance_ = group_parameter_covariance;
165 }
166 void SetGroupParameterVarianceComponent(double value, int32_t component_id) {
167 group_parameter_covariance_(component_id, component_id) = value;
168 }
169 void SetVariancePriorShape(double value) {
170 variance_prior_shape_ = value;
171 }
172 void SetVariancePriorScale(double value) {
173 variance_prior_scale_ = value;
174 }
175
177 Eigen::VectorXd& GetWorkingParameter() {
178 return working_parameter_;
179 }
180 Eigen::MatrixXd& GetGroupParameters() {
181 return group_parameters_;
182 }
183 Eigen::MatrixXd& GetWorkingParameterCovariance() {
184 return working_parameter_covariance_;
185 }
186 Eigen::MatrixXd& GetGroupParameterCovariance() {
187 return group_parameter_covariance_;
188 }
189 double GetVariancePriorShape() {
190 return variance_prior_shape_;
191 }
192 double GetVariancePriorScale() {
193 return variance_prior_scale_;
194 }
195 inline int32_t NumComponents() {return num_components_;}
196 inline int32_t NumGroups() {return num_groups_;}
197
198 std::vector<double> Predict(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker) {
199 std::vector<double> output(dataset.NumObservations());
200 PredictInplace(dataset, tracker, output);
201 return output;
202 }
203
204 void PredictInplace(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, std::vector<double>& output) {
205 Eigen::MatrixXd X = dataset.GetBasis();
206 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
207 CHECK_EQ(X.rows(), group_labels.size());
208 int n = X.rows();
209 CHECK_EQ(n, output.size());
210 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
211 std::int32_t group_ind;
212 for (int i = 0; i < n; i++) {
213 group_ind = tracker.CategoryNumber(group_labels[i]);
214 output[i] = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
215 }
216 }
217
218 void AddCurrentPredictionToResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
219 data_size_t n = dataset.NumObservations();
220 CHECK_EQ(n, residual.NumRows());
221 double current_pred;
222 double new_resid;
223 for (data_size_t i = 0; i < n; i++) {
224 current_pred = tracker.GetPrediction(i);
225 new_resid = residual.GetElement(i) + current_pred;
226 residual.SetElement(i, new_resid);
227 }
228 }
229
230 void SubtractNewPredictionFromResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
231 Eigen::MatrixXd X = dataset.GetBasis();
232 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
233 CHECK_EQ(X.rows(), group_labels.size());
234 int n = X.rows();
235 double new_pred;
236 double new_resid;
237 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
238 std::int32_t group_ind;
239 for (int i = 0; i < n; i++) {
240 group_ind = tracker.CategoryNumber(group_labels[i]);
241 new_pred = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
242 new_resid = residual.GetElement(i) - new_pred;
243 residual.SetElement(i, new_resid);
244 tracker.SetPrediction(i, new_pred);
245 }
246 }
247
249 Eigen::VectorXd WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance);
251 Eigen::MatrixXd WorkingParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance);
253 Eigen::VectorXd GroupParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id);
255 Eigen::MatrixXd GroupParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id);
257 double VarianceComponentShape(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
259 double VarianceComponentScale(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
260
261 private:
263 MultivariateNormalSampler normal_sampler_;
264 InverseGammaSampler ig_sampler_;
265
267 int num_components_;
268 int num_groups_;
269
273 Eigen::VectorXd working_parameter_;
274 Eigen::MatrixXd group_parameters_;
275
277 Eigen::MatrixXd group_parameter_covariance_;
278
280 Eigen::MatrixXd working_parameter_covariance_;
281
283 double variance_prior_shape_;
284 double variance_prior_scale_;
285};
286
288 public:
289 RandomEffectsContainer(int num_components, int num_groups) {
290 num_components_ = num_components;
291 num_groups_ = num_groups;
292 num_samples_ = 0;
293 }
295 num_components_ = 0;
296 num_groups_ = 0;
297 num_samples_ = 0;
298 }
300 void SaveToJsonFile(std::string filename) {
301 nlohmann::json model_json = this->to_json();
302 std::ofstream output_file(filename);
303 output_file << model_json << std::endl;
304 }
305 void LoadFromJsonFile(std::string filename) {
306 std::ifstream f(filename);
307 nlohmann::json rfx_container_json = nlohmann::json::parse(f);
308 this->Reset();
309 this->from_json(rfx_container_json);
310 }
311 std::string DumpJsonString() {
312 nlohmann::json model_json = this->to_json();
313 return model_json.dump();
314 }
315 void LoadFromJsonString(std::string& json_string) {
316 nlohmann::json rfx_container_json = nlohmann::json::parse(json_string);
317 this->Reset();
318 this->from_json(rfx_container_json);
319 }
320 void AddSample(MultivariateRegressionRandomEffectsModel& model);
321 void DeleteSample(int sample_num);
322 void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector<double>& output);
323 inline int NumSamples() {return num_samples_;}
324 inline int NumComponents() {return num_components_;}
325 inline int NumGroups() {return num_groups_;}
326 inline void SetNumSamples(int num_samples) {num_samples_ = num_samples;}
327 inline void SetNumComponents(int num_components) {num_components_ = num_components;}
328 inline void SetNumGroups(int num_groups) {num_groups_ = num_groups;}
329 void Reset() {
330 num_samples_ = 0;
331 num_components_ = 0;
332 num_groups_ = 0;
333 beta_.clear();
334 alpha_.clear();
335 xi_.clear();
336 sigma_xi_.clear();
337 }
338 std::vector<double>& GetBeta() {return beta_;}
339 std::vector<double>& GetAlpha() {return alpha_;}
340 std::vector<double>& GetXi() {return xi_;}
341 std::vector<double>& GetSigma() {return sigma_xi_;}
342 nlohmann::json to_json();
343 void from_json(const nlohmann::json& rfx_container_json);
344 void append_from_json(const nlohmann::json& rfx_container_json);
345 private:
346 int num_samples_;
347 int num_components_;
348 int num_groups_;
349 std::vector<double> beta_;
350 std::vector<double> alpha_;
351 std::vector<double> xi_;
352 std::vector<double> sigma_xi_;
353 void AddAlpha(MultivariateRegressionRandomEffectsModel& model);
355 void AddSigma(MultivariateRegressionRandomEffectsModel& model);
356};
357
358} // namespace StochTree
359
360#endif // STOCHTREE_RANDOM_EFFECTS_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:193
Definition ig_sampler.h:9
Standalone container for the map from category IDs to 0-based indices.
Definition random_effects.h:77
Definition normal_sampler.h:24
Posterior computation and sampling and state storage for random effects model with a group-level mult...
Definition random_effects.h:127
Eigen::VectorXd GroupParameterMean(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t group_id)
Compute the posterior mean of a group parameter, conditional on the working parameter and the varianc...
void ResetFromSample(RandomEffectsContainer &rfx_container, int sample_num)
Reconstruction from serialized model parameter samples.
Eigen::VectorXd & GetWorkingParameter()
Getters.
Definition random_effects.h:177
double VarianceComponentScale(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t component_id)
Compute the posterior scale of the group variance component, conditional on the working and group par...
Eigen::MatrixXd GroupParameterVariance(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t group_id)
Compute the posterior covariance of a group parameter, conditional on the working parameter and the v...
void SampleRandomEffects(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &tracker, double global_variance, std::mt19937 &gen)
Samplers.
Eigen::VectorXd WorkingParameterMean(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance)
Compute the posterior mean of the working parameter, conditional on the group parameters and the vari...
Eigen::MatrixXd WorkingParameterVariance(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance)
Compute the posterior covariance of the working parameter, conditional on the group parameters and th...
double VarianceComponentShape(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t component_id)
Compute the posterior shape of the group variance component, conditional on the working and group par...
void SetWorkingParameter(Eigen::VectorXd &working_parameter)
Setters.
Definition random_effects.h:151
Definition random_effects.h:287
API for loading and accessing data used to sample (additive) random effects.
Definition data.h:485
Wrapper around data structures for random effects sampling algorithms.
Definition random_effects.h:35
void RootReset(MultivariateRegressionRandomEffectsModel &rfx_model, RandomEffectsDataset &rfx_dataset, ColumnVector &residual)
Resets RFX tracker to initial default. Assumes tracker already exists in main memory....
void ResetFromSample(MultivariateRegressionRandomEffectsModel &rfx_model, RandomEffectsDataset &rfx_dataset, ColumnVector &residual)
Resets RFX tracker based on a specific sample. Assumes tracker already exists in main memory.
Class storing sample-node map for each tree in an ensemble TODO: Add run-time checks for categories w...
Definition category_tracker.h:41
Definition category_tracker.h:36