xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/sparse_fill_empty_rows_op.h"
19 
20 #include <algorithm>
21 #include <numeric>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/op_requires.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_util.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/util/sparse/sparse_tensor.h"
35 
36 namespace tensorflow {
37 
38 using CPUDevice = Eigen::ThreadPoolDevice;
39 using GPUDevice = Eigen::GpuDevice;
40 
41 namespace functor {
42 
43 template <typename T, typename Tindex>
44 struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseFillEmptyRows45   Status operator()(OpKernelContext* context, const Tensor& default_value_t,
46                     const Tensor& indices_t, const Tensor& values_t,
47                     const Tensor& dense_shape_t,
48                     typename AsyncOpKernel::DoneCallback done) {
49     (void)done;  // Unused (only used in GPU implementation)
50     const int kOutputIndicesOutput = 0;
51     const int kOutputValuesOutput = 1;
52     const int kEmptyRowIndicatorOutput = 2;
53     const int kReverseIndexMapOutput = 3;
54 
55     const T& default_value = default_value_t.scalar<T>()();
56     const auto indices = indices_t.matrix<Tindex>();
57     const auto values = values_t.vec<T>();
58     const auto dense_shape = dense_shape_t.vec<Tindex>();
59 
60     const Tindex N = indices_t.shape().dim_size(0);
61     const Tindex dense_rows = dense_shape(0);
62 
63     bool* empty_row_indicator = nullptr;
64     if (context->output_required(kEmptyRowIndicatorOutput)) {
65       Tensor* empty_row_indicator_t = nullptr;
66       TensorShape output_shape;
67       TF_RETURN_IF_ERROR(
68           TensorShape::BuildTensorShape({dense_rows}, &output_shape));
69       TF_RETURN_IF_ERROR(context->allocate_output(
70           kEmptyRowIndicatorOutput, output_shape, &empty_row_indicator_t));
71       empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
72     }
73     Tindex* reverse_index_map = nullptr;
74     if (context->output_required(kReverseIndexMapOutput)) {
75       Tensor* reverse_index_map_t = nullptr;
76       TensorShape output_shape;
77       TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape({N}, &output_shape));
78       TF_RETURN_IF_ERROR(context->allocate_output(
79           kReverseIndexMapOutput, output_shape, &reverse_index_map_t));
80       reverse_index_map = reverse_index_map_t->vec<Tindex>().data();
81     }
82 
83     int rank = indices_t.shape().dim_size(1);
84 
85     if (dense_rows == 0) {
86       if (N != 0) {
87         return errors::InvalidArgument(
88             "Received SparseTensor with dense_shape[0] = 0 but "
89             "indices.shape[0] = ",
90             N);
91       }
92       Tensor* output_indices_t;
93       TensorShape output_indices_shape;
94       TF_RETURN_IF_ERROR(
95           TensorShape::BuildTensorShape({0, rank}, &output_indices_shape));
96       TF_RETURN_IF_ERROR(context->allocate_output(
97           kOutputIndicesOutput, output_indices_shape, &output_indices_t));
98       Tensor* output_values_t;
99       TF_RETURN_IF_ERROR(context->allocate_output(
100           kOutputValuesOutput, TensorShape({0}), &output_values_t));
101 
102       // Exit early, nothing more to do.
103       return OkStatus();
104     }
105 
106     bool rows_are_ordered = true;
107     Tindex last_indices_row = 0;
108     std::vector<Tindex> csr_offset(dense_rows, 0);
109     for (int i = 0; i < N; ++i) {
110       const Tindex row = indices(i, 0);
111       if (row < 0 || row >= dense_rows) {
112         return errors::InvalidArgument("indices(", i, ", 0) is invalid: ", row,
113                                        " >= ", dense_rows);
114       }
115       ++csr_offset[row];
116       rows_are_ordered = rows_are_ordered & (row >= last_indices_row);
117       last_indices_row = row;
118     }
119     bool all_rows_full = true;
120     for (int row = 0; row < dense_rows; ++row) {
121       // csr_offset here describes the number of elements in this dense row
122       bool row_empty = (csr_offset[row] == 0);
123       if (empty_row_indicator) {
124         empty_row_indicator[row] = row_empty;
125       }
126       all_rows_full = all_rows_full & !row_empty;
127       // In filled version, each row has at least one element.
128       csr_offset[row] = std::max(csr_offset[row], Tindex{1});
129       // Update csr_offset to represent the number of elements up to and
130       // including dense_row + 1:
131       //  csr_offset(0) == #{elements of row 0}
132       //  csr_offset(1) == #{elements of row 1} + #{elements of row 0}
133       //  ..
134       //  csr_offset(i) == starting index for elements in row i + 1.
135       if (row > 0) {
136         csr_offset[row] += csr_offset[row - 1];
137       }
138     }
139 
140     if (all_rows_full && rows_are_ordered) {
141       context->set_output(kOutputIndicesOutput, indices_t);
142       context->set_output(kOutputValuesOutput, values_t);
143       if (reverse_index_map) {
144         for (Tindex i = 0; i < N; ++i) {
145           reverse_index_map[i] = i;
146         }
147       }
148     } else {
149       Tensor* output_indices_t;
150       const Tindex N_full = csr_offset[dense_rows - 1];
151       TensorShape output_indices_shape;
152       TF_RETURN_IF_ERROR(
153           TensorShape::BuildTensorShape({N_full, rank}, &output_indices_shape));
154       TF_RETURN_IF_ERROR(context->allocate_output(
155           kOutputIndicesOutput, output_indices_shape, &output_indices_t));
156       auto output_indices = output_indices_t->matrix<Tindex>();
157 
158       Tensor* output_values_t;
159       TF_RETURN_IF_ERROR(context->allocate_output(
160           kOutputValuesOutput, TensorShape({N_full}), &output_values_t));
161       auto output_values = output_values_t->vec<T>();
162 
163       std::vector<Tindex> filled_count(dense_rows, 0);
164 
165       // Fill in values for rows that are not missing
166       for (Tindex i = 0; i < N; ++i) {
167         const Tindex row = indices(i, 0);
168         Tindex& offset = filled_count[row];
169         const Tindex output_i = ((row == 0) ? 0 : csr_offset[row - 1]) + offset;
170         offset++;  // Increment the filled count for this row.
171         std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
172         output_values(output_i) = values(i);
173         // We'll need this reverse index map to backprop correctly.
174         if (reverse_index_map) {
175           reverse_index_map[i] = output_i;
176         }
177       }
178 
179       // Fill in values for rows that are missing
180       for (Tindex row = 0; row < dense_rows; ++row) {
181         const Tindex row_count = filled_count[row];
182         if (row_count == 0) {  // We haven't filled this row
183           const Tindex starting_index = (row == 0) ? 0 : csr_offset[row - 1];
184           // Remaining index values were set to zero already.
185           // Just need to set the row index in the right location.
186           output_indices(starting_index, 0) = row;
187           for (Tindex col = 1; col < rank; ++col) {
188             output_indices(starting_index, col) = 0;
189           }
190           output_values(starting_index) = default_value;
191         }
192       }
193     }
194 
195     return OkStatus();
196   }
197 };
198 
199 }  // namespace functor
200 
201 namespace {
202 
203 template <typename Device, typename T, typename Tindex>
SparseFillEmptyRowsOpImpl(OpKernelContext * context,AsyncOpKernel::DoneCallback done=nullptr)204 void SparseFillEmptyRowsOpImpl(OpKernelContext* context,
205                                AsyncOpKernel::DoneCallback done = nullptr) {
206   // Note that setting this empty lambda as the default parameter value directly
207   // can cause strange compiler/linker errors, so we do it like this instead.
208   if (!done) {
209     done = [] {};
210   }
211 
212   const int kIndicesInput = 0;
213   const int kValuesInput = 1;
214   const int kDenseShapeInput = 2;
215   const int kDefaultValueInput = 3;
216 
217   const Tensor& indices_t = context->input(kIndicesInput);
218   const Tensor& values_t = context->input(kValuesInput);
219   const Tensor& dense_shape_t = context->input(kDenseShapeInput);
220   const Tensor& default_value_t = context->input(kDefaultValueInput);
221 
222   OP_REQUIRES_ASYNC(
223       context, TensorShapeUtils::IsVector(dense_shape_t.shape()),
224       errors::InvalidArgument("dense_shape must be a vector, saw: ",
225                               dense_shape_t.shape().DebugString()),
226       done);
227   OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsMatrix(indices_t.shape()),
228                     errors::InvalidArgument("indices must be a matrix, saw: ",
229                                             indices_t.shape().DebugString()),
230                     done);
231   OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(values_t.shape()),
232                     errors::InvalidArgument("values must be a vector, saw: ",
233                                             values_t.shape().DebugString()),
234                     done);
235   OP_REQUIRES_ASYNC(
236       context, indices_t.dim_size(0) == values_t.dim_size(0),
237       errors::InvalidArgument("The length of `values` (", values_t.dim_size(0),
238                               ") must match the first dimension of `indices` (",
239                               indices_t.dim_size(0), ")."),
240       done);
241   OP_REQUIRES_ASYNC(
242       context, TensorShapeUtils::IsScalar(default_value_t.shape()),
243       errors::InvalidArgument("default_value must be a scalar, saw: ",
244                               default_value_t.shape().DebugString()),
245       done);
246   // TODO(ebrevdo): add shape checks between values, indices,
247   // Also add check that dense rank > 0.
248   OP_REQUIRES_ASYNC(context, dense_shape_t.NumElements() != 0,
249                     errors::InvalidArgument("Dense shape cannot be empty."),
250                     done);
251 
252   using FunctorType = functor::SparseFillEmptyRows<Device, T, Tindex>;
253   OP_REQUIRES_OK_ASYNC(context,
254                        FunctorType()(context, default_value_t, indices_t,
255                                      values_t, dense_shape_t, done),
256                        done);
257 }
258 
259 }  // namespace
260 
261 template <typename Device, typename T, typename Tindex>
262 class SparseFillEmptyRowsOp : public OpKernel {
263  public:
SparseFillEmptyRowsOp(OpKernelConstruction * context)264   explicit SparseFillEmptyRowsOp(OpKernelConstruction* context)
265       : OpKernel(context) {}
266 
Compute(OpKernelContext * context)267   void Compute(OpKernelContext* context) override {
268     SparseFillEmptyRowsOpImpl<Device, T, Tindex>(context);
269   }
270 };
271 
272 #define REGISTER_KERNELS(D, T, Tindex)                   \
273   REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows")    \
274                               .Device(DEVICE_##D)        \
275                               .HostMemory("dense_shape") \
276                               .TypeConstraint<T>("T"),   \
277                           SparseFillEmptyRowsOp<D##Device, T, Tindex>)
278 
279 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64)
280 TF_CALL_ALL_TYPES(REGISTER_CPU_KERNELS);
281 #undef REGISTER_CPU_KERNELS
282 
283 #undef REGISTER_KERNELS
284 
285 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
286 
287 // The GPU implementation is async because it requires waiting for a
288 // host->device memcpy before the output is allocated (similar to
289 // SegmentSumGPUOp).
290 template <typename T, typename Tindex>
291 class SparseFillEmptyRowsGPUOp : public AsyncOpKernel {
292  public:
SparseFillEmptyRowsGPUOp(OpKernelConstruction * context)293   explicit SparseFillEmptyRowsGPUOp(OpKernelConstruction* context)
294       : AsyncOpKernel(context) {}
295 
ComputeAsync(OpKernelContext * context,DoneCallback done)296   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
297     SparseFillEmptyRowsOpImpl<GPUDevice, T, Tindex>(context, done);
298   }
299 };
300 
301 #define REGISTER_KERNELS(T, Tindex)                      \
302   REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows")    \
303                               .Device(DEVICE_GPU)        \
304                               .HostMemory("dense_shape") \
305                               .TypeConstraint<T>("T"),   \
306                           SparseFillEmptyRowsGPUOp<T, Tindex>)
307 
308 // Forward declarations of the functor specializations for GPU.
309 namespace functor {
310 #define DECLARE_GPU_SPEC(T, Tindex)                                            \
311   template <>                                                                  \
312   Status SparseFillEmptyRows<GPUDevice, T, Tindex>::operator()(                \
313       OpKernelContext* context, const Tensor& default_value_t,                 \
314       const Tensor& indices_t, const Tensor& values_t,                         \
315       const Tensor& dense_shape_t, typename AsyncOpKernel::DoneCallback done); \
316   extern template struct SparseFillEmptyRows<GPUDevice, T, Tindex>;
317 #define DECLARE_GPU_SPEC_INT64(T) DECLARE_GPU_SPEC(T, int64_t)
318 TF_CALL_POD_TYPES(DECLARE_GPU_SPEC_INT64)
319 #undef DECLARE_GPU_SPEC_INT64
320 #undef DECLARE_GPU_SPEC
321 }  // namespace functor
322 
323 #define REGISTER_KERNELS_TINDEX(T) REGISTER_KERNELS(T, int64)
324 TF_CALL_POD_TYPES(REGISTER_KERNELS_TINDEX)
325 #undef REGISTER_KERNELS_TINDEX
326 
327 #undef REGISTER_KERNELS
328 
329 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
330 
331 namespace functor {
332 
333 template <typename T, typename Tindex>
334 struct SparseFillEmptyRowsGrad<CPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseFillEmptyRowsGrad335   Status operator()(OpKernelContext* context,
336                     typename TTypes<Tindex>::ConstVec reverse_index_map,
337                     typename TTypes<T>::ConstVec grad_values,
338                     typename TTypes<T>::Vec d_values,
339                     typename TTypes<T>::Scalar d_default_value) {
340     const CPUDevice& device = context->eigen_device<CPUDevice>();
341     const Tindex N = reverse_index_map.dimension(0);
342     const Tindex N_full = grad_values.dimension(0);
343 
344     T& d_default_value_scalar = d_default_value();
345     d_default_value_scalar = T();
346 
347     Tensor visited_t;
348     TF_RETURN_IF_ERROR(
349         context->allocate_temp(DT_BOOL, TensorShape({N_full}), &visited_t));
350     auto visited = visited_t.vec<bool>();
351     visited.device(device) = visited.constant(false);
352 
353     for (int i = 0; i < N; ++i) {
354       // Locate the index of the output of the forward prop associated
355       // with this location in the input of the forward prop.  Copy
356       // the gradient into it.  Mark it as visited.
357       int64_t reverse_index = reverse_index_map(i);
358       if (reverse_index < 0 || reverse_index >= N_full) {
359         return errors::InvalidArgument(
360             "Elements in reverse index must be in [0, ", N_full, ") but got ",
361             reverse_index);
362       }
363       d_values(i) = grad_values(reverse_index);
364       visited(reverse_index) = true;
365     }
366     for (int j = 0; j < N_full; ++j) {
367       // The default value gradient gets the accumulated remainder of
368       // the backprop values (since the default value was used to fill
369       // in these slots in the forward calculation).
370       if (!visited(j)) {
371         d_default_value_scalar += grad_values(j);
372       }
373     }
374     return OkStatus();
375   }
376 };
377 
378 }  // namespace functor
379 
380 template <typename Device, typename T, typename Tindex>
381 class SparseFillEmptyRowsGradOp : public OpKernel {
382  public:
SparseFillEmptyRowsGradOp(OpKernelConstruction * context)383   explicit SparseFillEmptyRowsGradOp(OpKernelConstruction* context)
384       : OpKernel(context) {}
385 
Compute(OpKernelContext * context)386   void Compute(OpKernelContext* context) override {
387     const Tensor* reverse_index_map_t;
388     const Tensor* grad_values_t;
389     OP_REQUIRES_OK(context,
390                    context->input("reverse_index_map", &reverse_index_map_t));
391     OP_REQUIRES_OK(context, context->input("grad_values", &grad_values_t));
392 
393     OP_REQUIRES(
394         context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
395         errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
396                                 reverse_index_map_t->shape().DebugString()));
397     OP_REQUIRES(context, TensorShapeUtils::IsVector(grad_values_t->shape()),
398                 errors::InvalidArgument("grad_values must be a vector, saw: ",
399                                         grad_values_t->shape().DebugString()));
400 
401     const auto reverse_index_map = reverse_index_map_t->vec<Tindex>();
402     const auto grad_values = grad_values_t->vec<T>();
403 
404     const Tindex N = reverse_index_map_t->shape().dim_size(0);
405 
406     Tensor* d_values_t;
407     OP_REQUIRES_OK(context, context->allocate_output(
408                                 "d_values", TensorShape({N}), &d_values_t));
409     auto d_values = d_values_t->vec<T>();
410     Tensor* d_default_value_t;
411     OP_REQUIRES_OK(context,
412                    context->allocate_output("d_default_value", TensorShape({}),
413                                             &d_default_value_t));
414     auto d_default_value = d_default_value_t->scalar<T>();
415 
416     OP_REQUIRES_OK(context,
417                    functor::SparseFillEmptyRowsGrad<Device, T, Tindex>()(
418                        context, reverse_index_map, grad_values, d_values,
419                        d_default_value));
420   }
421 };
422 
423 #define REGISTER_KERNELS(D, T, Tindex)                    \
424   REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRowsGrad") \
425                               .Device(DEVICE_##D)         \
426                               .TypeConstraint<T>("T"),    \
427                           SparseFillEmptyRowsGradOp<D##Device, T, Tindex>)
428 
429 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64)
430 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
431 #undef REGISTER_CPU_KERNELS
432 
433 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
434 
435 // Forward declarations of the functor specializations for GPU.
436 namespace functor {
437 #define DECLARE_GPU_SPEC(T, Tindex)                                 \
438   template <>                                                       \
439   Status SparseFillEmptyRowsGrad<GPUDevice, T, Tindex>::operator()( \
440       OpKernelContext* context,                                     \
441       typename TTypes<Tindex>::ConstVec reverse_index_map,          \
442       typename TTypes<T>::ConstVec grad_values,                     \
443       typename TTypes<T>::Vec d_values,                             \
444       typename TTypes<T>::Scalar d_default_value);                  \
445   extern template struct SparseFillEmptyRowsGrad<GPUDevice, T, Tindex>;
446 #define DECLARE_GPU_SPEC_INT64(T) DECLARE_GPU_SPEC(T, int64_t)
447 TF_CALL_REAL_NUMBER_TYPES(DECLARE_GPU_SPEC_INT64);
448 #undef DECLARE_GPU_SPEC_INT64
449 #undef DECLARE_GPU_SPEC
450 }  // namespace functor
451 
452 #define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T, int64)
453 TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_KERNELS);
454 #undef REGISTER_GPU_KERNELS
455 
456 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
457 
458 #undef REGISTER_KERNELS
459 }  // namespace tensorflow
460