StochTree 0.1.1
Loading...
Searching...
No Matches
random_effects.h
1
5#ifndef STOCHTREE_RANDOM_EFFECTS_H_
6#define STOCHTREE_RANDOM_EFFECTS_H_
7
8#include <stochtree/category_tracker.h>
9#include <stochtree/cutpoint_candidates.h>
10#include <stochtree/data.h>
11#include <stochtree/ensemble.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/log.h>
14#include <stochtree/normal_sampler.h>
15#include <stochtree/partition_tracker.h>
16#include <stochtree/prior.h>
17#include <nlohmann/json.hpp>
18#include <Eigen/Dense>
19
20#include <cmath>
21#include <fstream>
22#include <map>
23#include <memory>
24#include <random>
25#include <set>
26#include <string>
27#include <type_traits>
28#include <vector>
29
30namespace StochTree {
31
33class LabelMapper;
34class MultivariateRegressionRandomEffectsModel;
35class RandomEffectsContainer;
36
39 public:
40 RandomEffectsTracker(std::vector<int32_t>& group_indices);
42 inline data_size_t GetCategoryId(int observation_num) {return sample_category_mapper_->GetCategoryId(observation_num);}
43 inline data_size_t CategoryBegin(int category_id) {return category_sample_tracker_->CategoryBegin(category_id);}
44 inline data_size_t CategoryEnd(int category_id) {return category_sample_tracker_->CategoryEnd(category_id);}
45 inline data_size_t CategorySize(int category_id) {return category_sample_tracker_->CategorySize(category_id);}
46 inline int32_t NumCategories() {return num_categories_;}
47 inline int32_t CategoryNumber(int32_t category_id) {return category_sample_tracker_->CategoryNumber(category_id);}
48 SampleCategoryMapper* GetSampleCategoryMapper() {return sample_category_mapper_.get();}
49 CategorySampleTracker* GetCategorySampleTracker() {return category_sample_tracker_.get();}
50 std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(int category_id);
51 std::vector<data_size_t>::iterator UnsortedNodeEndIterator(int category_id);
52 std::map<int32_t, int32_t>& GetLabelMap() {return category_sample_tracker_->GetLabelMap();}
53 std::vector<int32_t>& GetUniqueGroupIds() {return category_sample_tracker_->GetUniqueGroupIds();}
54 std::vector<data_size_t>& NodeIndices(int category_id) {return category_sample_tracker_->NodeIndices(category_id);}
55 std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
56 double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);}
57 void SetPrediction(data_size_t observation_num, double pred) {rfx_predictions_.at(observation_num) = pred;}
60 RandomEffectsDataset& rfx_dataset, ColumnVector& residual);
65 RandomEffectsDataset& rfx_dataset, ColumnVector& residual);
66
67 private:
69 std::unique_ptr<SampleCategoryMapper> sample_category_mapper_;
71 std::unique_ptr<CategorySampleTracker> category_sample_tracker_;
73 std::vector<double> rfx_predictions_;
75 int num_categories_;
76 int num_observations_;
77};
78
81 public:
82 LabelMapper() {}
83 LabelMapper(std::map<int32_t, int32_t> label_map) {
84 label_map_ = label_map;
85 for (const auto& [key, value] : label_map) keys_.push_back(key);
86 }
87 ~LabelMapper() {}
88 void LoadFromLabelMap(std::map<int32_t, int32_t> label_map) {
89 label_map_ = label_map;
90 for (const auto& [key, value] : label_map) keys_.push_back(key);
91 }
92 bool ContainsLabel(int32_t category_id) {
93 auto pos = label_map_.find(category_id);
94 return pos != label_map_.end();
95 }
96 int32_t CategoryNumber(int32_t category_id) {
97 return label_map_[category_id];
98 }
99 void SaveToJsonFile(std::string filename) {
100 nlohmann::json model_json = this->to_json();
101 std::ofstream output_file(filename);
102 output_file << model_json << std::endl;
103 }
104 void LoadFromJsonFile(std::string filename) {
105 std::ifstream f(filename);
106 nlohmann::json rfx_label_mapper_json = nlohmann::json::parse(f);
107 this->Reset();
108 this->from_json(rfx_label_mapper_json);
109 }
110 std::string DumpJsonString() {
111 nlohmann::json model_json = this->to_json();
112 return model_json.dump();
113 }
114 void LoadFromJsonString(std::string& json_string) {
115 nlohmann::json rfx_label_mapper_json = nlohmann::json::parse(json_string);
116 this->Reset();
117 this->from_json(rfx_label_mapper_json);
118 }
119 std::vector<int32_t>& Keys() {return keys_;}
120 std::map<int32_t, int32_t>& Map() {return label_map_;}
121 void Reset() {label_map_.clear(); keys_.clear();}
122 nlohmann::json to_json();
123 void from_json(const nlohmann::json& rfx_label_mapper_json);
124 private:
125 std::map<int32_t, int32_t> label_map_;
126 std::vector<int32_t> keys_;
127};
128
131 public:
132 MultivariateRegressionRandomEffectsModel(int num_components, int num_groups) {
133 normal_sampler_ = MultivariateNormalSampler();
134 ig_sampler_ = InverseGammaSampler();
135 num_components_ = num_components;
136 num_groups_ = num_groups;
137 working_parameter_ = Eigen::VectorXd(num_components_);
138 group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
139 group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
140 working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
141 }
143
145 void ResetFromSample(RandomEffectsContainer& rfx_container, int sample_num);
146
148 void SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
149 void SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
150 void SampleGroupParameters(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
151 void SampleVarianceComponents(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
152
154 void SetWorkingParameter(Eigen::VectorXd& working_parameter) {
155 working_parameter_ = working_parameter;
156 }
157 void SetGroupParameters(Eigen::MatrixXd& group_parameters) {
158 group_parameters_ = group_parameters;
159 }
160 void SetGroupParameter(Eigen::VectorXd& group_parameter, int32_t group_id) {
161 group_parameters_(Eigen::all, group_id) = group_parameter;
162 }
163 void SetWorkingParameterCovariance(Eigen::MatrixXd& working_parameter_covariance) {
164 working_parameter_covariance_ = working_parameter_covariance;
165 }
166 void SetGroupParameterCovariance(Eigen::MatrixXd& group_parameter_covariance) {
167 group_parameter_covariance_ = group_parameter_covariance;
168 }
169 void SetGroupParameterVarianceComponent(double value, int32_t component_id) {
170 group_parameter_covariance_(component_id, component_id) = value;
171 }
172 void SetVariancePriorShape(double value) {
173 variance_prior_shape_ = value;
174 }
175 void SetVariancePriorScale(double value) {
176 variance_prior_scale_ = value;
177 }
178
180 Eigen::VectorXd& GetWorkingParameter() {
181 return working_parameter_;
182 }
183 Eigen::MatrixXd& GetGroupParameters() {
184 return group_parameters_;
185 }
186 Eigen::MatrixXd& GetWorkingParameterCovariance() {
187 return working_parameter_covariance_;
188 }
189 Eigen::MatrixXd& GetGroupParameterCovariance() {
190 return group_parameter_covariance_;
191 }
192 double GetVariancePriorShape() {
193 return variance_prior_shape_;
194 }
195 double GetVariancePriorScale() {
196 return variance_prior_scale_;
197 }
198 inline int32_t NumComponents() {return num_components_;}
199 inline int32_t NumGroups() {return num_groups_;}
200
201 std::vector<double> Predict(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker) {
202 std::vector<double> output(dataset.NumObservations());
203 PredictInplace(dataset, tracker, output);
204 return output;
205 }
206
207 void PredictInplace(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, std::vector<double>& output) {
208 Eigen::MatrixXd X = dataset.GetBasis();
209 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
210 CHECK_EQ(X.rows(), group_labels.size());
211 int n = X.rows();
212 CHECK_EQ(n, output.size());
213 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
214 std::int32_t group_ind;
215 for (int i = 0; i < n; i++) {
216 group_ind = tracker.CategoryNumber(group_labels[i]);
217 output[i] = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
218 }
219 }
220
221 void AddCurrentPredictionToResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
222 data_size_t n = dataset.NumObservations();
223 CHECK_EQ(n, residual.NumRows());
224 double current_pred;
225 double new_resid;
226 for (data_size_t i = 0; i < n; i++) {
227 current_pred = tracker.GetPrediction(i);
228 new_resid = residual.GetElement(i) + current_pred;
229 residual.SetElement(i, new_resid);
230 }
231 }
232
233 void SubtractNewPredictionFromResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
234 Eigen::MatrixXd X = dataset.GetBasis();
235 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
236 CHECK_EQ(X.rows(), group_labels.size());
237 int n = X.rows();
238 double new_pred;
239 double new_resid;
240 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
241 std::int32_t group_ind;
242 for (int i = 0; i < n; i++) {
243 group_ind = tracker.CategoryNumber(group_labels[i]);
244 new_pred = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
245 new_resid = residual.GetElement(i) - new_pred;
246 residual.SetElement(i, new_resid);
247 tracker.SetPrediction(i, new_pred);
248 }
249 }
250
252 Eigen::VectorXd WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance);
254 Eigen::MatrixXd WorkingParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance);
256 Eigen::VectorXd GroupParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id);
258 Eigen::MatrixXd GroupParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id);
260 double VarianceComponentShape(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
262 double VarianceComponentScale(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
263
264 private:
266 MultivariateNormalSampler normal_sampler_;
267 InverseGammaSampler ig_sampler_;
268
270 int num_components_;
271 int num_groups_;
272
276 Eigen::VectorXd working_parameter_;
277 Eigen::MatrixXd group_parameters_;
278
280 Eigen::MatrixXd group_parameter_covariance_;
281
283 Eigen::MatrixXd working_parameter_covariance_;
284
286 double variance_prior_shape_;
287 double variance_prior_scale_;
288};
289
291 public:
292 RandomEffectsContainer(int num_components, int num_groups) {
293 num_components_ = num_components;
294 num_groups_ = num_groups;
295 num_samples_ = 0;
296 }
298 num_components_ = 0;
299 num_groups_ = 0;
300 num_samples_ = 0;
301 }
303 void SaveToJsonFile(std::string filename) {
304 nlohmann::json model_json = this->to_json();
305 std::ofstream output_file(filename);
306 output_file << model_json << std::endl;
307 }
308 void LoadFromJsonFile(std::string filename) {
309 std::ifstream f(filename);
310 nlohmann::json rfx_container_json = nlohmann::json::parse(f);
311 this->Reset();
312 this->from_json(rfx_container_json);
313 }
314 std::string DumpJsonString() {
315 nlohmann::json model_json = this->to_json();
316 return model_json.dump();
317 }
318 void LoadFromJsonString(std::string& json_string) {
319 nlohmann::json rfx_container_json = nlohmann::json::parse(json_string);
320 this->Reset();
321 this->from_json(rfx_container_json);
322 }
323 void AddSample(MultivariateRegressionRandomEffectsModel& model);
324 void DeleteSample(int sample_num);
325 void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector<double>& output);
326 inline int NumSamples() {return num_samples_;}
327 inline int NumComponents() {return num_components_;}
328 inline int NumGroups() {return num_groups_;}
329 inline void SetNumSamples(int num_samples) {num_samples_ = num_samples;}
330 inline void SetNumComponents(int num_components) {num_components_ = num_components;}
331 inline void SetNumGroups(int num_groups) {num_groups_ = num_groups;}
332 void Reset() {
333 num_samples_ = 0;
334 num_components_ = 0;
335 num_groups_ = 0;
336 beta_.clear();
337 alpha_.clear();
338 xi_.clear();
339 sigma_xi_.clear();
340 }
341 std::vector<double>& GetBeta() {return beta_;}
342 std::vector<double>& GetAlpha() {return alpha_;}
343 std::vector<double>& GetXi() {return xi_;}
344 std::vector<double>& GetSigma() {return sigma_xi_;}
345 nlohmann::json to_json();
346 void from_json(const nlohmann::json& rfx_container_json);
347 void append_from_json(const nlohmann::json& rfx_container_json);
348 private:
349 int num_samples_;
350 int num_components_;
351 int num_groups_;
352 std::vector<double> beta_;
353 std::vector<double> alpha_;
354 std::vector<double> xi_;
355 std::vector<double> sigma_xi_;
356 void AddAlpha(MultivariateRegressionRandomEffectsModel& model);
358 void AddSigma(MultivariateRegressionRandomEffectsModel& model);
359};
360
361} // namespace StochTree
362
363#endif // STOCHTREE_RANDOM_EFFECTS_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:194
Definition ig_sampler.h:9
Standalone container for the map from category IDs to 0-based indices.
Definition random_effects.h:80
Definition normal_sampler.h:24
Posterior computation and sampling and state storage for random effects model with a group-level mult...
Definition random_effects.h:130
Eigen::VectorXd GroupParameterMean(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t group_id)
Compute the posterior mean of a group parameter, conditional on the working parameter and the varianc...
void ResetFromSample(RandomEffectsContainer &rfx_container, int sample_num)
Reconstruction from serialized model parameter samples.
Eigen::VectorXd & GetWorkingParameter()
Getters.
Definition random_effects.h:180
double VarianceComponentScale(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t component_id)
Compute the posterior scale of the group variance component, conditional on the working and group par...
Eigen::MatrixXd GroupParameterVariance(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t group_id)
Compute the posterior covariance of a group parameter, conditional on the working parameter and the v...
void SampleRandomEffects(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &tracker, double global_variance, std::mt19937 &gen)
Samplers.
Eigen::VectorXd WorkingParameterMean(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance)
Compute the posterior mean of the working parameter, conditional on the group parameters and the vari...
Eigen::MatrixXd WorkingParameterVariance(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance)
Compute the posterior covariance of the working parameter, conditional on the group parameters and th...
double VarianceComponentShape(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t component_id)
Compute the posterior shape of the group variance component, conditional on the working and group par...
void SetWorkingParameter(Eigen::VectorXd &working_parameter)
Setters.
Definition random_effects.h:154
Definition random_effects.h:290
API for loading and accessing data used to sample (additive) random effects.
Definition data.h:486
Wrapper around data structures for random effects sampling algorithms.
Definition random_effects.h:38
void RootReset(MultivariateRegressionRandomEffectsModel &rfx_model, RandomEffectsDataset &rfx_dataset, ColumnVector &residual)
Resets RFX tracker to initial default. Assumes tracker already exists in main memory....
void ResetFromSample(MultivariateRegressionRandomEffectsModel &rfx_model, RandomEffectsDataset &rfx_dataset, ColumnVector &residual)
Resets RFX tracker based on a specific sample. Assumes tracker already exists in main memory.
Class storing sample-node map for each tree in an ensemble TODO: Add run-time checks for categories w...
Definition category_tracker.h:45
Definition category_tracker.h:40