StochTree 0.0.1
Loading...
Searching...
No Matches
category_tracker.h
1
25#ifndef STOCHTREE_CATEGORY_TRACKER_H_
26#define STOCHTREE_CATEGORY_TRACKER_H_
27
28#include <Eigen/Dense>
29#include <stochtree/log.h>
30#include <stochtree/meta.h>
31
32#include <cmath>
33#include <map>
34#include <numeric>
35#include <random>
36#include <set>
37#include <string>
38#include <vector>
39
40namespace StochTree {
41
46 public:
47 SampleCategoryMapper(std::vector<int32_t>& group_indices) {
48 num_observations_ = group_indices.size();
49 observation_indices_ = group_indices;
50 }
51
53 num_observations_ = other.NumObservations();
54 observation_indices_.resize(num_observations_);
55 for (int i = 0; i < num_observations_; i++) {
56 observation_indices_[i] = other.GetCategoryId(i);
57 }
58 }
59
60 inline data_size_t GetCategoryId(data_size_t sample_id) {
61 CHECK_LT(sample_id, num_observations_);
62 return observation_indices_[sample_id];
63 }
64
65 inline void SetCategoryId(data_size_t sample_id, int category_id) {
66 CHECK_LT(sample_id, num_observations_);
67 observation_indices_[sample_id] = sample_id;
68 }
69
70 inline int NumObservations() {return num_observations_;}
71
72 private:
73 std::vector<int> observation_indices_;
74 data_size_t num_observations_;
75};
76
80class CategorySampleTracker {
81 public:
82 CategorySampleTracker(const std::vector<int32_t>& group_indices) {
83 int n = group_indices.size();
84 indices_ = std::vector<data_size_t>(n);
85 std::iota(indices_.begin(), indices_.end(), 0);
86
87 auto comp_op = [&](size_t const &l, size_t const &r) { return std::less<data_size_t>{}(group_indices[l], group_indices[r]); };
88 std::stable_sort(indices_.begin(), indices_.end(), comp_op);
89
90 category_count_ = 0;
91 int observation_count = 0;
92 for (int i = 0; i < n; i++) {
93 bool start_cond = i == 0;
94 bool end_cond = i == n-1;
95 bool new_group_cond{false};
96 if (i > 0) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i-1]];
97 if (start_cond || new_group_cond) {
98 category_id_map_.insert({group_indices[indices_[i]], category_count_});
99 unique_category_ids_.push_back(group_indices[indices_[i]]);
100 node_index_vector_.emplace_back();
101 if (i == 0) {
102 category_begin_.push_back(i);
103 } else {
104 category_begin_.push_back(i);
105 category_length_.push_back(observation_count);
106 }
107 observation_count = 1;
108 category_count_++;
109 } else if (end_cond) {
110 category_length_.push_back(observation_count+1);
111 } else {
112 observation_count++;
113 }
114 // Add the index to the category's node index vector in either case
115 node_index_vector_[category_count_ - 1].emplace_back(indices_[i]);
116 }
117 }
118
120 inline int32_t CategoryNumber(int category_id) {
121 return category_id_map_[category_id];
122 }
123
125 inline data_size_t CategoryBegin(int category_id) {return category_begin_[category_id_map_[category_id]];}
126
128 inline data_size_t CategoryEnd(int category_id) {
129 int32_t id = category_id_map_[category_id];
130 return category_begin_[id] + category_length_[id];
131 }
132
134 inline data_size_t CategorySize(int category_id) {
135 return category_length_[category_id_map_[category_id]];
136 }
137
139 inline data_size_t NumCategories() {return category_count_;}
140
142 std::vector<data_size_t> indices_;
143
145 std::vector<data_size_t>& NodeIndices(int category_id) {
146 int32_t id = category_id_map_[category_id];
147 return node_index_vector_[id];
148 }
149
151 std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {
152 return node_index_vector_[internal_category_id];
153 }
154
156 std::map<int32_t, int32_t>& GetLabelMap() {return category_id_map_;}
157
158 std::vector<int32_t>& GetUniqueGroupIds() {return unique_category_ids_;}
159
160 private:
161 // Vectors tracking indices in each node
162 std::vector<data_size_t> category_begin_;
163 std::vector<data_size_t> category_length_;
164 std::map<int32_t, int32_t> category_id_map_;
165 std::vector<int32_t> unique_category_ids_;
166 std::vector<std::vector<data_size_t>> node_index_vector_;
167 int32_t category_count_;
168};
169
170} // namespace StochTree
171
172#endif // STOCHTREE_CATEGORY_TRACKER_H_
Class storing sample-node map for each tree in an ensemble TODO: Add run-time checks for categories w...
Definition category_tracker.h:45
Definition category_tracker.h:40