xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/segment_reduction_ops_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
20 
21 #include <cstdint>
22 
23 #include "tensorflow/core/framework/op_requires.h"
24 #include "tensorflow/core/platform/types.h"
25 #define EIGEN_USE_THREADS
26 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
27 #define EIGEN_USE_GPU
28 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
29 
30 #include "third_party/eigen3/Eigen/Core"
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/core/framework/bounds_check.h"
33 #include "tensorflow/core/framework/numeric_op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/register_types.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_types.h"
38 #include "tensorflow/core/framework/tensor_util.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/kernels/segment_reduction_ops.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/util/determinism.h"
44 #include "tensorflow/core/util/util.h"
45 
46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
48 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
49 
50 #if GOOGLE_CUDA
51 #include "tensorflow/core/util/gpu_solvers.h"
52 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
53 
54 using stream_executor::cuda::ScopedActivateExecutorContext;
55 #elif TENSORFLOW_USE_ROCM
56 #include "tensorflow/core/platform/rocm.h"
57 #include "tensorflow/core/util/gpu_solvers.h"
58 using stream_executor::rocm::ScopedActivateExecutorContext;
59 #endif  // GOOGLE_CUDA
60 
61 namespace tensorflow {
62 
63 typedef Eigen::ThreadPoolDevice CPUDevice;
64 typedef Eigen::GpuDevice GPUDevice;
65 
66 namespace internal {
67 Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input,
68                                 const Tensor& segment_ids);
69 Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel,
70                                         OpKernelContext* context,
71                                         const Tensor& data,
72                                         const Tensor& segment_ids,
73                                         const Tensor& num_segments);
74 Status ValidateSparseSegmentReduction(OpKernelContext* context,
75                                       const Tensor& input,
76                                       const Tensor& indices,
77                                       const Tensor& segment_ids,
78                                       bool has_num_segments);
79 }  // namespace internal
80 
81 // This operator handles reducing segments along the first dimension.
82 // See core/ops/math_ops.cc for more details.
83 template <typename Device, class T, class Index, typename Reducer,
84           int default_value>
85 class SegmentReductionOp : public OpKernel {
86  public:
SegmentReductionOp(OpKernelConstruction * context)87   explicit SegmentReductionOp(OpKernelConstruction* context)
88       : OpKernel(context) {}
89 
Compute(OpKernelContext * context)90   void Compute(OpKernelContext* context) override {
91     const Tensor& input = context->input(0);
92     const Tensor& segment_ids = context->input(1);
93 
94     OP_REQUIRES_OK(context, internal::ValidateSegmentReduction(context, input,
95                                                                segment_ids));
96 
97     const int64_t num_indices = segment_ids.NumElements();
98     auto input_flat = input.flat_outer_dims<T>();
99     const int64_t num_col = input_flat.dimension(1);
100 
101     const auto segment_vec = segment_ids.vec<Index>();
102     // Note that the current implementation assumes that segment_vec values are
103     // sorted.
104     const Index output_rows =
105         num_indices > 0
106             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
107             : 0;
108     OP_REQUIRES(context, output_rows >= 0,
109                 errors::InvalidArgument("segment ids must be >= 0"));
110 
111     OP_REQUIRES(context, input.dims() >= 1,
112                 errors::InvalidArgument("Shape must be at least rank 1"));
113 
114     TensorShape output_shape = input.shape();
115     // Since we're changing the first dimension of the shape, we need to make
116     // sure the new shape won't overflow.
117     OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows));
118 
119     // Note that we do not initialize the output buffer with a default value, so
120     // we need to explicitly set missing indices to the default value.
121     Tensor* output = nullptr;
122     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
123     if (num_indices == 0) return;
124     OP_REQUIRES(context, output_rows > 0,
125                 errors::InvalidArgument("segment ids must be >= 0"));
126     auto output_flat = output->flat_outer_dims<T>();
127 
128     Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce;
129     Index start = 0, end = 1;
130 
131     Index uninitialized_index = 0;  // Index from which the output is not set.
132     Index out_index = internal::SubtleMustCopy(segment_vec(start));
133 
134     // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it
135     // across threads.
136     Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col);
137     while (end <= num_indices) {
138       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
139       // used uninitialized in this function" in the Mac build (since the
140       // compiler isn't smart enough to realize the code is safe).
141       Index next_index = 0;
142       if (end < num_indices) {
143         next_index = internal::SubtleMustCopy(segment_vec(end));
144         if (out_index == next_index) {
145           ++end;
146           continue;
147         }
148         // We have a new segment here.  Verify that the segment ids are growing.
149         OP_REQUIRES(context, out_index < next_index,
150                     errors::InvalidArgument("segment ids are not increasing"));
151       }
152 
153       // Process segment [start, end)
154       const T* in_slice_ptr = &input_flat(start, 0);
155       typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
156                                Eigen::Unaligned>
157           OutT;
158 
159       OP_REQUIRES(
160           context, FastBoundsCheck(out_index, output_rows),
161           errors::InvalidArgument(
162               "Segment id ", out_index, " out of range [0, ", output_rows,
163               "), possibly because 'segment_ids' input is not sorted."));
164 
165       // If there is a gap between two indices, we need to set that gap to the
166       // default value.
167       if (out_index > uninitialized_index) {
168         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
169             out_index - uninitialized_index, num_col);
170         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
171             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
172         gap_slice.setConstant(T(default_value));
173       }
174 
175       T* out_slice_ptr = &output_flat(out_index, 0);
176       OutT out_slice(out_slice_ptr, out_slice_shape);
177       // We don't use out_slice.device(context->eigen_device<Device>)
178       // because these pieces of work are likely to be very small and
179       // the context switching overhead dwarfs any benefit we get from
180       // using another thread to do this work.
181       if (start == end - 1) {
182         typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
183                                  Eigen::Unaligned>
184             InT;
185         InT in_slice(in_slice_ptr, out_slice_shape);
186         out_slice = in_slice;
187       } else {
188         Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
189                                                            num_col);
190         typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
191                                  Eigen::Unaligned>
192             InT;
193         InT in_slice(in_slice_ptr, in_slice_shape);
194 
195         out_slice = in_slice.reduce(dims_to_reduce, Reducer());
196       }
197       if (end >= num_indices) break;
198       start = end;
199       ++end;
200       uninitialized_index = out_index + 1;
201       out_index = next_index;
202     }
203   }
204 };
205 
206 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
207 
208 //  SegmentReductionGPUOp is a segment reduction operator implemented for GPU
209 //  only.
210 //  TODO: This implementation of SegmentReductionGPUOp is sometimes slower than
211 //  its unsorted counterpart (mostly when problem size is small).
212 //  This is due to the following two main reasons and a cost-effective way
213 //  to resolve these problems is desirable.
214 //  1. Sorted segment reduction requires a memory transfer from device to host
215 //     in order to know the size of the output dimension whereas unsorted
216 //     segment reduction receives the size of the output dimension as an input
217 //     parameter.
218 //  2. Sorted segment reduction is essentially a tiled version of unsorted
219 //     segment reduction and therefore such optimization comes at an inherent
220 //     cost. However such cost may not be justified when the problem size is
221 //     small. When to use the tiled version or the untiled version depends on
222 //     many factors including data alignments, ratio of calculation to memory
223 //     traffic and obviously, the problem sizes.
224 template <class T, class Index, class SegmentReductionFunctor, bool IsMean>
225 class SegmentReductionGPUOp : public AsyncOpKernel {
226  public:
SegmentReductionGPUOp(OpKernelConstruction * context)227   explicit SegmentReductionGPUOp(OpKernelConstruction* context)
228       : AsyncOpKernel(context) {}
229 
ComputeAsync(OpKernelContext * context,DoneCallback done)230   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
231     const Tensor& input = context->input(0);
232     const Tensor& segment_ids = context->input(1);
233 
234     OP_REQUIRES_ASYNC(
235         context, TensorShapeUtils::IsVector(segment_ids.shape()),
236         errors::InvalidArgument("segment_ids should be a vector."), done);
237 
238     OP_REQUIRES_ASYNC(context, input.dims() >= 1,
239                       errors::InvalidArgument("Shape must be at least rank 1"),
240                       done);
241 
242     const int64_t num_indices = segment_ids.NumElements();
243     OP_REQUIRES_ASYNC(
244         context, num_indices == input.dim_size(0),
245         errors::InvalidArgument(
246             "segment_ids should be the same size as dimension 0 of"
247             " input."),
248         done);
249 
250     if (num_indices == 0) {
251       TensorShape output_shape = input.shape();
252       output_shape.set_dim(0, 0);
253 
254       Tensor* output = nullptr;
255       OP_REQUIRES_OK_ASYNC(
256           context, context->allocate_output(0, output_shape, &output), done);
257       done();
258       return;
259     }
260 
261     se::DeviceMemoryBase output_rows_device(
262         const_cast<Tensor&>(segment_ids).template flat<Index>().data() +
263         (num_indices - 1));
264     ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
265 
266     auto stream = context->op_device_context()->stream();
267     OP_REQUIRES_ASYNC(
268         context,
269         stream
270             ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
271                          sizeof(Index))
272             .ok(),
273         errors::Internal(type_string() +
274                          ": failed to copy output_rows from device"),
275         done);
276 
277     SegmentReductionFunctor functor_;
278     auto create_and_check_output = [context, output_rows_host, &input,
279                                     &segment_ids, &functor_, done]() {
280       // Ensure that within the callback, the proper GPU settings are
281       // configured.
282       auto stream = context->op_device_context()->stream();
283       ScopedActivateExecutorContext scoped_activation{stream->parent()};
284 
285       Index output_rows = *output_rows_host.data();
286       output_rows++;
287       OP_REQUIRES_ASYNC(context, output_rows > 0,
288                         errors::InvalidArgument("segment ids must be >= 0"),
289                         done);
290 
291       TensorShape output_shape = input.shape();
292       // Since we're changing the first dimension of the shape, we need to make
293       // sure the new shape won't overflow.
294       OP_REQUIRES_OK_ASYNC(context,
295                            output_shape.SetDimWithStatus(0, output_rows), done);
296 
297       Tensor* output = nullptr;
298       OP_REQUIRES_OK_ASYNC(
299           context, context->allocate_output(0, output_shape, &output), done);
300 
301       bool use_deterministic_kernels =
302 #if defined(PLATFORM_WINDOWS)
303           // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows
304           // CI build error.
305           false;
306 #else
307           UseDeterministicSegmentReductions() ||
308           (!SegmentReductionFunctor::atomic_reduction_is_associative &&
309            OpDeterminismRequired());
310 #endif
311 
312       // The determinism check is here, rather than inside the functor (as it is
313       // for the unsorted segment reduction ops) because the done callback
314       // (required for OP_REQUIRES_ASYNC) is not available inside the functor.
315       bool determinism_requirement_met =
316           use_deterministic_kernels ||
317           SegmentReductionFunctor::atomic_reduction_is_associative ||
318           !OpDeterminismRequired() ||
319           DisableSegmentReductionOpDeterminismExceptions();
320       OP_REQUIRES_ASYNC(
321           context, determinism_requirement_met,
322           errors::Unimplemented(
323               "Deterministic GPU implementation of sorted segment reduction op"
324               " not available."),
325           done);
326 
327       auto output_flat = output->flat_outer_dims<T>();
328       auto data_ptr = input.template flat<T>().data();
329       auto segment_flat = segment_ids.flat<Index>();
330       functor_(context, context->eigen_device<GPUDevice>(), output_rows,
331                segment_ids.shape(), IsMean, segment_flat, input.NumElements(),
332                data_ptr, output_flat);
333 
334       done();
335     };
336 
337     context->device()
338         ->tensorflow_accelerator_device_info()
339         ->event_mgr->ThenExecute(stream, create_and_check_output);
340   }
341 };
342 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
343 
344 // ____________________________________________________________________________
345 // Unsorted segment reduction ops.
346 
347 namespace functor {
348 
349 // The ReductionFunctor implementation for CPU.
350 template <typename T, typename Index, typename InitialValueF,
351           typename ReductionF>
352 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
353   void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
354                   typename TTypes<Index>::ConstFlat segment_ids,
355                   typename TTypes<T, 2>::ConstTensor data,
356                   typename TTypes<T, 2>::Tensor output) {
357     auto cpu_device = ctx->eigen_cpu_device();
358     output.device(cpu_device) = output.constant(InitialValueF()());
359     if (data.size() == 0) {
360       return;
361     }
362 
363     // This functor will reduce `N` rows input to `num_segments` rows output.
364     const int64_t N = segment_ids.dimension(0);
365     const int64_t num_segments = output.dimension(0);
366     const int64_t inner_dim = data.dimension(1);
367     ReductionF reduction;
368 
369     // `num_real_segment` counts the rows actually reduced from input,
370     // the rows with negative segment index will be excluded.
371     // It will be used for cost model.
372     int64_t num_real_segment = N;
373     // `num_reductions` counts the rows actually reduced in output,
374     // the rows only filled with InitialValueF() will be excluded.
375     int64_t num_reductions = 0;
376     // `row_counter` records how many input rows will be reduced in each
377     // output row, the row only fills with InitialValueF() will keep 0.
378     // Length of non-zero elements is `num_reductions`.
379     std::vector<Index> row_counter(num_segments, 0);
380 
381     for (int64_t i = 0; i < N; ++i) {
382       Index j = internal::SubtleMustCopy(segment_ids(i));
383       if (j < 0) {
384         --num_real_segment;
385         continue;
386       }
387       OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments),
388                   errors::InvalidArgument(
389                       "segment_ids", SliceDebugString(segment_ids_shape, i),
390                       " = ", j, " is out of range [0, ", num_segments, ")"));
391       if (row_counter[j] == 0) num_reductions++;
392       row_counter[j]++;
393     }
394 
395     // Nothing to reduce. All output values equal to `InitialValueF()`.
396     if (num_reductions == 0) return;
397 
398     // Parallelize by `num_segments`. It's simple, efficient and safe
399     // (no data dependency):
400     //
401     //   input   segment_ids                 num_segments  operation
402     //   | a0 |  | 0 |            worker 1:  |0|           f(a0, a1)
403     //   | b0 |  | 1 |            worker 2:  |1|           f(b0, b1)
404     // N | c0 |  | 2 |       -->  worker 3:  |2|           f(c0)
405     //   | b1 |  | 1 |
406     //   | a1 |  | 0 |
407     //
408     // TODO(intel-tf): Balance workload in `row_counter` to make parallelism
409     //                 more efficient.
410     auto reductionWorker = [&](int64_t begin, int64_t end) -> void {
411       for (int64_t i = 0; i < N; i++) {
412         Index j = internal::SubtleMustCopy(segment_ids(i));
413         // If `j` is in work scope of this worker, do the reduction.
414         if (j >= begin && j < end) {
415           reduction(data.template chip<0>(i), output.template chip<0>(j));
416         }
417       }
418     };
419 
420     // Reduction functors includes Sum, Max, Min, etc. Simply consider it
421     // will cost 5 cycles per operation.
422     const int64_t kAverTaskSize = num_real_segment / num_segments;
423     const int64_t compute_cycles = 5 * inner_dim * kAverTaskSize;
424     const int64_t input_bytes = sizeof(T) * inner_dim * kAverTaskSize;
425     const int64_t output_bytes = sizeof(T) * inner_dim * kAverTaskSize;
426     const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
427     cpu_device.parallelFor(num_segments, cost, reductionWorker);
428   }
429 };
430 
431 template <typename T>
432 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>;
433 
434 template <typename T>
435 using constMatrixChip =
436     Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>;
437 
438 // reduction functors
439 template <typename T>
440 struct SumOp {
441   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
442     output += data;
443   }
444 };
445 
446 template <typename T>
447 struct MaxOp {
448   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
449     output = data.cwiseMax(output);
450   }
451 };
452 
453 template <typename T>
454 struct MinOp {
455   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
456     output = data.cwiseMin(output);
457   }
458 };
459 
460 template <typename T>
461 struct ProdOp {
462   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
463     output *= data;
464   }
465 };
466 }  // namespace functor
467 
468 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor
469 // is the device specific implementation of the reduction. These device
470 // specific implementations are templated themselves with the corresponding
471 // initial value functors and reduction functors.
472 template <typename T, typename Index, typename DeviceReductionFunctor>
473 class UnsortedSegmentReductionOp : public OpKernel {
474  public:
475   explicit UnsortedSegmentReductionOp(OpKernelConstruction* context)
476       : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {}
477 
478   void Compute(OpKernelContext* context) override {
479     const Tensor& data = context->input(0);
480     const Tensor& segment_ids = context->input(1);
481     const Tensor& num_segments = context->input(2);
482     OP_REQUIRES_OK(context,
483                    internal::ValidateUnsortedSegmentReduction(
484                        this, context, data, segment_ids, num_segments));
485     const auto segment_flat = segment_ids.flat<Index>();
486     const int64_t output_rows = internal::SubtleMustCopy(static_cast<int64_t>(
487         num_segments.dtype() == DT_INT32 ? num_segments.scalar<int32>()()
488                                          : num_segments.scalar<int64_t>()()));
489     OP_REQUIRES(context, output_rows >= 0,
490                 errors::InvalidArgument("Input num_segments == ", output_rows,
491                                         " must not be negative."));
492     TensorShape output_shape;
493     output_shape.AddDim(output_rows);
494     for (int i = segment_ids.dims(); i < data.dims(); i++) {
495       output_shape.AddDim(data.dim_size(i));
496     }
497     Tensor* output = nullptr;
498     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
499     auto output_flat = output->flat_outer_dims<T>();
500     auto data_flat = data.flat_inner_outer_dims<T, 2>(segment_ids.dims() - 1);
501     reduction_functor_(context, segment_ids.shape(), segment_flat, data_flat,
502                        output_flat);
503   }
504 
505  protected:
506   DeviceReductionFunctor reduction_functor_;
507 };
508 
509 // ____________________________________________________________________________
510 // Sparse segment reduction ops.
511 
512 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
513 // by two dense tensors, one containing the data, and the other containing
514 // indices into the data.
515 //
516 // The template parameters are:
517 // * Device: An Eigen device object, on which the kernel will execute.
518 // * T: The value type.
519 // * Index: The element type of the indices tensor (int32 or int64).
520 // * SegmentId: The element type of the segment_ids tensor (int32 or int64).
521 template <typename Device, class T, typename Index, typename SegmentId>
522 class SparseSegmentReductionOpBase : public OpKernel {
523  public:
524   explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
525                                         bool is_mean, bool is_sqrtn,
526                                         bool has_num_segments, T default_value)
527       : OpKernel(context),
528         dtidx_(DataTypeToEnum<Index>::v()),
529         is_mean_(is_mean),
530         is_sqrtn_(is_sqrtn),
531         has_num_segments_(has_num_segments),
532         default_value_(default_value) {}
533 
534   void Compute(OpKernelContext* context) override {
535     const Tensor& input = context->input(0);
536     const Tensor& indices = context->input(1);
537     const Tensor& segment_ids = context->input(2);
538 
539     OP_REQUIRES_OK(
540         context, internal::ValidateSparseSegmentReduction(
541                      context, input, indices, segment_ids, has_num_segments_));
542 
543     Index output_rows = -1;
544     if (has_num_segments_) {
545       const Tensor& num_segments = context->input(3);
546       // Note that there is a Tnumsegments parameter on the op, but it is not
547       // plumbed through to here and so always takes its default value of int32.
548       output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()());
549     }
550     const int64_t num_indices = indices.NumElements();
551 
552     auto input_flat = input.flat_outer_dims<T>();
553     const int64_t num_col = input_flat.dimension(1);
554     const auto indices_vec = indices.vec<Index>();
555     const auto segment_vec = segment_ids.vec<SegmentId>();
556     // Note that the current implementation assumes that segment_vec values are
557     // sorted.
558     const SegmentId last_segment_id =
559         num_indices > 0 ? segment_vec(num_indices - 1) : 0;
560     int64_t limit = dtidx_ == DataType::DT_INT32 ? kint32max : kint64max;
561 
562     OP_REQUIRES(
563         context, last_segment_id < limit,
564         errors::InvalidArgument("Last segment id must be < kintmax, got ",
565                                 last_segment_id, " limit ", limit));
566 
567     const SegmentId last_segment_id_plus_one =
568         num_indices > 0
569             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
570             : 0;
571 
572     if (has_num_segments_) {
573       OP_REQUIRES(
574           context, output_rows >= last_segment_id_plus_one,
575           errors::InvalidArgument("segment ids must be < num_segments"));
576     } else {
577       output_rows = last_segment_id_plus_one;
578     }
579     OP_REQUIRES(context, output_rows >= 0,
580                 errors::InvalidArgument("segment ids must be >= 0"));
581 
582     TensorShape output_shape = input.shape();
583     OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows));
584 
585     // Note that we do not initialize the output buffer with a default value, so
586     // we need to explicitly set missing indices to the default value.
587     Tensor* output = nullptr;
588     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
589     if (num_indices == 0) {
590       if (output_rows > 0) {
591         output->flat_outer_dims<T>().setConstant(default_value_);
592       }
593       return;
594     }
595     OP_REQUIRES(context, output_rows > 0,
596                 errors::InvalidArgument("segment ids must be >= 0"));
597     auto output_flat = output->flat_outer_dims<T>();
598 
599     Tensor temp;
600     if (input.dtype() == DT_BFLOAT16 || input.dtype() == DT_HALF) {
601       temp = tensorflow::Tensor(DT_FLOAT, output_shape);
602     }
603     auto temp_flat = temp.flat_outer_dims<float>();
604 
605     int64_t start = 0, end = 1;
606     // Index from which the output is not initialized.
607     SegmentId uninitialized_index = 0;
608     SegmentId out_index = internal::SubtleMustCopy(segment_vec(start));
609 
610     while (true) {
611       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
612       // used uninitialized in this function" in the Mac build (since the
613       // compiler isn't smart enough to realize the code is safe).
614       SegmentId next_index = 0;
615       if (end < num_indices) {
616         next_index = internal::SubtleMustCopy(segment_vec(end));
617         if (out_index == next_index) {
618           ++end;
619           continue;
620         }
621         // We have a new segment here.  Verify that the segment ids are growing.
622         OP_REQUIRES(context, out_index < next_index,
623                     errors::InvalidArgument("segment ids are not increasing"));
624       }
625 
626       OP_REQUIRES(
627           context, FastBoundsCheck(out_index, output_rows),
628           errors::InvalidArgument(
629               "Segment id ", out_index, " out of range [0, ", output_rows,
630               "), possibly because 'segment_ids' input is not sorted."));
631 
632       // If there is a gap between two indices, we need to set that gap to the
633       // default value.
634       if (out_index > uninitialized_index) {
635         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
636             out_index - uninitialized_index, num_col);
637         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
638             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
639         gap_slice.setConstant(default_value_);
640       }
641 
642       auto out = output_flat.template chip<0>(out_index);
643       auto temp = temp_flat.template chip<0>(out_index);
644       const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start,
645                                               end - start, out, temp);
646       OP_REQUIRES(context, bad_offset < 0,
647                   errors::InvalidArgument(
648                       "Bad: indices[", start + bad_offset,
649                       "] == ", indices_vec(start + bad_offset),
650                       " out of range [0, ", input_flat.dimension(0), ")"));
651 
652       start = end;
653       ++end;
654       uninitialized_index = out_index + 1;
655       out_index = next_index;
656       if (end > num_indices) break;
657     }
658 
659     // Fill the gap at the end with the default value.
660     if (uninitialized_index < output_rows) {
661       Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
662           output_rows - uninitialized_index, num_col);
663       Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
664           gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
665       gap_slice.setConstant(default_value_);
666     }
667   }
668 
669  private:
670   const DataType dtidx_;
671   template <typename Tin>
672   using EnableIfBfloat16OrHalf =
673       typename std::enable_if<std::is_same<Tin, bfloat16>::value ||
674                                   std::is_same<Tin, Eigen::half>::value,
675                               int>::type;
676   template <typename Tin>
677   using EnableIfNotBfloat16OrHalf =
678       typename std::enable_if<!std::is_same<Tin, bfloat16>::value &&
679                                   !std::is_same<Tin, Eigen::half>::value,
680                               int>::type;
681 
682   template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0>
683   EIGEN_ALWAYS_INLINE auto fetch_val(
684       const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
685     return input_flat.template chip<0>(index);
686   }
687 
688   template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0>
689   EIGEN_ALWAYS_INLINE auto fetch_val(
690       const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
691     return input_flat.template chip<0>(index).template cast<float>();
692   }
693 
694   template <typename Tout>
695   EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64_t num) {
696     Tout m(1);
697     if (is_mean_ && (num < 10)) {
698       m = Tout(num);
699     }
700     if (is_sqrtn_ && (num < 10)) {
701       m = Tout(sqrt(num));
702     }
703     return Tout(1) / m;
704   }
705 
706   template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0>
707   int64_t Reduce(
708       const typename TTypes<Tin>::ConstMatrix& input_flat,
709       const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
710       int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
711       Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
712     return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num,
713                                         out, get_scaling_factor<Tin>(num));
714   }
715 
716   template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0>
717   int64_t Reduce(
718       const typename TTypes<Tin>::ConstMatrix& input_flat,
719       const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
720       int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
721       Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
722     int64_t res =
723         ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num,
724                                        temp, get_scaling_factor<float>(num));
725     out = temp.template cast<Tin>();
726     return res;
727   }
728 
729   template <typename Tin, typename Tindex, typename Tout>
730   int64_t ReduceImpl(
731       const typename TTypes<Tin>::ConstMatrix& input_flat,
732       const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
733       int64_t num,
734       Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out,
735       const Tout scaling_factor) {
736 #define INDEX(n, i)                               \
737   const auto index##n = indices_vec(start + (i)); \
738   if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
739 
740 #define L(n) fetch_val<Tin, Tindex>(input_flat, index##n)
741 
742     if (num == 1) {
743       INDEX(0, 0);
744       out = L(0);
745     } else {
746       int64_t r = num & 7;
747       switch (r) {
748         case 2: {
749           INDEX(0, 0);
750           INDEX(1, 1);
751           out = (L(0) + L(1)) * scaling_factor;
752           break;
753         }
754         case 3: {
755           INDEX(0, 0);
756           INDEX(1, 1);
757           INDEX(2, 2);
758           out = (L(0) + L(1) + L(2)) * scaling_factor;
759           break;
760         }
761         case 4: {
762           INDEX(0, 0);
763           INDEX(1, 1);
764           INDEX(2, 2);
765           INDEX(3, 3);
766           out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor;
767           break;
768         }
769         case 5: {
770           INDEX(0, 0);
771           INDEX(1, 1);
772           INDEX(2, 2);
773           INDEX(3, 3);
774           INDEX(4, 4);
775           out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor;
776           break;
777         }
778         case 6: {
779           INDEX(0, 0);
780           INDEX(1, 1);
781           INDEX(2, 2);
782           INDEX(3, 3);
783           INDEX(4, 4);
784           INDEX(5, 5);
785           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor;
786           break;
787         }
788         case 7: {
789           INDEX(0, 0);
790           INDEX(1, 1);
791           INDEX(2, 2);
792           INDEX(3, 3);
793           INDEX(4, 4);
794           INDEX(5, 5);
795           INDEX(6, 6);
796           out =
797               (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor;
798           break;
799         }
800         case 0: {
801           INDEX(0, 0);
802           INDEX(1, 1);
803           INDEX(2, 2);
804           INDEX(3, 3);
805           INDEX(4, 4);
806           INDEX(5, 5);
807           INDEX(6, 6);
808           INDEX(7, 7);
809           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) *
810                 scaling_factor;
811           r = 8;
812           break;
813         }
814         case 1: {
815           INDEX(0, 0);
816           INDEX(1, 1);
817           INDEX(2, 2);
818           INDEX(3, 3);
819           INDEX(4, 4);
820           INDEX(5, 5);
821           INDEX(6, 6);
822           INDEX(7, 7);
823           INDEX(8, 8);
824           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) *
825                 scaling_factor;
826           r = 9;
827           break;
828         }
829       }
830       for (; r < num; r += 8) {
831         INDEX(0, r);
832         INDEX(1, r + 1);
833         INDEX(2, r + 2);
834         INDEX(3, r + 3);
835         INDEX(4, r + 4);
836         INDEX(5, r + 5);
837         INDEX(6, r + 6);
838         INDEX(7, r + 7);
839         out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
840       }
841       if (is_mean_ && num >= 10) {
842         out = out / static_cast<Tout>(num);
843       }
844       if (is_sqrtn_ && num >= 10) {
845         out = out / static_cast<Tout>(sqrt(num));
846       }
847     }
848 
849     return -1;
850 #undef L
851 #undef INDEX
852   }
853 
854   const bool is_mean_;
855   const bool is_sqrtn_;
856   const bool has_num_segments_;
857   const T default_value_;
858 };
859 
860 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
861 
862 // Specialization for GPU. Must be Async because may need to wait for a host to
863 // device memcpy before allocating output.
864 template <class T, typename Index, typename SegmentId>
865 class SparseSegmentReductionOpBase<GPUDevice, T, Index, SegmentId>
866     : public AsyncOpKernel {
867  public:
868   explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
869                                         bool is_mean, bool is_sqrtn,
870                                         bool has_num_segments, T default_value)
871       : AsyncOpKernel(context),
872         is_mean_(is_mean),
873         is_sqrtn_(is_sqrtn),
874         has_num_segments_(has_num_segments),
875         default_value_(default_value) {}
876 
877   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
878     const Tensor& input = context->input(0);
879     const Tensor& indices = context->input(1);
880     const Tensor& segment_ids = context->input(2);
881 
882     OP_REQUIRES_OK_ASYNC(
883         context,
884         internal::ValidateSparseSegmentReduction(
885             context, input, indices, segment_ids, has_num_segments_),
886         done);
887 
888     ScratchSpace<SegmentId> last_segment_id_host(context, 1, /*on_host=*/true);
889 
890     auto create_and_check_output = [this, context, input, indices, segment_ids,
891                                     last_segment_id_host, done]() {
892       // Ensure that within the callback, the proper GPU settings are
893       // configured.
894       auto stream = context->op_device_context()->stream();
895       ScopedActivateExecutorContext scoped_activation{stream->parent()};
896 
897       SegmentId last_segment_id = *last_segment_id_host.data();
898       SegmentId output_rows = last_segment_id + 1;
899       OP_REQUIRES_ASYNC(context, output_rows > 0,
900                         errors::InvalidArgument("segment ids must be >= 0"),
901                         done);
902 
903       TensorShape output_shape = input.shape();
904       output_shape.set_dim(0, output_rows);
905 
906       Tensor* output = nullptr;
907       OP_REQUIRES_OK_ASYNC(
908           context, context->allocate_output(0, output_shape, &output), done);
909 
910       auto input_flat = input.flat_outer_dims<T>();
911       const auto indices_vec = indices.vec<Index>();
912       const auto segment_ids_vec = segment_ids.vec<SegmentId>();
913       auto output_flat = output->flat_outer_dims<T>();
914 
915       functor::SparseSegmentReductionFunctor<T, Index, SegmentId> functor;
916       OP_REQUIRES_OK_ASYNC(
917           context,
918           functor(context, is_mean_, is_sqrtn_, default_value_, input_flat,
919                   indices_vec, segment_ids_vec, output_flat),
920           done);
921       done();
922     };
923 
924     if (has_num_segments_) {
925       // No need to do any device to host memcpy, just compute synchronously.
926       const Tensor& num_segments_t = context->input(3);
927       SegmentId num_segments =
928           internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32
929                                        ? num_segments_t.scalar<int32>()()
930                                        : num_segments_t.scalar<int64_t>()());
931       *last_segment_id_host.mutable_data() = num_segments - 1;
932       create_and_check_output();
933     } else {
934       const int64_t num_indices = indices.NumElements();
935       // Need to copy last element of segment_ids from device to host, and then
936       // asynchronously allocate the output and finish the computation.
937       se::DeviceMemoryBase last_segment_id_device(
938           const_cast<Tensor&>(segment_ids).template flat<SegmentId>().data() +
939           (num_indices - 1));
940       auto stream = context->op_device_context()->stream();
941       OP_REQUIRES_ASYNC(
942           context,
943           stream
944               ->ThenMemcpy(last_segment_id_host.mutable_data(),
945                            last_segment_id_device, sizeof(SegmentId))
946               .ok(),
947           errors::Internal(type_string() +
948                            ": failed to copy last_segment_id from device"),
949           done);
950       context->device()
951           ->tensorflow_accelerator_device_info()
952           ->event_mgr->ThenExecute(stream, create_and_check_output);
953     }
954   }
955 
956  private:
957   const bool is_mean_;
958   const bool is_sqrtn_;
959   const bool has_num_segments_;
960   const T default_value_;
961 };
962 
963 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
964 
965 template <typename Device, class T, typename Index, typename SegmentId>
966 class SparseSegmentReductionMeanOp
967     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
968  public:
969   explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
970       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
971             context, true /*is_mean*/, false /*is_sqrtn*/,
972             false /* has_num_segments */, T(0) /* default_value */) {}
973 };
974 
975 template <typename Device, class T, typename Index, typename SegmentId>
976 class SparseSegmentReductionMeanWithNumSegmentsOp
977     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
978  public:
979   explicit SparseSegmentReductionMeanWithNumSegmentsOp(
980       OpKernelConstruction* context)
981       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
982             context, true /*is_mean*/, false /*is_sqrtn*/,
983             true /* has_num_segments */, T(0) /* default_value */) {}
984 };
985 
986 template <typename Device, class T, typename Index, typename SegmentId>
987 class SparseSegmentReductionSqrtNOp
988     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
989  public:
990   explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
991       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
992             context, false /*is_mean*/, true /*is_sqrtn*/,
993             false /* has_num_segments */, T(0) /* default_value */) {}
994 };
995 
996 template <typename Device, class T, typename Index, typename SegmentId>
997 class SparseSegmentReductionSqrtNWithNumSegmentsOp
998     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
999  public:
1000   explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
1001       OpKernelConstruction* context)
1002       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
1003             context, false /*is_mean*/, true /*is_sqrtn*/,
1004             true /* has_num_segments */, T(0) /* default_value */) {}
1005 };
1006 
1007 template <typename Device, class T, typename Index, typename SegmentId>
1008 class SparseSegmentReductionSumOp
1009     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
1010  public:
1011   explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
1012       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
1013             context, false /*is_mean*/, false /*is_sqrtn*/,
1014             false /* has_num_segments */, T(0) /* default_value */) {}
1015 };
1016 
1017 template <typename Device, class T, typename Index, typename SegmentId>
1018 class SparseSegmentReductionSumWithNumSegmentsOp
1019     : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
1020  public:
1021   explicit SparseSegmentReductionSumWithNumSegmentsOp(
1022       OpKernelConstruction* context)
1023       : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
1024             context, false /*is_mean*/, false /*is_sqrtn*/,
1025             true /* has_num_segments */, T(0) /* default_value */) {}
1026 };
1027 
1028 namespace functor {
1029 
1030 template <typename T, typename Index, typename SegmentId>
1031 struct SparseSegmentGradFunctor<CPUDevice, T, Index, SegmentId> {
1032   void operator()(OpKernelContext* context,
1033                   SparseSegmentReductionOperation operation,
1034                   typename TTypes<T>::ConstMatrix input_flat,
1035                   typename TTypes<Index>::ConstVec indices_vec,
1036                   typename TTypes<SegmentId>::ConstVec segment_vec,
1037                   typename TTypes<T>::Matrix output_flat) {
1038     const int64_t N = indices_vec.size();
1039     const SegmentId M = output_flat.dimension(0);
1040 
1041     // Note that similar to SparseSegmentMean, we assume that segment_vec is
1042     // already sorted and has non-negative values.
1043     const SegmentId num_segments = input_flat.dimension(0);
1044     const SegmentId last_segment_id_plus_one =
1045         internal::SubtleMustCopy(segment_vec(N - 1)) + 1;
1046     OP_REQUIRES(context, last_segment_id_plus_one <= num_segments,
1047                 errors::InvalidArgument("Invalid number of segments"));
1048 
1049     // Compute scaling factors for input.
1050     std::vector<double> scaling(
1051         (operation == SparseSegmentReductionOperation::kSum ? 0 : num_segments),
1052         0.0);
1053     if (operation != SparseSegmentReductionOperation::kSum) {
1054       for (int64_t i = 0; i < N; ++i) {
1055         const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1056         OP_REQUIRES(
1057             context, FastBoundsCheck(idx, num_segments),
1058             errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1059                                     num_segments, ")."));
1060         scaling[idx] += 1;
1061       }
1062       for (size_t i = 0; i < scaling.size(); ++i) {
1063         switch (operation) {
1064           case SparseSegmentReductionOperation::kSum: {
1065             OP_REQUIRES(
1066                 context, false,
1067                 errors::Internal(
1068                     "Should not happen: sum inside SparseSegmentReductionOp "
1069                     "scaling generation."));
1070           }
1071           case SparseSegmentReductionOperation::kMean: {
1072             scaling[i] = 1.0 / std::max(scaling[i], 1.0);
1073             break;
1074           }
1075           case SparseSegmentReductionOperation::kSqrtN: {
1076             scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0));
1077             break;
1078           }
1079             // No default to get compiler warnings for missing cases.
1080         }
1081       }
1082     }
1083 
1084     output_flat.setZero();
1085     std::vector<bool> is_modified(M, false);
1086 
1087     for (int64_t i = 0; i < N; ++i) {
1088       const Index output_idx = internal::SubtleMustCopy(indices_vec(i));
1089       OP_REQUIRES(context, FastBoundsCheck(output_idx, M),
1090                   errors::InvalidArgument("Index ", output_idx,
1091                                           " out of range [0, ", M, ")."));
1092 
1093       const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1094       OP_REQUIRES(
1095           context, FastBoundsCheck(idx, num_segments),
1096           errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1097                                   num_segments, ")."));
1098 
1099       const T scale = (operation == SparseSegmentReductionOperation::kSum
1100                            ? static_cast<T>(1)
1101                            : static_cast<T>(scaling[idx]));
1102       if (is_modified[output_idx]) {
1103         if (scale == 1.0) {
1104           output_flat.template chip<0>(output_idx) +=
1105               input_flat.template chip<0>(idx);
1106         } else {
1107           output_flat.template chip<0>(output_idx) +=
1108               input_flat.template chip<0>(idx) * scale;
1109         }
1110       } else {
1111         if (scale == 1.0) {
1112           output_flat.template chip<0>(output_idx) =
1113               input_flat.template chip<0>(idx);
1114         } else {
1115           output_flat.template chip<0>(output_idx) =
1116               input_flat.template chip<0>(idx) * scale;
1117         }
1118       }
1119       is_modified[output_idx] = true;
1120     }
1121   }
1122 };
1123 
1124 }  // namespace functor
1125 
1126 // Implements the common logic for the gradients of SparseSegmentReduction
1127 // kernels.
1128 //
1129 // The template parameters are:
1130 // * Device: An Eigen device object, on which the kernel will execute.
1131 // * T: The value type.
1132 // * Index: The element type of the indices tensor (int32 or int64).
1133 // * SegmentId: The element type of the segment_ids tensor (int32 or int64).
1134 template <typename Device, class T, typename Index, typename SegmentId>
1135 class SparseSegmentGradOpBase : public OpKernel {
1136  public:
1137   explicit SparseSegmentGradOpBase(OpKernelConstruction* context,
1138                                    SparseSegmentReductionOperation operation)
1139       : OpKernel(context), operation_(operation) {}
1140 
1141   void Compute(OpKernelContext* context) override {
1142     const Tensor& input = context->input(0);
1143     const Tensor& indices = context->input(1);
1144     const Tensor& segment_ids = context->input(2);
1145     const Tensor& output_dim0 = context->input(3);
1146 
1147     OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
1148                 errors::InvalidArgument("indices should be a vector."));
1149     OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
1150                 errors::InvalidArgument("segment_ids should be a vector."));
1151     OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()),
1152                 errors::InvalidArgument("output_dim0 should be a scalar."));
1153 
1154     const int64_t N = indices.NumElements();
1155     OP_REQUIRES(context, N == segment_ids.NumElements(),
1156                 errors::InvalidArgument(
1157                     "segment_ids and indices should have same size."));
1158     const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar<int32>()());
1159 
1160     auto input_flat = input.flat_outer_dims<T>();
1161     const auto indices_vec = indices.vec<Index>();
1162     const auto segment_vec = segment_ids.vec<SegmentId>();
1163 
1164     TensorShape output_shape = input.shape();
1165     OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, M));
1166     Tensor* output = nullptr;
1167     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
1168     if (M == 0 || N == 0) return;
1169 
1170     auto output_flat = output->flat_outer_dims<T>();
1171     functor::SparseSegmentGradFunctor<Device, T, Index, SegmentId>()(
1172         context, operation_, input_flat, indices_vec, segment_vec, output_flat);
1173   }
1174 
1175  private:
1176   const SparseSegmentReductionOperation operation_;
1177 };
1178 
1179 template <typename Device, class T, typename Index, typename SegmentId>
1180 class SparseSegmentSumGradOp
1181     : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1182  public:
1183   explicit SparseSegmentSumGradOp(OpKernelConstruction* context)
1184       : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1185             context, SparseSegmentReductionOperation::kSum) {}
1186 };
1187 
1188 template <typename Device, class T, typename Index, typename SegmentId>
1189 class SparseSegmentMeanGradOp
1190     : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1191  public:
1192   explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
1193       : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1194             context, SparseSegmentReductionOperation::kMean) {}
1195 };
1196 
1197 template <typename Device, class T, typename Index, typename SegmentId>
1198 class SparseSegmentSqrtNGradOp
1199     : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1200  public:
1201   explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
1202       : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1203             context, SparseSegmentReductionOperation::kSqrtN) {}
1204 };
1205 
1206 }  // namespace tensorflow
1207 
1208 #endif  // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
1209