xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/agg_vector_iterator.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_AGGREGATION_CORE_AGG_VECTOR_ITERATOR_H_
18 #define FCP_AGGREGATION_CORE_AGG_VECTOR_ITERATOR_H_
19 
20 #include "fcp/aggregation/core/tensor_data.h"
21 #include "fcp/base/monitoring.h"
22 
23 namespace fcp {
24 namespace aggregation {
25 
26 // Iterator for AggVector which allows to iterate over sparse values
27 // as a collection of {index, value} pairs.
28 //
29 // This allows a simple iteration loops like the following:
30 // for (auto [index, value] : agg_vector) {
31 //    ... aggregate the value at the given dense index
32 // }
33 template <typename T>
34 struct AggVectorIterator {
35   struct IndexValuePair {
36     size_t index;
37     T value;
38 
39     friend bool operator==(const IndexValuePair& a, const IndexValuePair& b) {
40       return a.index == b.index && a.value == b.value;
41     }
42 
43     friend bool operator!=(const IndexValuePair& a, const IndexValuePair& b) {
44       return a.index != b.index || a.value != b.value;
45     }
46   };
47 
48   using value_type = IndexValuePair;
49   using pointer = value_type*;
50   using reference = value_type&;
51 
AggVectorIteratorAggVectorIterator52   explicit AggVectorIterator(const TensorData* data)
53       : AggVectorIterator(get_start_ptr(data), get_end_ptr(data), 0) {}
54 
55   // Current dense index corresponding to the current value.
indexAggVectorIterator56   size_t index() const { return dense_index; }
57   // Current value.
valueAggVectorIterator58   T value() const { return *ptr; }
59   // The current interator {index, value} pair value. This is used by
60   // for loop iterators.
61   IndexValuePair operator*() const { return {dense_index, *ptr}; }
62 
63   AggVectorIterator& operator++() {
64     FCP_CHECK(ptr != end_ptr);
65     if (++ptr == end_ptr) {
66       *this = end();
67     } else {
68       dense_index++;
69     }
70     return *this;
71   }
72 
73   AggVectorIterator operator++(int) {
74     AggVectorIterator tmp = *this;
75     ++(*this);
76     return tmp;
77   }
78 
79   friend bool operator==(const AggVectorIterator& a,
80                          const AggVectorIterator& b) {
81     return a.ptr == b.ptr;
82   }
83 
84   friend bool operator!=(const AggVectorIterator& a,
85                          const AggVectorIterator& b) {
86     return a.ptr != b.ptr;
87   }
88 
endAggVectorIterator89   static AggVectorIterator end() {
90     return AggVectorIterator(nullptr, nullptr, 0);
91   }
92 
93  private:
AggVectorIteratorAggVectorIterator94   AggVectorIterator(const T* ptr, const T* end_ptr, size_t dense_index)
95       : ptr(ptr), end_ptr(end_ptr), dense_index(dense_index) {}
96 
get_start_ptrAggVectorIterator97   static const T* get_start_ptr(const TensorData* data) {
98     return static_cast<const T*>(data->data());
99   }
100 
get_end_ptrAggVectorIterator101   static const T* get_end_ptr(const TensorData* data) {
102     return get_start_ptr(data) + data->byte_size() / sizeof(T);
103   }
104 
105   const T* ptr;
106   const T* end_ptr;
107   size_t dense_index;
108 };
109 
110 }  // namespace aggregation
111 }  // namespace fcp
112 
113 #endif  // FCP_AGGREGATION_CORE_AGG_VECTOR_ITERATOR_H_
114