StochTree 0.0.1
Loading...
Searching...
No Matches
random_effects.h
1
5#ifndef STOCHTREE_RANDOM_EFFECTS_H_
6#define STOCHTREE_RANDOM_EFFECTS_H_
7
8#include <stochtree/category_tracker.h>
9#include <stochtree/cutpoint_candidates.h>
10#include <stochtree/data.h>
11#include <stochtree/ensemble.h>
12#include <stochtree/ig_sampler.h>
13#include <stochtree/log.h>
14#include <stochtree/normal_sampler.h>
15#include <stochtree/partition_tracker.h>
16#include <stochtree/prior.h>
17#include <nlohmann/json.hpp>
18#include <Eigen/Dense>
19
20#include <cmath>
21#include <map>
22#include <memory>
23#include <random>
24#include <set>
25#include <string>
26#include <type_traits>
27#include <vector>
28
29namespace StochTree {
30
32class LabelMapper;
33class MultivariateRegressionRandomEffectsModel;
34class RandomEffectsContainer;
35
38 public:
39 RandomEffectsTracker(std::vector<int32_t>& group_indices);
41 inline data_size_t GetCategoryId(int observation_num) {return sample_category_mapper_->GetCategoryId(observation_num);}
42 inline data_size_t CategoryBegin(int category_id) {return category_sample_tracker_->CategoryBegin(category_id);}
43 inline data_size_t CategoryEnd(int category_id) {return category_sample_tracker_->CategoryEnd(category_id);}
44 inline data_size_t CategorySize(int category_id) {return category_sample_tracker_->CategorySize(category_id);}
45 inline int32_t NumCategories() {return num_categories_;}
46 inline int32_t CategoryNumber(int32_t category_id) {return category_sample_tracker_->CategoryNumber(category_id);}
47 SampleCategoryMapper* GetSampleCategoryMapper() {return sample_category_mapper_.get();}
48 CategorySampleTracker* GetCategorySampleTracker() {return category_sample_tracker_.get();}
49 std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(int category_id);
50 std::vector<data_size_t>::iterator UnsortedNodeEndIterator(int category_id);
51 std::map<int32_t, int32_t>& GetLabelMap() {return category_sample_tracker_->GetLabelMap();}
52 std::vector<int32_t>& GetUniqueGroupIds() {return category_sample_tracker_->GetUniqueGroupIds();}
53 std::vector<data_size_t>& NodeIndices(int category_id) {return category_sample_tracker_->NodeIndices(category_id);}
54 std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
55 double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);}
56 void SetPrediction(data_size_t observation_num, double pred) {rfx_predictions_.at(observation_num) = pred;}
59 RandomEffectsDataset& rfx_dataset, ColumnVector& residual);
64 RandomEffectsDataset& rfx_dataset, ColumnVector& residual);
65
66 private:
68 std::unique_ptr<SampleCategoryMapper> sample_category_mapper_;
70 std::unique_ptr<CategorySampleTracker> category_sample_tracker_;
72 std::vector<double> rfx_predictions_;
74 int num_categories_;
75 int num_observations_;
76};
77
80 public:
81 LabelMapper() {}
82 LabelMapper(std::map<int32_t, int32_t> label_map) {
83 label_map_ = label_map;
84 for (const auto& [key, value] : label_map) keys_.push_back(key);
85 }
86 ~LabelMapper() {}
87 bool ContainsLabel(int32_t category_id) {
88 auto pos = label_map_.find(category_id);
89 return pos != label_map_.end();
90 }
91 int32_t CategoryNumber(int32_t category_id) {
92 return label_map_[category_id];
93 }
94 std::vector<int32_t>& Keys() {return keys_;}
95 std::map<int32_t, int32_t>& Map() {return label_map_;}
96 void Reset() {label_map_.clear(); keys_.clear();}
97 nlohmann::json to_json();
98 void from_json(const nlohmann::json& rfx_label_mapper_json);
99 private:
100 std::map<int32_t, int32_t> label_map_;
101 std::vector<int32_t> keys_;
102};
103
106 public:
107 MultivariateRegressionRandomEffectsModel(int num_components, int num_groups) {
108 normal_sampler_ = MultivariateNormalSampler();
109 ig_sampler_ = InverseGammaSampler();
110 num_components_ = num_components;
111 num_groups_ = num_groups;
112 working_parameter_ = Eigen::VectorXd(num_components_);
113 group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
114 group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
115 working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
116 }
118
120 void ResetFromSample(RandomEffectsContainer& rfx_container, int sample_num);
121
123 void SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
124 void SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
125 void SampleGroupParameters(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
126 void SampleVarianceComponents(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
127
129 void SetWorkingParameter(Eigen::VectorXd& working_parameter) {
130 working_parameter_ = working_parameter;
131 }
132 void SetGroupParameters(Eigen::MatrixXd& group_parameters) {
133 group_parameters_ = group_parameters;
134 }
135 void SetGroupParameter(Eigen::VectorXd& group_parameter, int32_t group_id) {
136 group_parameters_(Eigen::all, group_id) = group_parameter;
137 }
138 void SetWorkingParameterCovariance(Eigen::MatrixXd& working_parameter_covariance) {
139 working_parameter_covariance_ = working_parameter_covariance;
140 }
141 void SetGroupParameterCovariance(Eigen::MatrixXd& group_parameter_covariance) {
142 group_parameter_covariance_ = group_parameter_covariance;
143 }
144 void SetGroupParameterVarianceComponent(double value, int32_t component_id) {
145 group_parameter_covariance_(component_id, component_id) = value;
146 }
147 void SetVariancePriorShape(double value) {
148 variance_prior_shape_ = value;
149 }
150 void SetVariancePriorScale(double value) {
151 variance_prior_scale_ = value;
152 }
153
155 Eigen::VectorXd& GetWorkingParameter() {
156 return working_parameter_;
157 }
158 Eigen::MatrixXd& GetGroupParameters() {
159 return group_parameters_;
160 }
161 Eigen::MatrixXd& GetWorkingParameterCovariance() {
162 return working_parameter_covariance_;
163 }
164 Eigen::MatrixXd& GetGroupParameterCovariance() {
165 return group_parameter_covariance_;
166 }
167 double GetVariancePriorShape() {
168 return variance_prior_shape_;
169 }
170 double GetVariancePriorScale() {
171 return variance_prior_scale_;
172 }
173 inline int32_t NumComponents() {return num_components_;}
174 inline int32_t NumGroups() {return num_groups_;}
175
176 std::vector<double> Predict(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker) {
177 std::vector<double> output(dataset.NumObservations());
178 PredictInplace(dataset, tracker, output);
179 return output;
180 }
181
182 void PredictInplace(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, std::vector<double>& output) {
183 Eigen::MatrixXd X = dataset.GetBasis();
184 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
185 CHECK_EQ(X.rows(), group_labels.size());
186 int n = X.rows();
187 CHECK_EQ(n, output.size());
188 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
189 std::int32_t group_ind;
190 for (int i = 0; i < n; i++) {
191 group_ind = tracker.CategoryNumber(group_labels[i]);
192 output[i] = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
193 }
194 }
195
196 void AddCurrentPredictionToResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
197 data_size_t n = dataset.NumObservations();
198 CHECK_EQ(n, residual.NumRows());
199 double current_pred;
200 double new_resid;
201 for (data_size_t i = 0; i < n; i++) {
202 current_pred = tracker.GetPrediction(i);
203 new_resid = residual.GetElement(i) + current_pred;
204 residual.SetElement(i, new_resid);
205 }
206 }
207
208 void SubtractNewPredictionFromResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
209 Eigen::MatrixXd X = dataset.GetBasis();
210 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
211 CHECK_EQ(X.rows(), group_labels.size());
212 int n = X.rows();
213 double new_pred;
214 double new_resid;
215 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
216 std::int32_t group_ind;
217 for (int i = 0; i < n; i++) {
218 group_ind = tracker.CategoryNumber(group_labels[i]);
219 new_pred = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
220 new_resid = residual.GetElement(i) - new_pred;
221 residual.SetElement(i, new_resid);
222 tracker.SetPrediction(i, new_pred);
223 }
224 }
225
227 Eigen::VectorXd WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance);
229 Eigen::MatrixXd WorkingParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance);
231 Eigen::VectorXd GroupParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id);
233 Eigen::MatrixXd GroupParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id);
235 double VarianceComponentShape(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
237 double VarianceComponentScale(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
238
239 private:
241 MultivariateNormalSampler normal_sampler_;
242 InverseGammaSampler ig_sampler_;
243
245 int num_components_;
246 int num_groups_;
247
251 Eigen::VectorXd working_parameter_;
252 Eigen::MatrixXd group_parameters_;
253
255 Eigen::MatrixXd group_parameter_covariance_;
256
258 Eigen::MatrixXd working_parameter_covariance_;
259
261 double variance_prior_shape_;
262 double variance_prior_scale_;
263};
264
266 public:
267 RandomEffectsContainer(int num_components, int num_groups) {
268 num_components_ = num_components;
269 num_groups_ = num_groups;
270 num_samples_ = 0;
271 }
273 num_components_ = 0;
274 num_groups_ = 0;
275 num_samples_ = 0;
276 }
278 void AddSample(MultivariateRegressionRandomEffectsModel& model);
279 void DeleteSample(int sample_num);
280 void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector<double>& output);
281 int NumSamples() {return num_samples_;}
282 int NumComponents() {return num_components_;}
283 int NumGroups() {return num_groups_;}
284 void Reset() {
285 num_samples_ = 0;
286 num_components_ = 0;
287 num_groups_ = 0;
288 beta_.clear();
289 alpha_.clear();
290 xi_.clear();
291 sigma_xi_.clear();
292 }
293 std::vector<double>& GetBeta() {return beta_;}
294 std::vector<double>& GetAlpha() {return alpha_;}
295 std::vector<double>& GetXi() {return xi_;}
296 std::vector<double>& GetSigma() {return sigma_xi_;}
297 nlohmann::json to_json();
298 void from_json(const nlohmann::json& rfx_container_json);
299 void append_from_json(const nlohmann::json& rfx_container_json);
300 private:
301 int num_samples_;
302 int num_components_;
303 int num_groups_;
304 std::vector<double> beta_;
305 std::vector<double> alpha_;
306 std::vector<double> xi_;
307 std::vector<double> sigma_xi_;
308 void AddAlpha(MultivariateRegressionRandomEffectsModel& model);
310 void AddSigma(MultivariateRegressionRandomEffectsModel& model);
311};
312
313} // namespace StochTree
314
315#endif // STOCHTREE_RANDOM_EFFECTS_H_
Internal wrapper around Eigen::VectorXd interface for univariate floating point data....
Definition data.h:194
Definition ig_sampler.h:9
Standalone container for the map from category IDs to 0-based indices.
Definition random_effects.h:79
Definition normal_sampler.h:24
Posterior computation and sampling and state storage for random effects model with a group-level mult...
Definition random_effects.h:105
Eigen::VectorXd GroupParameterMean(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t group_id)
Compute the posterior mean of a group parameter, conditional on the working parameter and the varianc...
void ResetFromSample(RandomEffectsContainer &rfx_container, int sample_num)
Reconstruction from serialized model parameter samples.
Eigen::VectorXd & GetWorkingParameter()
Getters.
Definition random_effects.h:155
double VarianceComponentScale(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t component_id)
Compute the posterior scale of the group variance component, conditional on the working and group par...
Eigen::MatrixXd GroupParameterVariance(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t group_id)
Compute the posterior covariance of a group parameter, conditional on the working parameter and the v...
void SampleRandomEffects(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &tracker, double global_variance, std::mt19937 &gen)
Samplers.
Eigen::VectorXd WorkingParameterMean(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance)
Compute the posterior mean of the working parameter, conditional on the group parameters and the vari...
Eigen::MatrixXd WorkingParameterVariance(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance)
Compute the posterior covariance of the working parameter, conditional on the group parameters and th...
double VarianceComponentShape(RandomEffectsDataset &dataset, ColumnVector &residual, RandomEffectsTracker &rfx_tracker, double global_variance, int32_t component_id)
Compute the posterior shape of the group variance component, conditional on the working and group par...
void SetWorkingParameter(Eigen::VectorXd &working_parameter)
Setters.
Definition random_effects.h:129
Definition random_effects.h:265
API for loading and accessing data used to sample (additive) random effects.
Definition data.h:486
Wrapper around data structures for random effects sampling algorithms.
Definition random_effects.h:37
void RootReset(MultivariateRegressionRandomEffectsModel &rfx_model, RandomEffectsDataset &rfx_dataset, ColumnVector &residual)
Resets RFX tracker to initial default. Assumes tracker already exists in main memory....
void ResetFromSample(MultivariateRegressionRandomEffectsModel &rfx_model, RandomEffectsDataset &rfx_dataset, ColumnVector &residual)
Resets RFX tracker based on a specific sample. Assumes tracker already exists in main memory.
Class storing sample-node map for each tree in an ensemble TODO: Add run-time checks for categories w...
Definition category_tracker.h:45
Definition category_tracker.h:40