xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/sparse/group_iterator.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
18 
19 #include <vector>
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 namespace sparse {
28 
29 class GroupIterable;  // Predeclare GroupIterable for Group.
30 
31 // This class is returned when dereferencing a GroupIterable iterator.
32 // It provides the methods group(), indices(), and values(), which
33 // provide access into the underlying SparseTensor.
34 class Group {
35  public:
Group(GroupIterable * iter,int64_t loc,int64_t next_loc)36   Group(GroupIterable* iter, int64_t loc, int64_t next_loc)
37       : iter_(iter), loc_(loc), next_loc_(next_loc) {}
38 
39   std::vector<int64_t> group() const;
40   int64_t group_at(size_t index) const;
41   TTypes<int64_t>::UnalignedConstMatrix indices() const;
42   template <typename T>
43   typename TTypes<T>::UnalignedVec values() const;
44 
45  private:
46   GroupIterable* iter_;
47   int64_t loc_;
48   int64_t next_loc_;
49 };
50 
51 /////////////////
52 // GroupIterable
53 /////////////////
54 //
55 // Returned when calling sparse_tensor.group({dim0, dim1, ...}).
56 //
57 // Please note: the sparse_tensor should already be ordered according
58 // to {dim0, dim1, ...}.  Otherwise this iteration will return invalid groups.
59 //
60 // Allows grouping and iteration of the SparseTensor according to the
61 // subset of dimensions provided to the group call.
62 //
63 // The actual grouping dimensions are stored in the
64 // internal vector group_dims_.  Iterators inside the iterable provide
65 // the three methods:
66 //
67 // *  group(): returns a vector with the current group dimension values.
68 // *  indices(): a map of index, providing the indices in
69 //    this group.
70 // *  values(): a map of values, providing the values in
71 //    this group.
72 //
73 // To iterate across GroupIterable, see examples in README.md.
74 //
75 
76 // Forward declaration of SparseTensor
77 class GroupIterable {
78  public:
79   typedef gtl::ArraySlice<int64_t> VarDimArray;
80 
GroupIterable(Tensor ix,Tensor vals,int dims,const VarDimArray & group_dims)81   GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
82       : ix_(ix),
83         ix_matrix_(ix_.matrix<int64_t>()),
84         vals_(vals),
85         dims_(dims),
86         group_dims_(group_dims.begin(), group_dims.end()) {}
87 
88   class IteratorStep;
89 
begin()90   IteratorStep begin() { return IteratorStep(this, 0); }
at(int64_t loc)91   IteratorStep at(int64_t loc) {
92     CHECK(loc >= 0 && loc <= ix_.dim_size(0))
93         << "loc provided must lie between 0 and " << ix_.dim_size(0);
94     return IteratorStep(this, loc);
95   }
end()96   IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
97 
98   template <typename TIX>
GroupMatches(const TIX & ix,int64_t loc_a,int64_t loc_b)99   inline bool GroupMatches(const TIX& ix, int64_t loc_a, int64_t loc_b) const {
100     for (int d : group_dims_) {
101       if (ix(loc_a, d) != ix(loc_b, d)) {
102         return false;
103       }
104     }
105     return true;
106   }
107 
108   class IteratorStep {
109    public:
IteratorStep(GroupIterable * iter,int64_t loc)110     IteratorStep(GroupIterable* iter, int64_t loc)
111         : iter_(iter), loc_(loc), next_loc_(loc_) {
112       UpdateEndOfGroup();
113     }
114 
115     void UpdateEndOfGroup();
116     bool operator!=(const IteratorStep& rhs) const;
117     bool operator==(const IteratorStep& rhs) const;
118     IteratorStep& operator++();    // prefix ++
119     IteratorStep operator++(int);  // postfix ++
120     Group operator*() const { return Group(iter_, loc_, next_loc_); }
loc()121     int64_t loc() const { return loc_; }
122 
123    private:
124     GroupIterable* iter_;
125     int64_t loc_;
126     int64_t next_loc_;
127   };
128 
129  private:
130   friend class Group;
131   const Tensor ix_;
132   const TTypes<int64_t>::ConstMatrix ix_matrix_;
133   Tensor vals_;
134   const int dims_;
135   const gtl::InlinedVector<int64_t, 8> group_dims_;
136 };
137 
group_at(size_t index)138 inline int64_t Group::group_at(size_t index) const {
139   const auto& ix_t = iter_->ix_matrix_;
140   return ix_t(loc_, index);
141 }
142 
143 // Implementation of Group::values<T>()
144 template <typename T>
values()145 typename TTypes<T>::UnalignedVec Group::values() const {
146   return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
147                                           next_loc_ - loc_);
148 }
149 
150 }  // namespace sparse
151 }  // namespace tensorflow
152 
153 #endif  // TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
154