xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/kernels/gpu_prim.h"
22 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
23 #include "tensorflow/core/kernels/segment_reduction_ops.h"
24 #include "tensorflow/core/lib/core/bits.h"
25 #include "tensorflow/core/util/determinism.h"
26 #include "tensorflow/core/util/env_var.h"
27 #include "tensorflow/core/util/gpu_device_functions.h"
28 #include "tensorflow/core/util/gpu_kernel_helper.h"
29 #include "tensorflow/core/util/permutation_input_iterator.h"
30 
31 namespace tensorflow {
32 
33 using GPUDevice = Eigen::GpuDevice;
34 
35 // Non/Atomic reduction functors for the gpu.
36 #define DEFINE_REDUCE_UPDATE_OP_GPU(name, func)                             \
37   struct name##OpGpu {                                                      \
38     template <typename T>                                                   \
39     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,          \
40                                                           const T& value) { \
41       func;                                                                 \
42     }                                                                       \
43   };
44 DEFINE_REDUCE_UPDATE_OP_GPU(AtomicSum, GpuAtomicAdd(dest, value))
45 DEFINE_REDUCE_UPDATE_OP_GPU(AtomicProd, GpuAtomicMul(dest, value))
46 DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMax, GpuAtomicMax(dest, value))
47 DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMin, GpuAtomicMin(dest, value))
48 DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicSum, *dest += value)
49 DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicProd, *dest *= value)
50 DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMax, *dest = max(*dest, value))
51 DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMin, *dest = min(*dest, value))
52 #undef DEFINE_REDUCE_UPDATE_OP_GPU
53 
54 template <typename ReduceOp>
55 struct ReduceUpdateOpFor {};
56 
57 #define DEFINE_REDUCE_UPDATE_OP_FOR(reduce_op, atomic, nonatomic) \
58   template <>                                                     \
59   struct ReduceUpdateOpFor<reduce_op> {                           \
60     using atomic_op = atomic;                                     \
61     using nonatomic_op = nonatomic;                               \
62   };
DEFINE_REDUCE_UPDATE_OP_FOR(functor::Sum,AtomicSumOpGpu,NonAtomicSumOpGpu)63 DEFINE_REDUCE_UPDATE_OP_FOR(functor::Sum, AtomicSumOpGpu, NonAtomicSumOpGpu)
64 DEFINE_REDUCE_UPDATE_OP_FOR(functor::Prod, AtomicProdOpGpu, NonAtomicProdOpGpu)
65 DEFINE_REDUCE_UPDATE_OP_FOR(functor::Max, AtomicMaxOpGpu, NonAtomicMaxOpGpu)
66 DEFINE_REDUCE_UPDATE_OP_FOR(functor::Min, AtomicMinOpGpu, NonAtomicMinOpGpu)
67 #undef DEFINE_REDUCE_UPDATE_OP_FOR
68 
69 // SortedSegmentReductionFunctor kernel reduces input data just as
70 // UnsortedSegmentReductionCustomKernel does except that input data
71 // is partitioned along the outer reduction dimension. This is
72 // because consecutive rows (elements in a row share the same
73 // outer dimension index) in the flattened 2D input data likely
74 // belong to the same segment in sorted segment sum operation.
75 // Therefore such partitioning strategy has two advantages over
76 // the UnsortedSegmentReductionFunctor kernel:
77 // 1. Each thread reduces across multiple rows before writing
78 // answers to the global memory, we can therefore
79 // write reduction results to global memory less often.
80 // 2. We may know that the current thread is the only contributor
81 // to an output element because of the increasing nature of segment
82 // ids. In such cases, we do not need to use atomic operations
83 // to write results to global memory.
84 // In the flattened view of input data (with only outer and inner
85 // dimension), every thread processes a strip of input data of
86 // size OuterDimTileSize x 1. This strip runs across multiple
87 // rows of input data and all reduction elements share one inner
88 // dimension index.
89 template <typename T, typename Index, int OuterDimTileSize, typename ReductionF,
90           typename AtomicReductionF>
91 __global__ void SortedSegmentReductionCustomKernel(
92     const Index input_outer_dim_size, const Index inner_dim_size,
93     const Index output_outer_dim_size, const Index* __restrict__ segment_ids,
94     const T* __restrict__ input, T* __restrict__ output,
95     const Index total_stripe_count, const T initial_value) {
96   for (int stripe_index : GpuGridRangeX(total_stripe_count)) {
97     const Index segment_offset = stripe_index % inner_dim_size;
98     const Index input_outer_dim_index_base =
99         stripe_index / inner_dim_size * Index(OuterDimTileSize);
100 
101     T reduce_res = initial_value;
102     Index first_segment_id = segment_ids[input_outer_dim_index_base];
103     Index last_output_segment_id = output_outer_dim_size;
104 
105     const Index actual_stripe_height =
106         min(Index(OuterDimTileSize),
107             input_outer_dim_size - input_outer_dim_index_base);
108     for (Index j = 0; j < actual_stripe_height; j++) {
109       Index current_output_segment_id =
110           segment_ids[input_outer_dim_index_base + j];
111       // Decide whether to write result to global memory. Result is only written
112       // to global memory if we move to another segment. Otherwise we can keep
113       // accumulating locally.
114       if (current_output_segment_id > last_output_segment_id) {
115         const Index output_index =
116             last_output_segment_id * inner_dim_size + segment_offset;
117         // Decide whether to write result to global memory using atomic
118         // operations.
119         if (last_output_segment_id == first_segment_id) {
120           AtomicReductionF()(output + output_index, reduce_res);
121         } else {
122           ReductionF()(output + output_index, reduce_res);
123         }
124         reduce_res = initial_value;
125       }
126       ReductionF()(
127           &reduce_res,
128           ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
129               segment_offset));
130       last_output_segment_id = current_output_segment_id;
131     }
132     // For the last result in a strip, always write using atomic operations
133     // due to possible race conditions with threads computing
134     // the following strip.
135     const Index output_index =
136         last_output_segment_id * inner_dim_size + segment_offset;
137     AtomicReductionF()(output + output_index, reduce_res);
138   }
139 }
140 
141 template <typename SegmentId, typename Index, typename T>
SegmentMeanNormalizeKernel(SegmentId nsegments,Index ninner,const Index * __restrict__ segment_offsets,T * __restrict__ output)142 __global__ void SegmentMeanNormalizeKernel(
143     SegmentId nsegments, Index ninner,
144     const Index* __restrict__ segment_offsets,  // [nsegments + 1]
145     T* __restrict__ output) {                   // [nsegments, ninner]
146   for (SegmentId seg : GpuGridRangeY(nsegments)) {
147     SegmentId segment_size = segment_offsets[seg + 1] - segment_offsets[seg];
148     segment_size = max(segment_size, Index(1));  // Avoid division by zero
149     T inv_norm = T(1) / static_cast<T>(segment_size);
150     for (Index i : GpuGridRangeX(ninner)) {
151       output[seg * ninner + i] *= inv_norm;
152     }
153   }
154 }
155 
156 template <typename SegmentId, typename Index, typename T>
LaunchSegmentMeanNormalizeKernel(const GPUDevice & d,SegmentId nsegments,Index ninner,const Index * __restrict__ segment_offsets,T * __restrict__ output)157 Status LaunchSegmentMeanNormalizeKernel(
158     const GPUDevice& d, SegmentId nsegments, Index ninner,
159     const Index* __restrict__ segment_offsets,  // [nsegments + 1]
160     T* __restrict__ output) {                   // [nsegments, ninner]
161   Gpu2DLaunchConfig config = GetGpu2DLaunchConfig(
162       ninner, nsegments, d, SegmentMeanNormalizeKernel<SegmentId, Index, T>,
163       /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
164   return GpuLaunchKernel(SegmentMeanNormalizeKernel<SegmentId, Index, T>,
165                          config.block_count, config.thread_per_block, 0,
166                          d.stream(), nsegments, ninner, segment_offsets,
167                          output);
168 }
169 
170 // UnsortedSegmentSumKernel processes 'input_total_size' elements.
171 // Each element is mapped from input to output by a combination of its
172 // 'segment_ids' mapping and 'inner_dim_size'.
173 template <typename T, typename Index, typename KernelReductionFunctor>
UnsortedSegmentCustomKernel(const int64_t input_outer_dim_size,const int64_t inner_dim_size,const int64_t output_outer_dim_size,const Index * __restrict__ segment_ids,const T * __restrict__ input,T * __restrict__ output)174 __global__ void UnsortedSegmentCustomKernel(
175     const int64_t input_outer_dim_size, const int64_t inner_dim_size,
176     const int64_t output_outer_dim_size, const Index* __restrict__ segment_ids,
177     const T* __restrict__ input, T* __restrict__ output) {
178   const int64_t input_total_size = input_outer_dim_size * inner_dim_size;
179   for (int64_t input_index : GpuGridRangeX(input_total_size)) {
180     const int64_t input_segment_index = input_index / inner_dim_size;
181     const int64_t segment_offset = input_index % inner_dim_size;
182     const Index output_segment_index = segment_ids[input_segment_index];
183     if (output_segment_index < 0 ||
184         output_segment_index >= output_outer_dim_size) {
185       continue;
186     }
187     const int64_t output_index =
188         output_segment_index * inner_dim_size + segment_offset;
189     KernelReductionFunctor()(output + output_index, ldg(input + input_index));
190   }
191 }
192 
193 template <typename Tindex, typename Tsegmentids>
SegmentOffsetsKernel(Tindex size,Tsegmentids nsegments,const Tsegmentids * __restrict__ segment_ids,Tindex * __restrict__ segment_offsets)194 __global__ void SegmentOffsetsKernel(
195     Tindex size, Tsegmentids nsegments,
196     const Tsegmentids* __restrict__ segment_ids,  // [size]
197     Tindex* __restrict__ segment_offsets) {       // [nsegments + 1]
198   GPU_1D_KERNEL_LOOP(i, size + 1) {
199     // IDs are clipped to [-1, nsegments] so that out-of-bounds IDs are ignored.
200     // Note that we can't report invalid IDs from the GPU without incurring
201     // additional overhead.
202     auto clip = [&](Tsegmentids id) {
203       return min(max(Tsegmentids(-1), id), nsegments);
204     };
205     const Tsegmentids cur_id = (i < size) ? clip(segment_ids[i]) : nsegments;
206     const Tsegmentids prev_id =
207         (i == 0) ? Tsegmentids(-1) : clip(segment_ids[i - 1]);
208     // At segment boundaries, write the offset for this ID and any missing IDs
209     // since the previous one.
210     for (Tsegmentids id = prev_id + 1; id <= cur_id; ++id) {
211       segment_offsets[id] = i;
212     }
213   }
214 }
215 
216 // Finds the start offset of each segment in the given sorted segment_ids
217 // vector. Missing IDs are given the same offset as the next ID so that they
218 // represent empty ranges. Invalid IDs (those that are outside the range
219 // [0, nsegments)) are ignored. The value at segment_offsets[0] is set to the
220 // start index of the first valid ID (e.g., 0 if all IDs are valid), and the
221 // value at segment_offsets[nsegments] is set to the end index of the last valid
222 // ID (e.g., nsegments if all IDs are valid).
223 template <typename Tindex, typename Tsegmentids>
LaunchSegmentOffsetsKernel(const GPUDevice & d,Tindex size,Tsegmentids nsegments,const Tsegmentids * segment_ids,Tindex * segment_offsets)224 Status LaunchSegmentOffsetsKernel(const GPUDevice& d, Tindex size,
225                                   Tsegmentids nsegments,
226                                   const Tsegmentids* segment_ids,  // [size]
227                                   Tindex* segment_offsets) {  // [nsegments + 1]
228   GpuLaunchConfig config = GetGpuLaunchConfig(
229       size + 1, d, &SegmentOffsetsKernel<Tindex, Tsegmentids>,
230       /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
231   return GpuLaunchKernel(SegmentOffsetsKernel<Tindex, Tsegmentids>,
232                          config.block_count, config.thread_per_block, 0,
233                          d.stream(), size, nsegments, segment_ids,
234                          segment_offsets);
235 }
236 
237 template <typename T>
238 struct RealTypeIfComplex {
239   using type = T;
240 };
241 
242 template <typename Real>
243 struct RealTypeIfComplex<std::complex<Real>> {
244   using type = Real;
245 };
246 
247 // Reduces along columns of the thread block, returning the result in the first
248 // row of threads.
249 template <typename T, typename ReduceOp>
250 __device__ T ReduceBlockAlongCols(ReduceOp reduce_op, const T& value,
251                                   bool is_valid) {
252   GPU_DYNAMIC_SHARED_MEM_DECL(/*ALIGN=*/16, char, shared_memory_raw);
253   T* const shared_partial_reduction =
254       reinterpret_cast<T*>(shared_memory_raw);  // [blockDim.y, blockDim.x]
255   const int x = threadIdx.x;
256   const int y = threadIdx.y;
257   T reduced = value;
258   // Reduce over the y dimension of the block.
259   for (unsigned k = blockDim.y / 2; k > 0; k /= 2) {
260     if (is_valid && y < 2 * k) {
261       shared_partial_reduction[y * blockDim.x + x] = reduced;
262     }
263     __syncthreads();
264     if (is_valid && y < k) {
265       reduced = reduce_op(reduced,
266                           shared_partial_reduction[(y + k) * blockDim.x + x]);
267     }
268     __syncthreads();
269   }
270   return reduced;
271 }
272 
273 // This kernel uses a 2D thread decomposition. The x dimension maps to the inner
274 // dimension of the input/output. The y grid dimension maps to segments, and y
275 // threads within a block cooperate to reduce over the block's segment.
276 // Note that Tinit is needed because Tvec and Treducevec may be vector types,
277 // but Tinit is always a scalar type.
278 // Note that the first dimension of input_vec is nouter if indices is not
279 // provided; otherwise it is indexed indirectly via indices and can have any
280 // size (as long as it spans at least the maximum value in indices). This also
281 // applies to the weights vector.
282 template <typename Treducevec, typename Tvec, typename Tindex,
283           typename Tsegmentids, typename ReduceOp, typename Tinit>
284 __global__ void SegmentReduceVectorKernel(
285     Tindex nouter, Tindex ninner_vec, Tsegmentids nsegments, ReduceOp reduce_op,
286     Tinit initial_value, Tinit empty_segment_value, bool is_mean, bool is_sqrtn,
287     const Tvec* __restrict__ input_vec,          // [nouter or any, ninner_vec]
288     const Tindex* __restrict__ segment_offsets,  // [nsegments + 1]
289     const Tindex* __restrict__ indices,          // [nouter] (optional)
290     const Tinit* __restrict__ weights,           // [nouter or any] (optional)
291     Tvec* __restrict__ output_vec) {             // [nsegments, ninner_vec]
292   const int num_blocks_x = (ninner_vec - 1) / blockDim.x + 1;
293   // Grid-stride loop over inner dimension blocks.
294   for (Tindex blk_x = blockIdx.x; blk_x < num_blocks_x; blk_x += gridDim.x) {
295     const Tindex x = threadIdx.x + blk_x * blockDim.x;
296     const Tindex y = threadIdx.y;
297     const bool x_ok = x < ninner_vec;
298     // Grid-stride loop over segment blocks, each processing one segment.
299     for (Tsegmentids seg = blockIdx.y; seg < nsegments; seg += gridDim.y) {
300       // Load segment range.
301       const Tindex begin = segment_offsets[seg];
302       const Tindex end = segment_offsets[seg + 1];
303       // Reduce over the segment.
304       Treducevec result = Treducevec(initial_value);
305       // Loop over the segment, reducing blockDim.y elements at a time.
306       for (Tindex y_offset = begin; y_offset < end; y_offset += blockDim.y) {
307         const bool y_ok = (y_offset + y) < end;
308         // Perform indirect lookup if required.
309         const Tindex y_idx =
310             indices && y_ok ? indices[y_offset + y] : y_offset + y;
311         const int64_t input_idx = static_cast<int64_t>(y_idx) * ninner_vec + x;
312         // Load the input row from global mem.
313         Treducevec block_result =
314             x_ok && y_ok ? input_vec[input_idx] : Tvec(initial_value);
315         // Apply weights if provided.
316         if (weights && y_ok) block_result *= Tvec(weights[y_idx]);
317         // Reduce along the columns of the block, returning result in first row.
318         block_result = ReduceBlockAlongCols(reduce_op, block_result, x_ok);
319         if (y == 0 && x_ok) {
320           result = reduce_op(result, block_result);
321         }
322       }
323       // First row of the block stores the result to global memory.
324       if (y == 0 && x_ok) {
325         if (begin == end) {
326           // Empty segment.
327           result = Treducevec(empty_segment_value);
328         } else {
329           typename RealTypeIfComplex<Tinit>::type total_weight(end - begin);
330           // Normalize the results if necessary.
331           if (is_mean) {
332             result /= Treducevec(total_weight);
333           } else if (is_sqrtn) {
334             result /= Treducevec(sqrt(total_weight));
335           }
336         }
337         // Cast from Treducevec to Tvec.
338         const int64_t output_idx = static_cast<int64_t>(seg) * ninner_vec + x;
339         output_vec[output_idx] = static_cast<Tvec>(result);
340       }
341     }
342   }
343 }
344 
345 // Reduces input matrix within segments over the outer dimension. Empty segments
346 // always output empty_segment_value.
347 // If is_mean or is_sqrtn is true, the results are normalized using the
348 // corresponding function.
349 // If indices is not nullptr, input rows are accessed indirectly as
350 // input[indices[i]], instead of input[i].
351 // Note: Treducevec is to allow reducing in higher precision than Tvec.
352 template <typename Treducevec, typename Tvec, typename Tindex,
353           typename Tsegmentids, typename ReduceOp, typename Tinit>
354 Status LaunchSegmentReduceVectorKernel(
355     const GPUDevice& d, Tindex nouter, Tindex ninner_vec, Tsegmentids nsegments,
356     ReduceOp reduce_op, Tinit initial_value, Tinit empty_segment_value,
357     bool is_mean, bool is_sqrtn,
358     const Tvec* input_vec,          // [nouter or any, ninner_vec]
359     const Tindex* segment_offsets,  // [nsegments + 1]
360     const Tindex* indices,          // [nouter] (optional)
361     const Tinit* weights,           // [nouter or any] (optional)
362     Tvec* output_vec) {             // [nsegments, ninner_vec]
363   static constexpr const int kMaxGridX = (1u << 31) - 1;
364   static constexpr const int kMaxGridY = (1u << 16) - 1;
365   const int max_block_size = 1024;  // Can be tuned for perf (<= 1024)
366   const int min_block_size = 64;    // Can be tuned for perf
367   const Tindex ninner_pow2 = Tindex(1) << Log2Ceiling64(ninner_vec);
368   // This is a heuristic that first allocates threads in the block to the inner
369   // (x) dimension (which is most efficient) and then allocates the rest to the
370   // reduction (y) dimension (which is less efficient but increases
371   // parallelism).
372   int block_x = std::min(ninner_pow2, static_cast<Tindex>(max_block_size));
373   const Tindex avg_reduce_size =
374       Eigen::divup(nouter, static_cast<Tindex>(nsegments));
375   const Tindex avg_reduce_size_pow2 = Tindex(1)
376                                       << Log2Ceiling64(avg_reduce_size);
377   dim3 block(
378       block_x,
379       std::min(static_cast<Tindex>(Eigen::divup(min_block_size, block_x)),
380                avg_reduce_size_pow2));
381   dim3 grid(std::min(Eigen::divup(ninner_vec, static_cast<Tindex>(block.x)),
382                      static_cast<Tindex>(kMaxGridX)),
383             std::min(nsegments, static_cast<Tsegmentids>(kMaxGridY)));
384   unsigned shared_memory_bytes = block.x * block.y * sizeof(Treducevec);
385   return GpuLaunchKernel(
386       SegmentReduceVectorKernel<Treducevec, Tvec, Tindex, Tsegmentids, ReduceOp,
387                                 Tinit>,
388       grid, block, shared_memory_bytes, d.stream(), nouter, ninner_vec,
389       nsegments, reduce_op, initial_value, empty_segment_value, is_mean,
390       is_sqrtn, input_vec, segment_offsets, indices, weights, output_vec);
391 }
392 
393 template <typename Tvec, typename Treducevec, typename Tindex,
394           typename Tsegmentids, typename Tinit>
395 __global__ void SegmentReduceEpilogueKernel(
396     Tsegmentids nsegments, Tinit empty_segment_value, bool is_mean,
397     bool is_sqrtn,
398     const Treducevec* __restrict__ output_raw,   // [nsegments]
399     const Tindex* __restrict__ segment_offsets,  // [nsegments + 1]
400     Tvec* __restrict__ output) {                 // [nsegments]
401   GPU_1D_KERNEL_LOOP(seg, nsegments) {
402     Tindex segment_size = segment_offsets[seg + 1] - segment_offsets[seg];
403     Treducevec val = output_raw[seg];
404     if (segment_size == 0) {
405       // Empty segment.
406       val = Treducevec(empty_segment_value);
407     } else if (is_mean) {
408       val /= Treducevec(segment_size);
409     } else if (is_sqrtn) {
410       val /= Treducevec(
411           sqrt(typename RealTypeIfComplex<Tinit>::type(segment_size)));
412     }
413     // Cast from Treducevec to Tvec.
414     output[seg] = static_cast<Tvec>(val);
415   }
416 }
417 
418 // Normalizes output_raw based on segment size and casts from Treducevec to
419 // Tvec. If Tvec == Treducevec, this is safe to call with output_raw == output.
420 // Note that Treducevec is the type that was used for the reduction, which may
421 // be a higher-precision type than the output type Tvec (e.g., float vs. half).
422 template <typename Tvec, typename Treducevec, typename Tindex,
423           typename Tsegmentids, typename Tinit>
424 Status LaunchSegmentReduceEpilogueKernel(
425     const GPUDevice& d, Tsegmentids nsegments, Tinit empty_segment_value,
426     bool is_mean, bool is_sqrtn,
427     const Treducevec* output_raw,   // [nsegments]
428     const Tindex* segment_offsets,  // [nsegments + 1]
429     Tvec* output) {                 // [nsegments]
430   GpuLaunchConfig config = GetGpuLaunchConfig(
431       nsegments, d,
432       &SegmentReduceEpilogueKernel<Tvec, Treducevec, Tindex, Tsegmentids,
433                                    Tinit>,
434       /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
435   return GpuLaunchKernel(
436       SegmentReduceEpilogueKernel<Tvec, Treducevec, Tindex, Tsegmentids, Tinit>,
437       config.block_count, config.thread_per_block, 0, d.stream(), nsegments,
438       empty_segment_value, is_mean, is_sqrtn, output_raw, segment_offsets,
439       output);
440 }
441 
442 template <typename Tto>
443 struct CastFunctor {
444   template <typename T>
445   __device__ Tto operator()(const T& val) const {
446     return static_cast<Tto>(val);
447   }
448 };
449 
450 template <typename Treducevec, typename Tvec, typename Tindex, typename Tinit>
451 struct LookupAndScaleAndCastInputsFunctor {
452   LookupAndScaleAndCastInputsFunctor(const Tvec* input_vec,
453                                      const Tindex* indices,
454                                      const Tinit* weights)
455       : input_vec_(input_vec), indices_(indices), weights_(weights) {}
456 
457   __device__ Treducevec operator()(Tindex idx) const {
458     if (indices_) idx = indices_[idx];
459     Treducevec result = static_cast<Treducevec>(input_vec_[idx]);
460     if (weights_) result *= Tvec(weights_[idx]);
461     return result;
462   }
463 
464  private:
465   const Tvec* __restrict__ input_vec_;
466   const Tindex* __restrict__ indices_;
467   const Tinit* __restrict__ weights_;
468 };
469 
470 template <typename Treducevec, typename Tvec, typename Tindex, typename Tinit>
471 struct CastIterator {
472   using FunctorTy =
473       LookupAndScaleAndCastInputsFunctor<Treducevec, Tvec, Tindex, Tinit>;
474   using InputIteratorTy = gpuprim::CountingInputIterator<Tindex>;
475   using IteratorTy =
476       gpuprim::TransformInputIterator<Treducevec, FunctorTy, InputIteratorTy>;
477 };
478 
479 template <typename Treducevec, typename Tvec, typename Tindex, typename Tinit>
480 typename CastIterator<Treducevec, Tvec, Tindex, Tinit>::IteratorTy
481 MakeLookupAndScaleAndCastInputsIterator(const Tvec* input_vec,
482                                         const Tindex* indices,
483                                         const Tinit* weights) {
484   using CastIteratorTy = CastIterator<Treducevec, Tvec, Tindex, Tinit>;
485   typename CastIteratorTy::FunctorTy functor(input_vec, indices, weights);
486   return typename CastIteratorTy::IteratorTy(
487       typename CastIteratorTy::InputIteratorTy(Tindex(0)), functor);
488 }
489 
490 template <typename Treducevec, typename Tvec, typename Tindex,
491           typename Tsegmentids, typename ReduceOp, typename Tinit>
492 Status SegmentReduceGPUImplNoInnerDim(
493     OpKernelContext* ctx, Tindex nouter, Tsegmentids nsegments,
494     ReduceOp reduce_op, Tinit initial_value, Tinit empty_segment_value,
495     bool is_mean, bool is_sqrtn,
496     const Tvec* input_vec,          // [nouter or any]
497     const Tindex* segment_offsets,  // [nsegments + 1]
498     const Tindex* indices,          // [nouter] (optional)
499     const Tinit* weights,           // [nouter or any] (optional)
500     Tvec* output_vec) {             // [nsegments]
501   // Here we use gpuprim::DeviceSegmentedReduce (which is optimized for this
502   // shape) and add the additional required functionality using fancy input
503   // iterators and an epilogue kernel.
504 
505   // Note: This reinterpret cast is only needed to avoid compilation error
506   // when Tvec != Treducevec; the result is only used if Tvec == Treducevec.
507   Treducevec* output_raw_ptr = reinterpret_cast<Treducevec*>(output_vec);
508   Tensor output_raw;
509   bool need_temp_output = !std::is_same<Tvec, Treducevec>::value;
510   if (need_temp_output) {
511     // Note: We must allocate and reinterpret as bytes because Treducevec may
512     // be a vector type and they are not supported as Tensor dtypes.
513     TF_RETURN_IF_ERROR(ctx->allocate_temp(
514         DT_INT8,
515         TensorShape({static_cast<int64_t>(nsegments * sizeof(Treducevec))}),
516         &output_raw));
517     output_raw_ptr =
518         reinterpret_cast<Treducevec*>(output_raw.flat<int8>().data());
519   }
520   auto input_iter = MakeLookupAndScaleAndCastInputsIterator<Treducevec>(
521       input_vec, indices, weights);
522   TF_RETURN_IF_ERROR(GpuSegmentedReduce(ctx, nsegments, reduce_op,
523                                         Treducevec(initial_value), input_iter,
524                                         segment_offsets, output_raw_ptr));
525   bool need_epilogue = !std::is_same<Tvec, Treducevec>::value ||
526                        initial_value != empty_segment_value || is_mean ||
527                        is_sqrtn;
528   if (need_epilogue) {
529     const GPUDevice& device = ctx->eigen_gpu_device();
530     // Normalize based on the segment size and cast results back to T.
531     TF_RETURN_IF_ERROR(LaunchSegmentReduceEpilogueKernel(
532         device, nsegments, empty_segment_value, is_mean, is_sqrtn,
533         output_raw_ptr, segment_offsets, output_vec));
534   }
535   return OkStatus();
536 }
537 
538 template <typename Treducevec, typename Tvec, typename Tindex,
539           typename Tsegmentids, typename ReduceOp, typename Tinit>
540 Status SegmentReduceGPUImpl(
541     OpKernelContext* ctx, Tindex nouter, Tindex ninner_vec,
542     Tsegmentids nsegments, ReduceOp reduce_op, Tinit initial_value,
543     Tinit empty_segment_value, bool is_mean, bool is_sqrtn,
544     const Tvec* input_vec,           // [nouter or any, ninner_vec]
545     const Tsegmentids* segment_ids,  // [nouter]
546     const Tindex* indices,           // [nouter] (optional)
547     const Tinit* weights,            // [nouter or any] (optional)
548     Tvec* output_vec) {              // [nsegments, ninner_vec]
549   const GPUDevice& device = ctx->eigen_gpu_device();
550 
551   if (nouter == 0) {
552     // Just set output to empty_segment_value.
553     GPUDevice d = ctx->template eigen_device<GPUDevice>();
554     int64_t output_size = static_cast<int64_t>(nsegments) * ninner_vec;
555     GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d);
556     return GpuLaunchKernel(SetToValue<Tvec, Tinit>, config.block_count,
557                            config.thread_per_block, 0, d.stream(), output_size,
558                            output_vec, empty_segment_value);
559   }
560 
561   // Allocate and compute segment_offsets.
562   Tensor segment_offsets;
563   TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<Tindex>::value,
564                                         TensorShape({nsegments + 1}),
565                                         &segment_offsets));
566   Tindex* segment_offsets_ptr = segment_offsets.flat<Tindex>().data();
567   TF_RETURN_IF_ERROR(LaunchSegmentOffsetsKernel(
568       device, nouter, nsegments, segment_ids, segment_offsets_ptr));
569 
570   const Tindex avg_reduce_size =
571       Eigen::divup(nouter, static_cast<Tindex>(nsegments));
572   // This avg_reduce_size threshold is a performance heuristic.
573   if (ninner_vec == 1 && avg_reduce_size >= 512) {
574     // Here we use a gpuprim-based implementation that doesn't support an
575     // inner dimension but can be significantly faster for large reductions.
576     return SegmentReduceGPUImplNoInnerDim<Treducevec>(
577         ctx, nouter, nsegments, reduce_op, initial_value, empty_segment_value,
578         is_mean, is_sqrtn, input_vec, segment_offsets_ptr, indices, weights,
579         output_vec);
580   }
581   // Here we use a custom kernel that is optimized for ninner_vec >= ~64 and
582   // gives decent performance for smaller cases. It also handles indices,
583   // casting to/from Treducevec, and normalizing the output.
584   return LaunchSegmentReduceVectorKernel<Treducevec>(
585       device, nouter, ninner_vec, nsegments, reduce_op, initial_value,
586       empty_segment_value, is_mean, is_sqrtn, input_vec, segment_offsets_ptr,
587       indices, weights, output_vec);
588 }
589 
590 template <typename Treduce>
591 struct SegmentReduceGPUVectorized {
592   template <int vec_size>
593   struct Impl {
594     template <typename T, typename Tindex, typename Tsegmentids,
595               typename ReduceOp>
596     Status operator()(OpKernelContext* ctx, Tindex nouter, Tindex ninner,
597                       Tsegmentids nsegments, ReduceOp reduce_op,
598                       T initial_value, T empty_segment_value, bool is_mean,
599                       bool is_sqrtn, const T* input,
600                       const Tsegmentids* segment_ids, const Tindex* indices,
601                       const T* weights, T* output) {
602       DCHECK_EQ(ninner % vec_size, 0);
603       DCHECK_EQ(reinterpret_cast<std::uintptr_t>(input) % vec_size, 0);
604       DCHECK_EQ(reinterpret_cast<std::uintptr_t>(output) % vec_size, 0);
605       Tindex ninner_vec = ninner / vec_size;
606       using Tvec = AlignedVector<T, vec_size>;
607       using Treducevec = AlignedVector<Treduce, vec_size>;
608       const Tvec* input_vec = reinterpret_cast<const Tvec*>(input);
609       Tvec* output_vec = reinterpret_cast<Tvec*>(output);
610 
611       return SegmentReduceGPUImpl<Treducevec>(
612           ctx, nouter, ninner_vec, nsegments, reduce_op, initial_value,
613           empty_segment_value, is_mean, is_sqrtn, input_vec, segment_ids,
614           indices, weights, output_vec);
615     }
616   };
617 };
618 
619 // Reduces input matrix within segments over the outer dimension. Empty segments
620 // always output empty_segment_value.
621 // The segment_ids vector must be sorted.
622 // If is_mean or is_sqrtn is true, the results are normalized using the
623 // corresponding function.
624 // If indices is not nullptr, input rows are accessed indirectly as
625 // input[indices[i]], instead of input[i].
626 // The implementation is deterministic.
627 // Note: Treduce is to allow reducing in higher precision than T.
628 template <typename Treduce, typename T, typename Tindex, typename Tsegmentids,
629           typename ReduceOp>
630 Status SegmentReduceGPU(OpKernelContext* ctx, Tindex nouter, Tindex ninner,
631                         Tsegmentids nsegments, ReduceOp reduce_op,
632                         T initial_value, T empty_segment_value, bool is_mean,
633                         bool is_sqrtn,
634                         const T* input,  // [nouter or any, ninner]
635                         const Tsegmentids* segment_ids,  // [nouter]
636                         const Tindex* indices,           // [nouter] (optional)
637                         const T* weights,  // [nouter or any] (optional)
638                         T* output) {       // [nsegments, ninner]
639   if (ninner == 0 || nsegments == 0) return OkStatus();
640   return DispatchToVectorized<
641       T, SegmentReduceGPUVectorized<Treduce>::template Impl>(
642       MinAlignmentOf(input, output, ninner), ctx, nouter, ninner, nsegments,
643       reduce_op, initial_value, empty_segment_value, is_mean, is_sqrtn, input,
644       segment_ids, indices, weights, output);
645 }
646 
647 template <typename SegmentId, typename Index, typename T>
648 __global__ void SegmentWeightsKernel(
649     SegmentId nsegments, SparseSegmentReductionOperation operation,
650     const Index* __restrict__ segment_offsets,  // [nsegments + 1]
651     T* __restrict__ weights) {                  // [nsegments]
652   GPU_1D_KERNEL_LOOP(i, nsegments) {
653     Index segment_size = segment_offsets[i + 1] - segment_offsets[i];
654     segment_size = max(segment_size, Index(1));  // Avoid division by zero
655     if (operation == SparseSegmentReductionOperation::kMean) {
656       weights[i] = T(1) / static_cast<T>(segment_size);
657     } else if (operation == SparseSegmentReductionOperation::kSqrtN) {
658       weights[i] = T(1) / sqrt(static_cast<T>(segment_size));
659     }
660   }
661 }
662 
663 template <typename SegmentId, typename Index, typename T>
664 Status LaunchSegmentWeightsKernel(
665     const GPUDevice& d, SegmentId nsegments,
666     SparseSegmentReductionOperation operation,
667     const Index* segment_offsets,  // [nsegments + 1]
668     T* weights) {                  // [nsegments]
669   GpuLaunchConfig config = GetGpuLaunchConfig(
670       nsegments, d, &SegmentWeightsKernel<SegmentId, Index, T>,
671       /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
672   return GpuLaunchKernel(SegmentWeightsKernel<SegmentId, Index, T>,
673                          config.block_count, config.thread_per_block, 0,
674                          d.stream(), nsegments, operation, segment_offsets,
675                          weights);
676 }
677 
678 template <typename ReduceOp, typename T>
679 struct ReduceType {
680   using type = T;
681 };
682 
683 // Sum fp16 values using an fp32 accumulator to avoid numerical issues.
684 template <>
685 struct ReduceType<functor::Sum, Eigen::half> {
686   using type = float;
687 };
688 
689 namespace functor {
690 
691 template <typename T, typename Index, typename InitialValueF,
692           typename EmptySegmentValueF, typename ReductionF>
693 void SegmentReductionFunctor<
694     T, Index, InitialValueF, EmptySegmentValueF,
695     ReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d,
696                             const Index output_rows,
697                             const TensorShape& segment_ids_shape, bool is_mean,
698                             typename TTypes<Index>::ConstFlat segment_ids,
699                             const Index data_size, const T* data,
700                             typename TTypes<T, 2>::Tensor output) {
701   if (output.size() == 0) {
702     return;
703   }
704 
705   // Launch kernel(s) to compute sorted segment reduction.
706   // Notes:
707   // *) 'input_total_size' is the total number of elements to process.
708   // *) 'segment_ids.shape' is a prefix of data's shape.
709   // *) 'input_outer_dim_size' is the total number of segments to process.
710   const Index input_total_size = data_size;
711   const Index input_outer_dim_size = segment_ids.dimension(0);
712   const Index input_inner_dim_size = input_total_size / input_outer_dim_size;
713   const Index num_segments = output.size() / input_inner_dim_size;
714 
715   bool use_deterministic_kernels =
716 #if defined(PLATFORM_WINDOWS)
717       // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
718       // build error.
719       false;
720 #else
721       UseDeterministicSegmentReductions() ||
722       (OpDeterminismRequired() && !ReduceOpIsAssociative<ReductionF, T>::value);
723 #endif
724 
725   // TODO(benbarsdell): If there are no performance concerns with the new
726   // deterministic kernels, remove this runtime check and only compile the old
727   // non-deterministic kernels on Windows (as a workaround for the build failure
728   // issue).
729   if (!use_deterministic_kernels) {
730     // Set 'output' to initial value.
731     GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
732     const T InitialValue = InitialValueF()();
733     TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count,
734                                 config.thread_per_block, 0, d.stream(),
735                                 output.size(), output.data(), InitialValue));
736     if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
737       return;
738     }
739 
740     const int OuterDimTileSize = 8;
741 
742     const Index input_outer_dim_num_stripe =
743         Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
744 
745     const Index total_stripe_count =
746         input_inner_dim_size * input_outer_dim_num_stripe;
747 
748     config = GetGpuLaunchConfig(total_stripe_count, d);
749     TF_CHECK_OK(GpuLaunchKernel(
750         SortedSegmentReductionCustomKernel<
751             T, Index, OuterDimTileSize,
752             typename ReduceUpdateOpFor<ReductionF>::nonatomic_op,
753             typename ReduceUpdateOpFor<ReductionF>::atomic_op>,
754         config.block_count, config.thread_per_block, 0, d.stream(),
755         input_outer_dim_size, input_inner_dim_size, output_rows,
756         segment_ids.data(), data, output.data(), total_stripe_count,
757         InitialValue));
758 
759     if (is_mean) {
760       Tensor segment_offsets;
761       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<Index>::value,
762                                              TensorShape({num_segments + 1}),
763                                              &segment_offsets));
764       Index* segment_offsets_ptr = segment_offsets.flat<Index>().data();
765       OP_REQUIRES_OK(ctx, LaunchSegmentOffsetsKernel(
766                               d, input_outer_dim_size, num_segments,
767                               segment_ids.data(), segment_offsets_ptr));
768 
769       OP_REQUIRES_OK(ctx, LaunchSegmentMeanNormalizeKernel(
770                               d, num_segments, input_inner_dim_size,
771                               segment_offsets_ptr, output.data()));
772     }
773   } else {
774     // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
775     // build error.
776 #if !defined(PLATFORM_WINDOWS)
777     using Treduce = typename ReduceType<ReductionF, T>::type;
778     OP_REQUIRES_OK(
779         ctx,
780         SegmentReduceGPU<Treduce>(
781             ctx, input_outer_dim_size, input_inner_dim_size, num_segments,
782             ReductionF(), InitialValueF()(), EmptySegmentValueF()(),
783             /*is_mean=*/is_mean, /*is_sqrtn=*/false, data, segment_ids.data(),
784             /*indices=*/static_cast<const Index*>(nullptr),
785             /*weights=*/static_cast<T*>(nullptr), output.data()));
786 #else
787     // Note: Shouldn't reach here because use_deterministic_kernels is always
788     // false on Windows.
789     OP_REQUIRES(ctx, false,
790                 errors::Unimplemented("Deterministic segment reductions are "
791                                       "not implemented on Windows."));
792 #endif
793   }
794 }
795 
796 template <typename T, typename Index, typename InitialValueF,
797           typename ReductionF>
798 struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
799   void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
800                   typename TTypes<Index>::ConstFlat unsorted_segment_ids,
801                   typename TTypes<T, 2>::ConstTensor data,
802                   typename TTypes<T, 2>::Tensor output) {
803     if (output.size() == 0) {
804       return;
805     }
806 
807     bool use_deterministic_kernels =
808 #if defined(PLATFORM_WINDOWS)
809         // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
810         // build error.
811         false;
812 #else
813         UseDeterministicSegmentReductions() ||
814         (!ReduceOpIsAssociative<ReductionF, T>::value &&
815          OpDeterminismRequired());
816 #endif
817 
818     bool determinism_requirement_met =
819         use_deterministic_kernels ||
820         ReduceOpIsAssociative<ReductionF, T>::value ||
821         !OpDeterminismRequired() ||
822         DisableSegmentReductionOpDeterminismExceptions();
823     OP_REQUIRES(
824         ctx, determinism_requirement_met,
825         errors::Unimplemented(
826             "Deterministic GPU implementation of unsorted segment reduction op"
827             " not available."));
828 
829     // Launch kernel(s) to compute unsorted segment reduction.
830     // Notes:
831     // *) 'data_size' is the total number of elements to process.
832     // *) 'segment_ids.shape' is a prefix of data's shape.
833     // *) 'input_outer_dim_size' is the total number of segments to process.
834     const Index input_outer_dim_size = unsorted_segment_ids.dimension(0);
835     const Index input_inner_dim_size = data.dimension(1);
836     const Index output_outer_dim_size = output.dimension(0);
837     const Index num_segments = output.size() / input_inner_dim_size;
838 
839     // TODO(benbarsdell): If there are no performance concerns with the new
840     // deterministic kernels, remove this runtime check and only compile the old
841     // non-deterministic kernels on Windows (as a workaround for the build
842     // failure issue).
843     if (!use_deterministic_kernels) {
844       // Set 'output' to initial value.
845       GPUDevice d = ctx->template eigen_device<GPUDevice>();
846       GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
847       TF_CHECK_OK(GpuLaunchKernel(
848           SetToValue<T>, config.block_count, config.thread_per_block, 0,
849           d.stream(), output.size(), output.data(), InitialValueF()()));
850       const int64_t data_size = data.size();
851       if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
852         return;
853       }
854       config = GetGpuLaunchConfig(data_size, d);
855       TF_CHECK_OK(GpuLaunchKernel(
856           UnsortedSegmentCustomKernel<
857               T, Index, typename ReduceUpdateOpFor<ReductionF>::atomic_op>,
858           config.block_count, config.thread_per_block, 0, d.stream(),
859           input_outer_dim_size, input_inner_dim_size, output_outer_dim_size,
860           unsorted_segment_ids.data(), data.data(), output.data()));
861     } else {
862       // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
863       // build error.
864 #if !defined(PLATFORM_WINDOWS)
865       // Allocate temporary space and sort segment_ids, then call the sorted
866       // implem.
867       Tensor segment_ids;
868       OP_REQUIRES_OK(
869           ctx, ctx->allocate_temp(
870                    DataTypeToEnum<Index>::value,
871                    TensorShape({static_cast<int64_t>(input_outer_dim_size)}),
872                    &segment_ids));
873       Index* segment_ids_ptr = segment_ids.flat<Index>().data();
874       Tensor sorted_indices;
875       OP_REQUIRES_OK(
876           ctx, ctx->allocate_temp(
877                    DataTypeToEnum<Index>::value,
878                    TensorShape({static_cast<int64_t>(input_outer_dim_size)}),
879                    &sorted_indices));
880       Index* sorted_indices_ptr = sorted_indices.flat<Index>().data();
881       // Note: We must sort using all bits here because unsorted_segment_ids
882       // may contain negative values.
883       OP_REQUIRES_OK(
884           ctx, GpuRadixSort(ctx, input_outer_dim_size,
885                             /*keys_in=*/unsorted_segment_ids.data(),
886                             /*keys_out=*/segment_ids_ptr,
887                             /*indices_in=*/static_cast<const Index*>(nullptr),
888                             /*indices_out=*/sorted_indices_ptr));
889       using Treduce = typename ReduceType<ReductionF, T>::type;
890       OP_REQUIRES_OK(
891           ctx,
892           SegmentReduceGPU<Treduce>(
893               ctx, input_outer_dim_size, input_inner_dim_size, num_segments,
894               ReductionF(), /*initial_value=*/InitialValueF()(),
895               /*empty_segment_value=*/InitialValueF()(), /*is_mean=*/false,
896               /*is_sqrtn=*/false, /*input=*/data.data(),
897               /*segment_ids=*/segment_ids_ptr, /*indices=*/sorted_indices_ptr,
898               /*weights=*/static_cast<T*>(nullptr), output.data()));
899 #else
900       // Note: Shouldn't reach here because use_deterministic_kernels is always
901       // false on Windows.
902       OP_REQUIRES(
903           ctx, false,
904           errors::Unimplemented("Deterministic unsorted segment reductions are "
905                                 "not implemented on Windows."));
906 #endif
907     }
908   }
909 };
910 
911 template <typename T, typename Index, typename SegmentId>
912 Status SparseSegmentReductionFunctor<T, Index, SegmentId>::operator()(
913     OpKernelContext* context, bool is_mean, bool is_sqrtn, T default_value,
914     typename TTypes<T, 2>::ConstTensor input,
915     typename TTypes<Index>::ConstVec indices,
916     typename TTypes<SegmentId>::ConstVec segment_ids,
917     typename TTypes<T, 2>::Tensor output) {
918   using ReduceOp = functor::Sum;
919   using Treduce = typename ReduceType<ReduceOp, T>::type;
920   Index nouter = segment_ids.size();
921   Index ninner = input.dimension(1);
922   SegmentId nsegments = output.dimension(0);
923   return SegmentReduceGPU<Treduce>(
924       context, /*nouter=*/nouter, /*ninner=*/ninner,
925       /*nsegments=*/nsegments, /*reduce_op=*/ReduceOp(),
926       /*initial_value=*/T(0),
927       /*empty_segment_value=*/default_value,
928       /*is_mean=*/is_mean, /*is_sqrtn=*/is_sqrtn,
929       /*input=*/input.data(), /*segment_ids=*/segment_ids.data(),
930       /*indices=*/indices.data(), /*weights=*/static_cast<T*>(nullptr),
931       /*output=*/output.data());
932 }
933 
934 template <typename T, typename Index, typename SegmentId>
935 struct SparseSegmentGradFunctor<GPUDevice, T, Index, SegmentId> {
936   void operator()(OpKernelContext* context,
937                   SparseSegmentReductionOperation operation,
938                   typename TTypes<T>::ConstMatrix input_flat,
939                   typename TTypes<Index>::ConstVec indices_vec,
940                   typename TTypes<SegmentId>::ConstVec segment_vec,
941                   typename TTypes<T>::Matrix output_flat) {
942     const GPUDevice& device = context->eigen_gpu_device();
943 
944     const SegmentId nsegments = input_flat.dimension(0);
945     const Index ninner = input_flat.dimension(1);
946     const Index nouter = indices_vec.dimension(0);
947     const Index noutput = output_flat.dimension(0);
948 
949     // Allocate and compute segment weights (for Mean/SqrtN operations only).
950     Tensor weights;
951     T* weights_ptr = nullptr;
952     if (operation != SparseSegmentReductionOperation::kSum) {
953       OP_REQUIRES_OK(
954           context, context->allocate_temp(DataTypeToEnum<T>::value,
955                                           TensorShape({nsegments}), &weights));
956       weights_ptr = weights.flat<T>().data();
957       // Allocate and compute segment_offsets.
958       Tensor segment_offsets;
959       OP_REQUIRES_OK(context,
960                      context->allocate_temp(DataTypeToEnum<Index>::value,
961                                             TensorShape({nsegments + 1}),
962                                             &segment_offsets));
963       Index* segment_offsets_ptr = segment_offsets.flat<Index>().data();
964       OP_REQUIRES_OK(context, LaunchSegmentOffsetsKernel(
965                                   device, nouter, nsegments, segment_vec.data(),
966                                   segment_offsets_ptr));
967       // Compute the weights based on the segment sizes using segment_offsets.
968       OP_REQUIRES_OK(context, LaunchSegmentWeightsKernel(
969                                   device, nsegments, operation,
970                                   segment_offsets_ptr, weights_ptr));
971     }
972 
973     const Index* sorted_indices_ptr = indices_vec.data();
974     const SegmentId* sorted_segment_ptr = segment_vec.data();
975     Tensor tmp_sorted_indices;
976     Tensor tmp_sorted_segment;
977     if (noutput > 1) {
978       // Sort indices and permute segments.
979       OP_REQUIRES_OK(context, context->allocate_temp(
980                                   DataTypeToEnum<Index>::value,
981                                   TensorShape({nouter}), &tmp_sorted_indices));
982       Index* tmp_sorted_indices_ptr = tmp_sorted_indices.flat<Index>().data();
983       OP_REQUIRES_OK(context, context->allocate_temp(
984                                   DataTypeToEnum<SegmentId>::value,
985                                   TensorShape({nouter}), &tmp_sorted_segment));
986       SegmentId* tmp_sorted_segment_ptr =
987           tmp_sorted_segment.flat<SegmentId>().data();
988       OP_REQUIRES_OK(context,
989                      GpuRadixSort(context, nouter,
990                                   /*keys_in=*/indices_vec.data(),
991                                   /*keys_out=*/tmp_sorted_indices_ptr,
992                                   /*indices_in=*/segment_vec.data(),
993                                   /*indices_out=*/tmp_sorted_segment_ptr,
994                                   /*num_bits=*/Log2Ceiling64(noutput)));
995       sorted_indices_ptr = tmp_sorted_indices_ptr;
996       sorted_segment_ptr = tmp_sorted_segment_ptr;
997     }
998 
999     // Compute the gradient using a weighted SegmentReduceGPU with the segment
1000     // IDs and indices swapped.
1001     using ReduceOp = gpuprim::Sum;
1002     using Treduce = typename ReduceType<ReduceOp, T>::type;
1003     OP_REQUIRES_OK(
1004         context,
1005         SegmentReduceGPU<Treduce>(
1006             context, /*nouter=*/static_cast<SegmentId>(nouter),
1007             /*ninner=*/static_cast<SegmentId>(ninner),
1008             /*nsegments=*/noutput,
1009             /*reduce_op=*/ReduceOp(),
1010             /*initial_value=*/T(0),
1011             /*empty_segment_value=*/T(0),
1012             /*is_mean=*/false, /*is_sqrtn=*/false,
1013             /*input=*/input_flat.data(), /*segment_ids=*/sorted_indices_ptr,
1014             /*indices=*/sorted_segment_ptr, /*weights=*/weights_ptr,
1015             /*output=*/output_flat.data()));
1016   }
1017 };
1018 
1019 }  // namespace functor
1020 }  // namespace tensorflow
1021 
1022 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1023