44 num_observations_ = group_indices.size();
45 observation_indices_ = group_indices;
49 num_observations_ = other.NumObservations();
50 observation_indices_.resize(num_observations_);
51 for (
int i = 0; i < num_observations_; i++) {
52 observation_indices_[i] = other.GetCategoryId(i);
56 inline data_size_t GetCategoryId(data_size_t sample_id) {
57 CHECK_LT(sample_id, num_observations_);
58 return observation_indices_[sample_id];
61 inline void SetCategoryId(data_size_t sample_id,
int category_id) {
62 CHECK_LT(sample_id, num_observations_);
63 observation_indices_[sample_id] = sample_id;
66 inline int NumObservations() {
return num_observations_;}
69 std::vector<int> observation_indices_;
70 data_size_t num_observations_;
76class CategorySampleTracker {
78 CategorySampleTracker(
const std::vector<int32_t>& group_indices) {
79 int n = group_indices.size();
80 indices_ = std::vector<data_size_t>(n);
81 std::iota(indices_.begin(), indices_.end(), 0);
83 auto comp_op = [&](
size_t const &l,
size_t const &r) {
return std::less<data_size_t>{}(group_indices[l], group_indices[r]); };
84 std::stable_sort(indices_.begin(), indices_.end(), comp_op);
87 int observation_count = 0;
88 for (
int i = 0; i < n; i++) {
89 bool start_cond = i == 0;
90 bool end_cond = i == n-1;
91 bool new_group_cond{
false};
92 if (i > 0) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i-1]];
93 if (start_cond || new_group_cond) {
94 category_id_map_.insert({group_indices[indices_[i]], category_count_});
95 unique_category_ids_.push_back(group_indices[indices_[i]]);
96 node_index_vector_.emplace_back();
98 category_begin_.push_back(i);
100 category_begin_.push_back(i);
101 category_length_.push_back(observation_count);
103 observation_count = 1;
105 }
else if (end_cond) {
106 category_length_.push_back(observation_count+1);
111 node_index_vector_[category_count_ - 1].emplace_back(indices_[i]);
116 inline int32_t CategoryNumber(
int category_id) {
117 return category_id_map_[category_id];
121 inline data_size_t CategoryBegin(
int category_id) {
return category_begin_[category_id_map_[category_id]];}
124 inline data_size_t CategoryEnd(
int category_id) {
125 int32_t
id = category_id_map_[category_id];
126 return category_begin_[id] + category_length_[id];
130 inline data_size_t CategorySize(
int category_id) {
131 return category_length_[category_id_map_[category_id]];
135 inline data_size_t NumCategories() {
return category_count_;}
138 std::vector<data_size_t> indices_;
141 std::vector<data_size_t>& NodeIndices(
int category_id) {
142 int32_t
id = category_id_map_[category_id];
143 return node_index_vector_[id];
147 std::vector<data_size_t>& NodeIndicesInternalIndex(
int internal_category_id) {
148 return node_index_vector_[internal_category_id];
152 std::map<int32_t, int32_t>& GetLabelMap() {
return category_id_map_;}
154 std::vector<int32_t>& GetUniqueGroupIds() {
return unique_category_ids_;}
158 std::vector<data_size_t> category_begin_;
159 std::vector<data_size_t> category_length_;
160 std::map<int32_t, int32_t> category_id_map_;
161 std::vector<int32_t> unique_category_ids_;
162 std::vector<std::vector<data_size_t>> node_index_vector_;
163 int32_t category_count_;