39 inline data_size_t GetCategoryId(
int observation_num) {
return sample_category_mapper_->GetCategoryId(observation_num);}
40 inline data_size_t CategoryBegin(
int category_id) {
return category_sample_tracker_->CategoryBegin(category_id);}
41 inline data_size_t CategoryEnd(
int category_id) {
return category_sample_tracker_->CategoryEnd(category_id);}
42 inline data_size_t CategorySize(
int category_id) {
return category_sample_tracker_->CategorySize(category_id);}
43 inline int32_t NumCategories() {
return num_categories_;}
44 inline int32_t CategoryNumber(int32_t category_id) {
return category_sample_tracker_->CategoryNumber(category_id);}
46 CategorySampleTracker* GetCategorySampleTracker() {
return category_sample_tracker_.get();}
47 std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(
int category_id);
48 std::vector<data_size_t>::iterator UnsortedNodeEndIterator(
int category_id);
49 std::map<int32_t, int32_t>& GetLabelMap() {
return category_sample_tracker_->GetLabelMap();}
50 std::vector<int32_t>& GetUniqueGroupIds() {
return category_sample_tracker_->GetUniqueGroupIds();}
51 std::vector<data_size_t>& NodeIndices(
int category_id) {
return category_sample_tracker_->NodeIndices(category_id);}
52 std::vector<data_size_t>& NodeIndicesInternalIndex(
int internal_category_id) {
return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
53 double GetPrediction(data_size_t observation_num) {
return rfx_predictions_.at(observation_num);}
54 void SetPrediction(data_size_t observation_num,
double pred) {rfx_predictions_.at(observation_num) = pred;}
66 std::unique_ptr<SampleCategoryMapper> sample_category_mapper_;
68 std::unique_ptr<CategorySampleTracker> category_sample_tracker_;
70 std::vector<double> rfx_predictions_;
73 int num_observations_;
81 label_map_ = label_map;
82 for (
const auto& [key, value] : label_map) keys_.push_back(key);
85 void LoadFromLabelMap(std::map<int32_t, int32_t> label_map) {
86 label_map_ = label_map;
87 for (
const auto& [key, value] : label_map) keys_.push_back(key);
89 bool ContainsLabel(int32_t category_id) {
90 auto pos = label_map_.find(category_id);
91 return pos != label_map_.end();
93 int32_t CategoryNumber(int32_t category_id) {
94 return label_map_[category_id];
96 void SaveToJsonFile(std::string filename) {
97 nlohmann::json model_json = this->to_json();
98 std::ofstream output_file(filename);
99 output_file << model_json << std::endl;
101 void LoadFromJsonFile(std::string filename) {
102 std::ifstream f(filename);
103 nlohmann::json rfx_label_mapper_json = nlohmann::json::parse(f);
105 this->from_json(rfx_label_mapper_json);
107 std::string DumpJsonString() {
108 nlohmann::json model_json = this->to_json();
109 return model_json.dump();
111 void LoadFromJsonString(std::string& json_string) {
112 nlohmann::json rfx_label_mapper_json = nlohmann::json::parse(json_string);
114 this->from_json(rfx_label_mapper_json);
116 std::vector<int32_t>& Keys() {
return keys_;}
117 std::map<int32_t, int32_t>& Map() {
return label_map_;}
118 void Reset() {label_map_.clear(); keys_.clear();}
119 nlohmann::json to_json();
120 void from_json(
const nlohmann::json& rfx_label_mapper_json);
122 std::map<int32_t, int32_t> label_map_;
123 std::vector<int32_t> keys_;
132 num_components_ = num_components;
133 num_groups_ = num_groups;
134 working_parameter_ = Eigen::VectorXd(num_components_);
135 group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
136 group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
137 working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
152 working_parameter_ = working_parameter;
154 void SetGroupParameters(Eigen::MatrixXd& group_parameters) {
155 group_parameters_ = group_parameters;
157 void SetGroupParameter(Eigen::VectorXd& group_parameter, int32_t group_id) {
158 group_parameters_(Eigen::all, group_id) = group_parameter;
160 void SetWorkingParameterCovariance(Eigen::MatrixXd& working_parameter_covariance) {
161 working_parameter_covariance_ = working_parameter_covariance;
163 void SetGroupParameterCovariance(Eigen::MatrixXd& group_parameter_covariance) {
164 group_parameter_covariance_ = group_parameter_covariance;
166 void SetGroupParameterVarianceComponent(
double value, int32_t component_id) {
167 group_parameter_covariance_(component_id, component_id) = value;
169 void SetVariancePriorShape(
double value) {
170 variance_prior_shape_ = value;
172 void SetVariancePriorScale(
double value) {
173 variance_prior_scale_ = value;
178 return working_parameter_;
180 Eigen::MatrixXd& GetGroupParameters() {
181 return group_parameters_;
183 Eigen::MatrixXd& GetWorkingParameterCovariance() {
184 return working_parameter_covariance_;
186 Eigen::MatrixXd& GetGroupParameterCovariance() {
187 return group_parameter_covariance_;
189 double GetVariancePriorShape() {
190 return variance_prior_shape_;
192 double GetVariancePriorScale() {
193 return variance_prior_scale_;
195 inline int32_t NumComponents() {
return num_components_;}
196 inline int32_t NumGroups() {
return num_groups_;}
198 std::vector<double> Predict(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker) {
199 std::vector<double> output(dataset.NumObservations());
200 PredictInplace(dataset, tracker, output);
204 void PredictInplace(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, std::vector<double>& output) {
205 Eigen::MatrixXd X = dataset.GetBasis();
206 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
207 CHECK_EQ(X.rows(), group_labels.size());
209 CHECK_EQ(n, output.size());
210 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
211 std::int32_t group_ind;
212 for (
int i = 0; i < n; i++) {
213 group_ind = tracker.CategoryNumber(group_labels[i]);
214 output[i] = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
218 void AddCurrentPredictionToResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
219 data_size_t n = dataset.NumObservations();
220 CHECK_EQ(n, residual.NumRows());
223 for (data_size_t i = 0; i < n; i++) {
224 current_pred = tracker.GetPrediction(i);
225 new_resid = residual.GetElement(i) + current_pred;
226 residual.SetElement(i, new_resid);
230 void SubtractNewPredictionFromResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
231 Eigen::MatrixXd X = dataset.GetBasis();
232 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
233 CHECK_EQ(X.rows(), group_labels.size());
237 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
238 std::int32_t group_ind;
239 for (
int i = 0; i < n; i++) {
240 group_ind = tracker.CategoryNumber(group_labels[i]);
241 new_pred = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
242 new_resid = residual.GetElement(i) - new_pred;
243 residual.SetElement(i, new_resid);
244 tracker.SetPrediction(i, new_pred);
273 Eigen::VectorXd working_parameter_;
274 Eigen::MatrixXd group_parameters_;
277 Eigen::MatrixXd group_parameter_covariance_;
280 Eigen::MatrixXd working_parameter_covariance_;
283 double variance_prior_shape_;
284 double variance_prior_scale_;
290 num_components_ = num_components;
291 num_groups_ = num_groups;
300 void SaveToJsonFile(std::string filename) {
301 nlohmann::json model_json = this->to_json();
302 std::ofstream output_file(filename);
303 output_file << model_json << std::endl;
305 void LoadFromJsonFile(std::string filename) {
306 std::ifstream f(filename);
307 nlohmann::json rfx_container_json = nlohmann::json::parse(f);
309 this->from_json(rfx_container_json);
311 std::string DumpJsonString() {
312 nlohmann::json model_json = this->to_json();
313 return model_json.dump();
315 void LoadFromJsonString(std::string& json_string) {
316 nlohmann::json rfx_container_json = nlohmann::json::parse(json_string);
318 this->from_json(rfx_container_json);
321 void DeleteSample(
int sample_num);
323 inline int NumSamples() {
return num_samples_;}
324 inline int NumComponents() {
return num_components_;}
325 inline int NumGroups() {
return num_groups_;}
326 inline void SetNumSamples(
int num_samples) {num_samples_ = num_samples;}
327 inline void SetNumComponents(
int num_components) {num_components_ = num_components;}
328 inline void SetNumGroups(
int num_groups) {num_groups_ = num_groups;}
338 std::vector<double>& GetBeta() {
return beta_;}
339 std::vector<double>& GetAlpha() {
return alpha_;}
340 std::vector<double>& GetXi() {
return xi_;}
341 std::vector<double>& GetSigma() {
return sigma_xi_;}
342 nlohmann::json to_json();
343 void from_json(
const nlohmann::json& rfx_container_json);
344 void append_from_json(
const nlohmann::json& rfx_container_json);
349 std::vector<double> beta_;
350 std::vector<double> alpha_;
351 std::vector<double> xi_;
352 std::vector<double> sigma_xi_;