StochTree 0.1.1
Loading...
Searching...
No Matches
category_tracker.h
1
25#ifndef STOCHTREE_CATEGORY_TRACKER_H_
26#define STOCHTREE_CATEGORY_TRACKER_H_
27
28#include <Eigen/Dense>
29#include <stochtree/log.h>
30#include <stochtree/meta.h>
31
32#include <map>
33#include <numeric>
34#include <vector>
35
36namespace StochTree {
37
42 public:
43 SampleCategoryMapper(std::vector<int32_t>& group_indices) {
44 num_observations_ = group_indices.size();
45 observation_indices_ = group_indices;
46 }
47
49 num_observations_ = other.NumObservations();
50 observation_indices_.resize(num_observations_);
51 for (int i = 0; i < num_observations_; i++) {
52 observation_indices_[i] = other.GetCategoryId(i);
53 }
54 }
55
56 inline data_size_t GetCategoryId(data_size_t sample_id) {
57 CHECK_LT(sample_id, num_observations_);
58 return observation_indices_[sample_id];
59 }
60
61 inline void SetCategoryId(data_size_t sample_id, int category_id) {
62 CHECK_LT(sample_id, num_observations_);
63 observation_indices_[sample_id] = sample_id;
64 }
65
66 inline int NumObservations() {return num_observations_;}
67
68 private:
69 std::vector<int> observation_indices_;
70 data_size_t num_observations_;
71};
72
76class CategorySampleTracker {
77 public:
78 CategorySampleTracker(const std::vector<int32_t>& group_indices) {
79 int n = group_indices.size();
80 indices_ = std::vector<data_size_t>(n);
81 std::iota(indices_.begin(), indices_.end(), 0);
82
83 auto comp_op = [&](size_t const &l, size_t const &r) { return std::less<data_size_t>{}(group_indices[l], group_indices[r]); };
84 std::stable_sort(indices_.begin(), indices_.end(), comp_op);
85
86 category_count_ = 0;
87 int observation_count = 0;
88 for (int i = 0; i < n; i++) {
89 bool start_cond = i == 0;
90 bool end_cond = i == n-1;
91 bool new_group_cond{false};
92 if (i > 0) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i-1]];
93 if (start_cond || new_group_cond) {
94 category_id_map_.insert({group_indices[indices_[i]], category_count_});
95 unique_category_ids_.push_back(group_indices[indices_[i]]);
96 node_index_vector_.emplace_back();
97 if (i == 0) {
98 category_begin_.push_back(i);
99 } else {
100 category_begin_.push_back(i);
101 category_length_.push_back(observation_count);
102 }
103 observation_count = 1;
104 category_count_++;
105 } else if (end_cond) {
106 category_length_.push_back(observation_count+1);
107 } else {
108 observation_count++;
109 }
110 // Add the index to the category's node index vector in either case
111 node_index_vector_[category_count_ - 1].emplace_back(indices_[i]);
112 }
113 }
114
116 inline int32_t CategoryNumber(int category_id) {
117 return category_id_map_[category_id];
118 }
119
121 inline data_size_t CategoryBegin(int category_id) {return category_begin_[category_id_map_[category_id]];}
122
124 inline data_size_t CategoryEnd(int category_id) {
125 int32_t id = category_id_map_[category_id];
126 return category_begin_[id] + category_length_[id];
127 }
128
130 inline data_size_t CategorySize(int category_id) {
131 return category_length_[category_id_map_[category_id]];
132 }
133
135 inline data_size_t NumCategories() {return category_count_;}
136
138 std::vector<data_size_t> indices_;
139
141 std::vector<data_size_t>& NodeIndices(int category_id) {
142 int32_t id = category_id_map_[category_id];
143 return node_index_vector_[id];
144 }
145
147 std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {
148 return node_index_vector_[internal_category_id];
149 }
150
152 std::map<int32_t, int32_t>& GetLabelMap() {return category_id_map_;}
153
154 std::vector<int32_t>& GetUniqueGroupIds() {return unique_category_ids_;}
155
156 private:
157 // Vectors tracking indices in each node
158 std::vector<data_size_t> category_begin_;
159 std::vector<data_size_t> category_length_;
160 std::map<int32_t, int32_t> category_id_map_;
161 std::vector<int32_t> unique_category_ids_;
162 std::vector<std::vector<data_size_t>> node_index_vector_;
163 int32_t category_count_;
164};
165
166} // namespace StochTree
167
168#endif // STOCHTREE_CATEGORY_TRACKER_H_
Class storing sample-node map for each tree in an ensemble TODO: Add run-time checks for categories w...
Definition category_tracker.h:41
Definition category_tracker.h:36