xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/sparse_utils.h"
17 
18 #include <cstddef>
19 #include <cstdint>
20 
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/macros.h"
24 #include "tensorflow/core/platform/status.h"
25 
26 namespace tensorflow {
27 namespace sparse_utils {
28 
29 template <typename Tindices>
FindNextDenseRowStartIndex(const Tindices sparse_index_begin,const typename TTypes<Tindices>::ConstMatrix & indices_mat)30 Tindices FindNextDenseRowStartIndex(
31     const Tindices sparse_index_begin,
32     const typename TTypes<Tindices>::ConstMatrix& indices_mat) {
33   // Search in the index range [begin, end) of indices_mat.
34   Tindices begin = sparse_index_begin;
35   Tindices end = indices_mat.dimension(0);
36   const Tindices orig_sparse_index_end = end;
37 
38   // The first dense row we search.
39   const Tindices orig_dense_index_begin = indices_mat(begin, 0);
40   // Early exit if no next dense row index.
41   if (orig_dense_index_begin == static_cast<int64_t>(indices_mat(end - 1, 0))) {
42     return orig_sparse_index_end;
43   }
44 
45   Tindices increment = 1;
46   while (begin + increment < end &&
47          indices_mat(begin + increment, 0) == orig_dense_index_begin) {
48     increment *= 2;
49   }
50   // Narrow the search space as an optimization.
51   if (begin + increment < end) {
52     end = begin + increment;
53   }
54   begin += increment / 2;
55 
56   // Perform a binary search on the interval [begin, end) for
57   // dense_row_index_to_find.
58   const Tindices dense_row_index_to_find = orig_dense_index_begin;
59   while (begin < end) {
60     const Tindices m = begin + (end - begin) / 2;
61     const Tindices m_dense_row_index = static_cast<Tindices>(indices_mat(m, 0));
62     if (m_dense_row_index == dense_row_index_to_find &&
63         (m + 1 == orig_sparse_index_end ||
64          static_cast<Tindices>(indices_mat(m + 1, 0)) !=
65              dense_row_index_to_find)) {
66       return m + 1;
67     } else if (m_dense_row_index <= dense_row_index_to_find) {
68       begin = m + 1;
69     } else {
70       end = m;
71     }
72   }
73 
74   // No next dense row index.
75   return orig_sparse_index_end;
76 }
77 
78 template <typename Tindices>
GetStartIndicesOfEachDenseRow(const typename TTypes<Tindices>::ConstMatrix & indices_mat,bool * contains_empty_rows)79 std::vector<Tindices> GetStartIndicesOfEachDenseRow(
80     const typename TTypes<Tindices>::ConstMatrix& indices_mat,
81     bool* contains_empty_rows) {
82   int64_t start_sparse_index_of_cur_dense_row = 0;
83   std::vector<Tindices> segment_indices;
84   const Tindices num_entries_in_sparse_tensor = indices_mat.dimension(0);
85   const Tindices num_dense_rows_in_sparse_tensor =
86       1 + indices_mat(num_entries_in_sparse_tensor - 1, 0);
87   // Reserve an extra slot for the 0 we store in the first entry by convention.
88   segment_indices.reserve(1 + num_dense_rows_in_sparse_tensor);
89   segment_indices.push_back(0);
90   for (Tindices i = 0; i < indices_mat(0, 0); ++i) {
91     segment_indices.push_back(0);
92   }
93   *contains_empty_rows = indices_mat(0, 0) > 0;
94   while (true) {
95     const Tindices start_sparse_index_of_next_dense_row =
96         FindNextDenseRowStartIndex<Tindices>(
97             start_sparse_index_of_cur_dense_row, indices_mat);
98     if (start_sparse_index_of_next_dense_row == num_entries_in_sparse_tensor) {
99       segment_indices.push_back(start_sparse_index_of_next_dense_row);
100       break;
101     }
102     // Encode the length of the current dense row as well as the lengths of all
103     // the empty rows until the next dense row,
104     for (Tindices i = 0;
105          i < indices_mat(start_sparse_index_of_next_dense_row, 0) -
106                  indices_mat(start_sparse_index_of_cur_dense_row, 0);
107          ++i) {
108       segment_indices.push_back(start_sparse_index_of_next_dense_row);
109     }
110     // If there is more than one row between the current and next non-empty
111     // rows then those rows are empty.
112     *contains_empty_rows |=
113         indices_mat(start_sparse_index_of_next_dense_row, 0) -
114             indices_mat(start_sparse_index_of_cur_dense_row, 0) >
115         1;
116     start_sparse_index_of_cur_dense_row = start_sparse_index_of_next_dense_row;
117   }
118   return segment_indices;
119 }
120 
121 template <typename Tindices>
ParseRowStartIndices(const tensorflow::Tensor & tensor,const Tindices num_nonzero_entries_in_sparse_mat)122 std::vector<Tindices> ParseRowStartIndices(
123     const tensorflow::Tensor& tensor,
124     const Tindices num_nonzero_entries_in_sparse_mat) {
125   std::vector<Tindices> out;
126   auto vec = tensor.vec<Tindices>();
127   out.reserve(vec.size() + 1);
128   for (size_t i = 0; i < vec.dimension(0); ++i) {
129     out.push_back(vec(i));
130   }
131   out.push_back(num_nonzero_entries_in_sparse_mat);
132   return out;
133 }
134 
135 template <typename Tindices>
ContainsEmptyRows(const std::vector<Tindices> & row_start_indices)136 bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) {
137   // Skip checking the length of the last dense row since it is
138   // always non-empty.
139   for (size_t i = 1; i < row_start_indices.size() - 1; ++i) {
140     if (row_start_indices.at(i) - row_start_indices.at(i - 1) == 0) {
141       return true;
142     }
143   }
144   return false;
145 }
146 
147 namespace {
148 
149 // Ensures indices, values, shape are all of the proper ranks and are
150 // compatible.
ValidateSparseTensorShape(const Tensor & indices,const Tensor & values,const Tensor & shape)151 Status ValidateSparseTensorShape(const Tensor& indices, const Tensor& values,
152                                  const Tensor& shape) {
153   // Indices must be a matrix, and values/shape must be a vector.
154   if (!TensorShapeUtils::IsMatrix(indices.shape())) {
155     return errors::InvalidArgument("Sparse indices must be rank 2 but is rank ",
156                                    indices.shape().dim_sizes().size());
157   }
158   if (!TensorShapeUtils::IsVector(values.shape())) {
159     return errors::InvalidArgument("Sparse values must be rank 1 but is rank ",
160                                    values.shape().dims());
161   }
162   if (!TensorShapeUtils::IsVector(shape.shape())) {
163     return errors::InvalidArgument("Sparse shape must be rank 1 but is rank ",
164                                    shape.shape().dims());
165   }
166   // Indices shape must be compatible with the values vector and dense shape.
167   int64_t nnz = indices.dim_size(0);
168   int64_t ndims = indices.dim_size(1);
169   if (values.dim_size(0) != nnz) {
170     return errors::InvalidArgument("Number of elements in indices (", nnz,
171                                    ") and values (", values.dim_size(0),
172                                    ") do not match");
173   }
174   if (shape.NumElements() != ndims) {
175     return errors::InvalidArgument("Index rank (", ndims, ") and shape rank (",
176                                    shape.NumElements(), ") do not match");
177   }
178 
179   return Status::OK();
180 }
181 
182 // Creates a debug string for the index tuple in indices(row, :).
183 template <typename IndexTensor>
CreateIndexString(const IndexTensor & indices,int64_t row)184 string CreateIndexString(const IndexTensor& indices, int64_t row) {
185   const int64_t ndims = indices.dimension(1);
186   string index_str = strings::StrCat("indices[", row, ", :] = [");
187   for (int64_t dim = 0; dim < ndims; ++dim) {
188     strings::StrAppend(&index_str, indices(row, dim),
189                        dim < ndims - 1 ? ", " : "]");
190   }
191   if (ndims == 0) {
192     strings::StrAppend(&index_str, "]");
193   }
194   return index_str;
195 }
196 
197 // Ensures all sparse indices are within correct bounds.
198 template <typename Tindices>
ValidateSparseTensorIndicesUnordered(const Tensor & indices,const Tensor & shape)199 Status ValidateSparseTensorIndicesUnordered(const Tensor& indices,
200                                             const Tensor& shape) {
201   // Ensure no index is out-of-bounds.
202   const auto indices_mat = indices.flat_inner_dims<Tindices>();
203   const auto shape_vec = shape.flat<Tindices>();
204   int64_t nnz = indices.dim_size(0);
205   int64_t ndims = indices.dim_size(1);
206 
207   for (int64_t i = 0; i < nnz; ++i) {
208     for (int64_t dim = 0; dim < ndims; ++dim) {
209       const Tindices idx = indices_mat(i, dim);
210       if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
211         string index_str = CreateIndexString(indices_mat, i);
212         return errors::InvalidArgument("Sparse index tuple ", index_str,
213                                        " is out of bounds");
214       }
215     }
216   }
217 
218   return Status::OK();
219 }
220 
221 // Ensures all sparse indices are within correct bounds and are
222 // lexicographically ordered.
223 template <typename Tindices>
ValidateSparseTensorIndicesOrdered(const Tensor & indices,const Tensor & shape)224 Status ValidateSparseTensorIndicesOrdered(const Tensor& indices,
225                                           const Tensor& shape) {
226   const auto indices_mat = indices.flat_inner_dims<Tindices>();
227   const auto shape_vec = shape.flat<Tindices>();
228   int64_t nnz = indices.dim_size(0);
229   int64_t ndims = indices.dim_size(1);
230 
231   if (nnz == 0) {
232     return Status::OK();
233   }
234 
235   // First set of indices must be within range.
236   for (int64_t dim = 0; dim < ndims; ++dim) {
237     const Tindices idx = indices_mat(0, dim);
238     if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
239       string index_str = CreateIndexString(indices_mat, 0);
240       return errors::InvalidArgument("Sparse index tuple ", index_str,
241                                      " is out of bounds");
242     }
243   }
244 
245   // Remaining set of indices must be within range and lexicographically
246   // larger than the previous.
247   for (int64_t i = 1; i < nnz; ++i) {
248     bool different = false;
249     for (int64_t dim = 0; dim < ndims; ++dim) {
250       const Tindices idx = indices_mat(i, dim);
251       const Tindices prev_idx = indices_mat(i - 1, dim);
252       // If indices are already different from previous i, the new index can
253       // be anything within the valid range.
254       if (TF_PREDICT_TRUE(different)) {
255         if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
256           string index_str = CreateIndexString(indices_mat, i);
257           return errors::InvalidArgument("Sparse index tuple ", index_str,
258                                          " is out of bounds");
259         }
260       } else {
261         // Otherwise, the new index must be >= previous and <= shape(dim).
262         if (TF_PREDICT_FALSE(idx < prev_idx || idx >= shape_vec(dim))) {
263           string index_str = CreateIndexString(indices_mat, i);
264           // Check if index is actually out of bounds.
265           if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
266             return errors::InvalidArgument("Sparse index tuple ", index_str,
267                                            " is out of bounds");
268           } else {
269             return errors::InvalidArgument("Sparse index tuple ", index_str,
270                                            " is out of order");
271           }
272         } else if (TF_PREDICT_TRUE(idx > prev_idx)) {
273           different = true;
274         }
275       }  // if (different)
276     }    // for dim in [0, ndims)
277 
278     if (TF_PREDICT_FALSE(!different)) {
279       string index_str = CreateIndexString(indices_mat, i);
280       return errors::InvalidArgument("Sparse index tuple ", index_str,
281                                      " is repeated");
282     }
283   }  // for i in [1, nnz)
284 
285   return Status::OK();
286 }
287 
288 }  // namespace
289 
290 template <typename Tindices>
ValidateSparseTensor(const Tensor & indices,const Tensor & values,const Tensor & shape,IndexValidation index_validation)291 Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
292                             const Tensor& shape,
293                             IndexValidation index_validation) {
294   TF_RETURN_IF_ERROR(ValidateSparseTensorShape(indices, values, shape));
295   switch (index_validation) {
296     case IndexValidation::kOrdered:
297       return ValidateSparseTensorIndicesOrdered<Tindices>(indices, shape);
298     case IndexValidation::kUnordered:
299       return ValidateSparseTensorIndicesUnordered<Tindices>(indices, shape);
300     case IndexValidation::kNone: {
301     }
302   }
303   return Status::OK();
304 }
305 
306 #define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex)                           \
307   template TypeIndex FindNextDenseRowStartIndex<TypeIndex>(                 \
308       const TypeIndex sparse_index_begin,                                   \
309       const TTypes<TypeIndex>::ConstMatrix& indices_mat);                   \
310   template std::vector<TypeIndex> GetStartIndicesOfEachDenseRow<TypeIndex>( \
311       const TTypes<TypeIndex>::ConstMatrix& indices_mat,                    \
312       bool* contains_empty_rows);                                           \
313   template bool ContainsEmptyRows<TypeIndex>(                               \
314       const std::vector<TypeIndex>& row_start_indices);                     \
315   template std::vector<TypeIndex> ParseRowStartIndices<TypeIndex>(          \
316       const tensorflow::Tensor& tensor,                                     \
317       const TypeIndex num_nonzero_entries_in_sparse_mat);                   \
318   template Status ValidateSparseTensor<TypeIndex>(                          \
319       const Tensor& indices, const Tensor& values, const Tensor& shape,     \
320       IndexValidation index_validation)
321 
322 REGISTER_SPARSE_UTIL_FUNCTIONS(int32);
323 REGISTER_SPARSE_UTIL_FUNCTIONS(int64);
324 REGISTER_SPARSE_UTIL_FUNCTIONS(uint8);
325 REGISTER_SPARSE_UTIL_FUNCTIONS(uint16);
326 REGISTER_SPARSE_UTIL_FUNCTIONS(uint32);
327 REGISTER_SPARSE_UTIL_FUNCTIONS(uint64);
328 
329 }  // namespace sparse_utils
330 }  // namespace tensorflow
331