41 inline data_size_t GetCategoryId(
int observation_num) {
return sample_category_mapper_->GetCategoryId(observation_num);}
42 inline data_size_t CategoryBegin(
int category_id) {
return category_sample_tracker_->CategoryBegin(category_id);}
43 inline data_size_t CategoryEnd(
int category_id) {
return category_sample_tracker_->CategoryEnd(category_id);}
44 inline data_size_t CategorySize(
int category_id) {
return category_sample_tracker_->CategorySize(category_id);}
45 inline int32_t NumCategories() {
return num_categories_;}
46 inline int32_t CategoryNumber(int32_t category_id) {
return category_sample_tracker_->CategoryNumber(category_id);}
48 CategorySampleTracker* GetCategorySampleTracker() {
return category_sample_tracker_.get();}
49 std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(
int category_id);
50 std::vector<data_size_t>::iterator UnsortedNodeEndIterator(
int category_id);
51 std::map<int32_t, int32_t>& GetLabelMap() {
return category_sample_tracker_->GetLabelMap();}
52 std::vector<int32_t>& GetUniqueGroupIds() {
return category_sample_tracker_->GetUniqueGroupIds();}
53 std::vector<data_size_t>& NodeIndices(
int category_id) {
return category_sample_tracker_->NodeIndices(category_id);}
54 std::vector<data_size_t>& NodeIndicesInternalIndex(
int internal_category_id) {
return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
55 double GetPrediction(data_size_t observation_num) {
return rfx_predictions_.at(observation_num);}
56 void SetPrediction(data_size_t observation_num,
double pred) {rfx_predictions_.at(observation_num) = pred;}
68 std::unique_ptr<SampleCategoryMapper> sample_category_mapper_;
70 std::unique_ptr<CategorySampleTracker> category_sample_tracker_;
72 std::vector<double> rfx_predictions_;
75 int num_observations_;
83 label_map_ = label_map;
84 for (
const auto& [key, value] : label_map) keys_.push_back(key);
87 bool ContainsLabel(int32_t category_id) {
88 auto pos = label_map_.find(category_id);
89 return pos != label_map_.end();
91 int32_t CategoryNumber(int32_t category_id) {
92 return label_map_[category_id];
94 std::vector<int32_t>& Keys() {
return keys_;}
95 std::map<int32_t, int32_t>& Map() {
return label_map_;}
96 void Reset() {label_map_.clear(); keys_.clear();}
97 nlohmann::json to_json();
98 void from_json(
const nlohmann::json& rfx_label_mapper_json);
100 std::map<int32_t, int32_t> label_map_;
101 std::vector<int32_t> keys_;
110 num_components_ = num_components;
111 num_groups_ = num_groups;
112 working_parameter_ = Eigen::VectorXd(num_components_);
113 group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
114 group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
115 working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
130 working_parameter_ = working_parameter;
132 void SetGroupParameters(Eigen::MatrixXd& group_parameters) {
133 group_parameters_ = group_parameters;
135 void SetGroupParameter(Eigen::VectorXd& group_parameter, int32_t group_id) {
136 group_parameters_(Eigen::all, group_id) = group_parameter;
138 void SetWorkingParameterCovariance(Eigen::MatrixXd& working_parameter_covariance) {
139 working_parameter_covariance_ = working_parameter_covariance;
141 void SetGroupParameterCovariance(Eigen::MatrixXd& group_parameter_covariance) {
142 group_parameter_covariance_ = group_parameter_covariance;
144 void SetGroupParameterVarianceComponent(
double value, int32_t component_id) {
145 group_parameter_covariance_(component_id, component_id) = value;
147 void SetVariancePriorShape(
double value) {
148 variance_prior_shape_ = value;
150 void SetVariancePriorScale(
double value) {
151 variance_prior_scale_ = value;
156 return working_parameter_;
158 Eigen::MatrixXd& GetGroupParameters() {
159 return group_parameters_;
161 Eigen::MatrixXd& GetWorkingParameterCovariance() {
162 return working_parameter_covariance_;
164 Eigen::MatrixXd& GetGroupParameterCovariance() {
165 return group_parameter_covariance_;
167 double GetVariancePriorShape() {
168 return variance_prior_shape_;
170 double GetVariancePriorScale() {
171 return variance_prior_scale_;
173 inline int32_t NumComponents() {
return num_components_;}
174 inline int32_t NumGroups() {
return num_groups_;}
176 std::vector<double> Predict(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker) {
177 std::vector<double> output(dataset.NumObservations());
178 PredictInplace(dataset, tracker, output);
182 void PredictInplace(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, std::vector<double>& output) {
183 Eigen::MatrixXd X = dataset.GetBasis();
184 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
185 CHECK_EQ(X.rows(), group_labels.size());
187 CHECK_EQ(n, output.size());
188 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
189 std::int32_t group_ind;
190 for (
int i = 0; i < n; i++) {
191 group_ind = tracker.CategoryNumber(group_labels[i]);
192 output[i] = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
196 void AddCurrentPredictionToResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
197 data_size_t n = dataset.NumObservations();
198 CHECK_EQ(n, residual.NumRows());
201 for (data_size_t i = 0; i < n; i++) {
202 current_pred = tracker.GetPrediction(i);
203 new_resid = residual.GetElement(i) + current_pred;
204 residual.SetElement(i, new_resid);
208 void SubtractNewPredictionFromResidual(RandomEffectsDataset& dataset, RandomEffectsTracker& tracker, ColumnVector& residual) {
209 Eigen::MatrixXd X = dataset.GetBasis();
210 std::vector<int32_t> group_labels = dataset.GetGroupLabels();
211 CHECK_EQ(X.rows(), group_labels.size());
215 Eigen::MatrixXd alpha_diag = working_parameter_.asDiagonal().toDenseMatrix();
216 std::int32_t group_ind;
217 for (
int i = 0; i < n; i++) {
218 group_ind = tracker.CategoryNumber(group_labels[i]);
219 new_pred = X(i, Eigen::all) * alpha_diag * group_parameters_(Eigen::all, group_ind);
220 new_resid = residual.GetElement(i) - new_pred;
221 residual.SetElement(i, new_resid);
222 tracker.SetPrediction(i, new_pred);
251 Eigen::VectorXd working_parameter_;
252 Eigen::MatrixXd group_parameters_;
255 Eigen::MatrixXd group_parameter_covariance_;
258 Eigen::MatrixXd working_parameter_covariance_;
261 double variance_prior_shape_;
262 double variance_prior_scale_;