42 inline data_size_t GetCategoryId(
int observation_num) {
return sample_category_mapper_->GetCategoryId(observation_num);}
43 inline data_size_t CategoryBegin(
int category_id) {
return category_sample_tracker_->CategoryBegin(category_id);}
44 inline data_size_t CategoryEnd(
int category_id) {
return category_sample_tracker_->CategoryEnd(category_id);}
45 inline data_size_t CategorySize(
int category_id) {
return category_sample_tracker_->CategorySize(category_id);}
46 inline int32_t NumCategories() {
return num_categories_;}
47 inline int32_t CategoryNumber(int32_t category_id) {
return category_sample_tracker_->CategoryNumber(category_id);}
49 CategorySampleTracker* GetCategorySampleTracker() {
return category_sample_tracker_.get();}
50 std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(
int category_id);
51 std::vector<data_size_t>::iterator UnsortedNodeEndIterator(
int category_id);
52 std::map<int32_t, int32_t>& GetLabelMap() {
return category_sample_tracker_->GetLabelMap();}
53 std::vector<int32_t>& GetUniqueGroupIds() {
return category_sample_tracker_->GetUniqueGroupIds();}
54 std::vector<data_size_t>& NodeIndices(
int category_id) {
return category_sample_tracker_->NodeIndices(category_id);}
55 std::vector<data_size_t>& NodeIndicesInternalIndex(
int internal_category_id) {
return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
56 double GetPrediction(data_size_t observation_num) {
return rfx_predictions_.at(observation_num);}
57 void SetPrediction(data_size_t observation_num,
double pred) {rfx_predictions_.at(observation_num) = pred;}
69 std::unique_ptr<SampleCategoryMapper> sample_category_mapper_;
71 std::unique_ptr<CategorySampleTracker> category_sample_tracker_;
73 std::vector<double> rfx_predictions_;
76 int num_observations_;
84 label_map_ = label_map;
85 for (
const auto& [key, value] : label_map) keys_.push_back(key);
88 void LoadFromLabelMap(std::map<int32_t, int32_t> label_map) {
89 label_map_ = label_map;
90 for (
const auto& [key, value] : label_map) keys_.push_back(key);
92 bool ContainsLabel(int32_t category_id) {
93 auto pos = label_map_.find(category_id);
94 return pos != label_map_.end();
96 int32_t CategoryNumber(int32_t category_id) {
97 return label_map_[category_id];
99 void SaveToJsonFile(std::string filename) {
100 nlohmann::json model_json = this->to_json();
101 std::ofstream output_file(filename);
102 output_file << model_json << std::endl;
104 void LoadFromJsonFile(std::string filename) {
105 std::ifstream f(filename);
106 nlohmann::json rfx_label_mapper_json = nlohmann::json::parse(f);
108 this->from_json(rfx_label_mapper_json);
110 std::string DumpJsonString() {
111 nlohmann::json model_json = this->to_json();
112 return model_json.dump();
114 void LoadFromJsonString(std::string& json_string) {
115 nlohmann::json rfx_label_mapper_json = nlohmann::json::parse(json_string);
117 this->from_json(rfx_label_mapper_json);
119 std::vector<int32_t>& Keys() {
return keys_;}
120 std::map<int32_t, int32_t>& Map() {
return label_map_;}
121 void Reset() {label_map_.clear(); keys_.clear();}
122 nlohmann::json to_json();
123 void from_json(
const nlohmann::json& rfx_label_mapper_json);
125 std::map<int32_t, int32_t> label_map_;
126 std::vector<int32_t> keys_;
135 num_components_ = num_components;
136 num_groups_ = num_groups;
137 working_parameter_ = Eigen::VectorXd(num_components_);
138 group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
139 group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
140 working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
155 working_parameter_ = working_parameter;
157 void SetGroupParameters(Eigen::MatrixXd& group_parameters) {
158 group_parameters_ = group_parameters;
160 void SetGroupParameter(Eigen::VectorXd& group_parameter, int32_t group_id) {
161 group_parameters_(Eigen::all, group_id) = group_parameter;
163 void SetWorkingParameterCovariance(Eigen::MatrixXd& working_parameter_covariance) {
164 working_parameter_covariance_ = working_parameter_covariance;
166 void SetGroupParameterCovariance(Eigen::MatrixXd& group_parameter_covariance) {
167 group_parameter_covariance_ = group_parameter_covariance;
169 void SetGroupParameterVarianceComponent(
double value, int32_t component_id) {
170 group_parameter_covariance_(component_id, component_id) = value;
172 void SetVariancePriorShape(
double value) {
173 variance_prior_shape_ = value;
175 void SetVariancePriorScale(
double value) {
176 variance_prior_scale_ = value;
181 return working_parameter_;
183 Eigen::MatrixXd& GetGroupParameters() {
184 return group_parameters_;
186 Eigen::MatrixXd& GetWorkingParameterCovariance() {
187 return working_parameter_covariance_;
189 Eigen::MatrixXd& GetGroupParameterCovariance() {
190 return group_parameter_covariance_;
192 double GetVariancePriorShape() {
193 return variance_prior_shape_;
195 double GetVariancePriorScale() {
196 return variance_prior_scale_;
198 inline int32_t NumComponents() {
return num_components_;}
199 inline int32_t NumGroups() {
return num_groups_;}
201 std::vector<double> Predict(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker) {
202 std::vector<double> output(dataset.NumObservations());
203 PredictInplace(dataset, tracker, output);
207 void PredictInplace(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, std::vector<double>& output) {
208 Eigen::MatrixXd X = dataset.GetBasis();
209 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
210 CHECK_EQ(X.rows(), group_labels.size());
212 CHECK_EQ(n, output.size());
213 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
214 std::int32_t group_ind;
215 for (
int i = 0; i < n; i++) {
216 group_ind = tracker.CategoryNumber(group_labels[i]);
217 output[i] = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
221 void AddCurrentPredictionToResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
222 data_size_t n = dataset.NumObservations();
223 CHECK_EQ(n, residual.NumRows());
226 for (data_size_t i = 0; i < n; i++) {
227 current_pred = tracker.GetPrediction(i);
228 new_resid = residual.GetElement(i) + current_pred;
229 residual.SetElement(i, new_resid);
233 void SubtractNewPredictionFromResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
234 Eigen::MatrixXd X = dataset.GetBasis();
235 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
236 CHECK_EQ(X.rows(), group_labels.size());
240 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
241 std::int32_t group_ind;
242 for (
int i = 0; i < n; i++) {
243 group_ind = tracker.CategoryNumber(group_labels[i]);
244 new_pred = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
245 new_resid = residual.GetElement(i) - new_pred;
246 residual.SetElement(i, new_resid);
247 tracker.SetPrediction(i, new_pred);
276 Eigen::VectorXd working_parameter_;
277 Eigen::MatrixXd group_parameters_;
280 Eigen::MatrixXd group_parameter_covariance_;
283 Eigen::MatrixXd working_parameter_covariance_;
286 double variance_prior_shape_;
287 double variance_prior_scale_;
293 num_components_ = num_components;
294 num_groups_ = num_groups;
303 void SaveToJsonFile(std::string filename) {
304 nlohmann::json model_json = this->to_json();
305 std::ofstream output_file(filename);
306 output_file << model_json << std::endl;
308 void LoadFromJsonFile(std::string filename) {
309 std::ifstream f(filename);
310 nlohmann::json rfx_container_json = nlohmann::json::parse(f);
312 this->from_json(rfx_container_json);
314 std::string DumpJsonString() {
315 nlohmann::json model_json = this->to_json();
316 return model_json.dump();
318 void LoadFromJsonString(std::string& json_string) {
319 nlohmann::json rfx_container_json = nlohmann::json::parse(json_string);
321 this->from_json(rfx_container_json);
324 void DeleteSample(
int sample_num);
326 inline int NumSamples() {
return num_samples_;}
327 inline int NumComponents() {
return num_components_;}
328 inline int NumGroups() {
return num_groups_;}
329 inline void SetNumSamples(
int num_samples) {num_samples_ = num_samples;}
330 inline void SetNumComponents(
int num_components) {num_components_ = num_components;}
331 inline void SetNumGroups(
int num_groups) {num_groups_ = num_groups;}
341 std::vector<double>& GetBeta() {
return beta_;}
342 std::vector<double>& GetAlpha() {
return alpha_;}
343 std::vector<double>& GetXi() {
return xi_;}
344 std::vector<double>& GetSigma() {
return sigma_xi_;}
345 nlohmann::json to_json();
346 void from_json(
const nlohmann::json& rfx_container_json);
347 void append_from_json(
const nlohmann::json& rfx_container_json);
352 std::vector<double> beta_;
353 std::vector<double> alpha_;
354 std::vector<double> xi_;
355 std::vector<double> sigma_xi_;