48 num_observations_ = group_indices.size();
49 observation_indices_ = group_indices;
53 num_observations_ = other.NumObservations();
54 observation_indices_.resize(num_observations_);
55 for (
int i = 0; i < num_observations_; i++) {
56 observation_indices_[i] = other.GetCategoryId(i);
60 inline data_size_t GetCategoryId(data_size_t sample_id) {
61 CHECK_LT(sample_id, num_observations_);
62 return observation_indices_[sample_id];
65 inline void SetCategoryId(data_size_t sample_id,
int category_id) {
66 CHECK_LT(sample_id, num_observations_);
67 observation_indices_[sample_id] = sample_id;
70 inline int NumObservations() {
return num_observations_;}
73 std::vector<int> observation_indices_;
74 data_size_t num_observations_;
80class CategorySampleTracker {
82 CategorySampleTracker(
const std::vector<int32_t>& group_indices) {
83 int n = group_indices.size();
84 indices_ = std::vector<data_size_t>(n);
85 std::iota(indices_.begin(), indices_.end(), 0);
87 auto comp_op = [&](
size_t const &l,
size_t const &r) {
return std::less<data_size_t>{}(group_indices[l], group_indices[r]); };
88 std::stable_sort(indices_.begin(), indices_.end(), comp_op);
91 int observation_count = 0;
92 for (
int i = 0; i < n; i++) {
93 bool start_cond = i == 0;
94 bool end_cond = i == n-1;
95 bool new_group_cond{
false};
96 if (i > 0) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i-1]];
97 if (start_cond || new_group_cond) {
98 category_id_map_.insert({group_indices[indices_[i]], category_count_});
99 unique_category_ids_.push_back(group_indices[indices_[i]]);
100 node_index_vector_.emplace_back();
102 category_begin_.push_back(i);
104 category_begin_.push_back(i);
105 category_length_.push_back(observation_count);
107 observation_count = 1;
109 }
else if (end_cond) {
110 category_length_.push_back(observation_count+1);
115 node_index_vector_[category_count_ - 1].emplace_back(indices_[i]);
120 inline int32_t CategoryNumber(
int category_id) {
121 return category_id_map_[category_id];
125 inline data_size_t CategoryBegin(
int category_id) {
return category_begin_[category_id_map_[category_id]];}
128 inline data_size_t CategoryEnd(
int category_id) {
129 int32_t
id = category_id_map_[category_id];
130 return category_begin_[id] + category_length_[id];
134 inline data_size_t CategorySize(
int category_id) {
135 return category_length_[category_id_map_[category_id]];
139 inline data_size_t NumCategories() {
return category_count_;}
142 std::vector<data_size_t> indices_;
145 std::vector<data_size_t>& NodeIndices(
int category_id) {
146 int32_t
id = category_id_map_[category_id];
147 return node_index_vector_[id];
151 std::vector<data_size_t>& NodeIndicesInternalIndex(
int internal_category_id) {
152 return node_index_vector_[internal_category_id];
156 std::map<int32_t, int32_t>& GetLabelMap() {
return category_id_map_;}
158 std::vector<int32_t>& GetUniqueGroupIds() {
return unique_category_ids_;}
162 std::vector<data_size_t> category_begin_;
163 std::vector<data_size_t> category_length_;
164 std::map<int32_t, int32_t> category_id_map_;
165 std::vector<int32_t> unique_category_ids_;
166 std::vector<std::vector<data_size_t>> node_index_vector_;
167 int32_t category_count_;