1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
18
19 #include <vector>
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25
26 namespace tensorflow {
27 namespace sparse {
28
29 class GroupIterable; // Predeclare GroupIterable for Group.
30
31 // This class is returned when dereferencing a GroupIterable iterator.
32 // It provides the methods group(), indices(), and values(), which
33 // provide access into the underlying SparseTensor.
34 class Group {
35 public:
Group(GroupIterable * iter,int64_t loc,int64_t next_loc)36 Group(GroupIterable* iter, int64_t loc, int64_t next_loc)
37 : iter_(iter), loc_(loc), next_loc_(next_loc) {}
38
39 std::vector<int64_t> group() const;
40 int64_t group_at(size_t index) const;
41 TTypes<int64_t>::UnalignedConstMatrix indices() const;
42 template <typename T>
43 typename TTypes<T>::UnalignedVec values() const;
44
45 private:
46 GroupIterable* iter_;
47 int64_t loc_;
48 int64_t next_loc_;
49 };
50
51 /////////////////
52 // GroupIterable
53 /////////////////
54 //
55 // Returned when calling sparse_tensor.group({dim0, dim1, ...}).
56 //
57 // Please note: the sparse_tensor should already be ordered according
58 // to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups.
59 //
60 // Allows grouping and iteration of the SparseTensor according to the
61 // subset of dimensions provided to the group call.
62 //
63 // The actual grouping dimensions are stored in the
64 // internal vector group_dims_. Iterators inside the iterable provide
65 // the three methods:
66 //
67 // * group(): returns a vector with the current group dimension values.
68 // * indices(): a map of index, providing the indices in
69 // this group.
70 // * values(): a map of values, providing the values in
71 // this group.
72 //
73 // To iterate across GroupIterable, see examples in README.md.
74 //
75
76 // Forward declaration of SparseTensor
77 class GroupIterable {
78 public:
79 typedef gtl::ArraySlice<int64_t> VarDimArray;
80
GroupIterable(Tensor ix,Tensor vals,int dims,const VarDimArray & group_dims)81 GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
82 : ix_(ix),
83 ix_matrix_(ix_.matrix<int64_t>()),
84 vals_(vals),
85 dims_(dims),
86 group_dims_(group_dims.begin(), group_dims.end()) {}
87
88 class IteratorStep;
89
begin()90 IteratorStep begin() { return IteratorStep(this, 0); }
at(int64_t loc)91 IteratorStep at(int64_t loc) {
92 CHECK(loc >= 0 && loc <= ix_.dim_size(0))
93 << "loc provided must lie between 0 and " << ix_.dim_size(0);
94 return IteratorStep(this, loc);
95 }
end()96 IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
97
98 template <typename TIX>
GroupMatches(const TIX & ix,int64_t loc_a,int64_t loc_b)99 inline bool GroupMatches(const TIX& ix, int64_t loc_a, int64_t loc_b) const {
100 for (int d : group_dims_) {
101 if (ix(loc_a, d) != ix(loc_b, d)) {
102 return false;
103 }
104 }
105 return true;
106 }
107
108 class IteratorStep {
109 public:
IteratorStep(GroupIterable * iter,int64_t loc)110 IteratorStep(GroupIterable* iter, int64_t loc)
111 : iter_(iter), loc_(loc), next_loc_(loc_) {
112 UpdateEndOfGroup();
113 }
114
115 void UpdateEndOfGroup();
116 bool operator!=(const IteratorStep& rhs) const;
117 bool operator==(const IteratorStep& rhs) const;
118 IteratorStep& operator++(); // prefix ++
119 IteratorStep operator++(int); // postfix ++
120 Group operator*() const { return Group(iter_, loc_, next_loc_); }
loc()121 int64_t loc() const { return loc_; }
122
123 private:
124 GroupIterable* iter_;
125 int64_t loc_;
126 int64_t next_loc_;
127 };
128
129 private:
130 friend class Group;
131 const Tensor ix_;
132 const TTypes<int64_t>::ConstMatrix ix_matrix_;
133 Tensor vals_;
134 const int dims_;
135 const gtl::InlinedVector<int64_t, 8> group_dims_;
136 };
137
group_at(size_t index)138 inline int64_t Group::group_at(size_t index) const {
139 const auto& ix_t = iter_->ix_matrix_;
140 return ix_t(loc_, index);
141 }
142
143 // Implementation of Group::values<T>()
144 template <typename T>
values()145 typename TTypes<T>::UnalignedVec Group::values() const {
146 return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
147 next_loc_ - loc_);
148 }
149
150 } // namespace sparse
151 } // namespace tensorflow
152
153 #endif // TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
154