1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/kernels/gpu_prim.h"
22 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
23 #include "tensorflow/core/kernels/segment_reduction_ops.h"
24 #include "tensorflow/core/lib/core/bits.h"
25 #include "tensorflow/core/util/determinism.h"
26 #include "tensorflow/core/util/env_var.h"
27 #include "tensorflow/core/util/gpu_device_functions.h"
28 #include "tensorflow/core/util/gpu_kernel_helper.h"
29 #include "tensorflow/core/util/permutation_input_iterator.h"
30
31 namespace tensorflow {
32
33 using GPUDevice = Eigen::GpuDevice;
34
35 // Non/Atomic reduction functors for the gpu.
36 #define DEFINE_REDUCE_UPDATE_OP_GPU(name, func) \
37 struct name##OpGpu { \
38 template <typename T> \
39 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, \
40 const T& value) { \
41 func; \
42 } \
43 };
44 DEFINE_REDUCE_UPDATE_OP_GPU(AtomicSum, GpuAtomicAdd(dest, value))
45 DEFINE_REDUCE_UPDATE_OP_GPU(AtomicProd, GpuAtomicMul(dest, value))
46 DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMax, GpuAtomicMax(dest, value))
47 DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMin, GpuAtomicMin(dest, value))
48 DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicSum, *dest += value)
49 DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicProd, *dest *= value)
50 DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMax, *dest = max(*dest, value))
51 DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMin, *dest = min(*dest, value))
52 #undef DEFINE_REDUCE_UPDATE_OP_GPU
53
54 template <typename ReduceOp>
55 struct ReduceUpdateOpFor {};
56
57 #define DEFINE_REDUCE_UPDATE_OP_FOR(reduce_op, atomic, nonatomic) \
58 template <> \
59 struct ReduceUpdateOpFor<reduce_op> { \
60 using atomic_op = atomic; \
61 using nonatomic_op = nonatomic; \
62 };
DEFINE_REDUCE_UPDATE_OP_FOR(functor::Sum,AtomicSumOpGpu,NonAtomicSumOpGpu)63 DEFINE_REDUCE_UPDATE_OP_FOR(functor::Sum, AtomicSumOpGpu, NonAtomicSumOpGpu)
64 DEFINE_REDUCE_UPDATE_OP_FOR(functor::Prod, AtomicProdOpGpu, NonAtomicProdOpGpu)
65 DEFINE_REDUCE_UPDATE_OP_FOR(functor::Max, AtomicMaxOpGpu, NonAtomicMaxOpGpu)
66 DEFINE_REDUCE_UPDATE_OP_FOR(functor::Min, AtomicMinOpGpu, NonAtomicMinOpGpu)
67 #undef DEFINE_REDUCE_UPDATE_OP_FOR
68
69 // SortedSegmentReductionFunctor kernel reduces input data just as
70 // UnsortedSegmentReductionCustomKernel does except that input data
71 // is partitioned along the outer reduction dimension. This is
72 // because consecutive rows (elements in a row share the same
73 // outer dimension index) in the flattened 2D input data likely
74 // belong to the same segment in sorted segment sum operation.
75 // Therefore such partitioning strategy has two advantages over
76 // the UnsortedSegmentReductionFunctor kernel:
77 // 1. Each thread reduces across multiple rows before writing
78 // answers to the global memory, we can therefore
79 // write reduction results to global memory less often.
80 // 2. We may know that the current thread is the only contributor
81 // to an output element because of the increasing nature of segment
82 // ids. In such cases, we do not need to use atomic operations
83 // to write results to global memory.
84 // In the flattened view of input data (with only outer and inner
85 // dimension), every thread processes a strip of input data of
86 // size OuterDimTileSize x 1. This strip runs across multiple
87 // rows of input data and all reduction elements share one inner
88 // dimension index.
89 template <typename T, typename Index, int OuterDimTileSize, typename ReductionF,
90 typename AtomicReductionF>
91 __global__ void SortedSegmentReductionCustomKernel(
92 const Index input_outer_dim_size, const Index inner_dim_size,
93 const Index output_outer_dim_size, const Index* __restrict__ segment_ids,
94 const T* __restrict__ input, T* __restrict__ output,
95 const Index total_stripe_count, const T initial_value) {
96 for (int stripe_index : GpuGridRangeX(total_stripe_count)) {
97 const Index segment_offset = stripe_index % inner_dim_size;
98 const Index input_outer_dim_index_base =
99 stripe_index / inner_dim_size * Index(OuterDimTileSize);
100
101 T reduce_res = initial_value;
102 Index first_segment_id = segment_ids[input_outer_dim_index_base];
103 Index last_output_segment_id = output_outer_dim_size;
104
105 const Index actual_stripe_height =
106 min(Index(OuterDimTileSize),
107 input_outer_dim_size - input_outer_dim_index_base);
108 for (Index j = 0; j < actual_stripe_height; j++) {
109 Index current_output_segment_id =
110 segment_ids[input_outer_dim_index_base + j];
111 // Decide whether to write result to global memory. Result is only written
112 // to global memory if we move to another segment. Otherwise we can keep
113 // accumulating locally.
114 if (current_output_segment_id > last_output_segment_id) {
115 const Index output_index =
116 last_output_segment_id * inner_dim_size + segment_offset;
117 // Decide whether to write result to global memory using atomic
118 // operations.
119 if (last_output_segment_id == first_segment_id) {
120 AtomicReductionF()(output + output_index, reduce_res);
121 } else {
122 ReductionF()(output + output_index, reduce_res);
123 }
124 reduce_res = initial_value;
125 }
126 ReductionF()(
127 &reduce_res,
128 ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
129 segment_offset));
130 last_output_segment_id = current_output_segment_id;
131 }
132 // For the last result in a strip, always write using atomic operations
133 // due to possible race conditions with threads computing
134 // the following strip.
135 const Index output_index =
136 last_output_segment_id * inner_dim_size + segment_offset;
137 AtomicReductionF()(output + output_index, reduce_res);
138 }
139 }
140
141 template <typename SegmentId, typename Index, typename T>
SegmentMeanNormalizeKernel(SegmentId nsegments,Index ninner,const Index * __restrict__ segment_offsets,T * __restrict__ output)142 __global__ void SegmentMeanNormalizeKernel(
143 SegmentId nsegments, Index ninner,
144 const Index* __restrict__ segment_offsets, // [nsegments + 1]
145 T* __restrict__ output) { // [nsegments, ninner]
146 for (SegmentId seg : GpuGridRangeY(nsegments)) {
147 SegmentId segment_size = segment_offsets[seg + 1] - segment_offsets[seg];
148 segment_size = max(segment_size, Index(1)); // Avoid division by zero
149 T inv_norm = T(1) / static_cast<T>(segment_size);
150 for (Index i : GpuGridRangeX(ninner)) {
151 output[seg * ninner + i] *= inv_norm;
152 }
153 }
154 }
155
156 template <typename SegmentId, typename Index, typename T>
LaunchSegmentMeanNormalizeKernel(const GPUDevice & d,SegmentId nsegments,Index ninner,const Index * __restrict__ segment_offsets,T * __restrict__ output)157 Status LaunchSegmentMeanNormalizeKernel(
158 const GPUDevice& d, SegmentId nsegments, Index ninner,
159 const Index* __restrict__ segment_offsets, // [nsegments + 1]
160 T* __restrict__ output) { // [nsegments, ninner]
161 Gpu2DLaunchConfig config = GetGpu2DLaunchConfig(
162 ninner, nsegments, d, SegmentMeanNormalizeKernel<SegmentId, Index, T>,
163 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
164 return GpuLaunchKernel(SegmentMeanNormalizeKernel<SegmentId, Index, T>,
165 config.block_count, config.thread_per_block, 0,
166 d.stream(), nsegments, ninner, segment_offsets,
167 output);
168 }
169
170 // UnsortedSegmentSumKernel processes 'input_total_size' elements.
171 // Each element is mapped from input to output by a combination of its
172 // 'segment_ids' mapping and 'inner_dim_size'.
173 template <typename T, typename Index, typename KernelReductionFunctor>
UnsortedSegmentCustomKernel(const int64_t input_outer_dim_size,const int64_t inner_dim_size,const int64_t output_outer_dim_size,const Index * __restrict__ segment_ids,const T * __restrict__ input,T * __restrict__ output)174 __global__ void UnsortedSegmentCustomKernel(
175 const int64_t input_outer_dim_size, const int64_t inner_dim_size,
176 const int64_t output_outer_dim_size, const Index* __restrict__ segment_ids,
177 const T* __restrict__ input, T* __restrict__ output) {
178 const int64_t input_total_size = input_outer_dim_size * inner_dim_size;
179 for (int64_t input_index : GpuGridRangeX(input_total_size)) {
180 const int64_t input_segment_index = input_index / inner_dim_size;
181 const int64_t segment_offset = input_index % inner_dim_size;
182 const Index output_segment_index = segment_ids[input_segment_index];
183 if (output_segment_index < 0 ||
184 output_segment_index >= output_outer_dim_size) {
185 continue;
186 }
187 const int64_t output_index =
188 output_segment_index * inner_dim_size + segment_offset;
189 KernelReductionFunctor()(output + output_index, ldg(input + input_index));
190 }
191 }
192
193 template <typename Tindex, typename Tsegmentids>
SegmentOffsetsKernel(Tindex size,Tsegmentids nsegments,const Tsegmentids * __restrict__ segment_ids,Tindex * __restrict__ segment_offsets)194 __global__ void SegmentOffsetsKernel(
195 Tindex size, Tsegmentids nsegments,
196 const Tsegmentids* __restrict__ segment_ids, // [size]
197 Tindex* __restrict__ segment_offsets) { // [nsegments + 1]
198 GPU_1D_KERNEL_LOOP(i, size + 1) {
199 // IDs are clipped to [-1, nsegments] so that out-of-bounds IDs are ignored.
200 // Note that we can't report invalid IDs from the GPU without incurring
201 // additional overhead.
202 auto clip = [&](Tsegmentids id) {
203 return min(max(Tsegmentids(-1), id), nsegments);
204 };
205 const Tsegmentids cur_id = (i < size) ? clip(segment_ids[i]) : nsegments;
206 const Tsegmentids prev_id =
207 (i == 0) ? Tsegmentids(-1) : clip(segment_ids[i - 1]);
208 // At segment boundaries, write the offset for this ID and any missing IDs
209 // since the previous one.
210 for (Tsegmentids id = prev_id + 1; id <= cur_id; ++id) {
211 segment_offsets[id] = i;
212 }
213 }
214 }
215
216 // Finds the start offset of each segment in the given sorted segment_ids
217 // vector. Missing IDs are given the same offset as the next ID so that they
218 // represent empty ranges. Invalid IDs (those that are outside the range
219 // [0, nsegments)) are ignored. The value at segment_offsets[0] is set to the
220 // start index of the first valid ID (e.g., 0 if all IDs are valid), and the
221 // value at segment_offsets[nsegments] is set to the end index of the last valid
222 // ID (e.g., nsegments if all IDs are valid).
223 template <typename Tindex, typename Tsegmentids>
LaunchSegmentOffsetsKernel(const GPUDevice & d,Tindex size,Tsegmentids nsegments,const Tsegmentids * segment_ids,Tindex * segment_offsets)224 Status LaunchSegmentOffsetsKernel(const GPUDevice& d, Tindex size,
225 Tsegmentids nsegments,
226 const Tsegmentids* segment_ids, // [size]
227 Tindex* segment_offsets) { // [nsegments + 1]
228 GpuLaunchConfig config = GetGpuLaunchConfig(
229 size + 1, d, &SegmentOffsetsKernel<Tindex, Tsegmentids>,
230 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
231 return GpuLaunchKernel(SegmentOffsetsKernel<Tindex, Tsegmentids>,
232 config.block_count, config.thread_per_block, 0,
233 d.stream(), size, nsegments, segment_ids,
234 segment_offsets);
235 }
236
237 template <typename T>
238 struct RealTypeIfComplex {
239 using type = T;
240 };
241
242 template <typename Real>
243 struct RealTypeIfComplex<std::complex<Real>> {
244 using type = Real;
245 };
246
247 // Reduces along columns of the thread block, returning the result in the first
248 // row of threads.
249 template <typename T, typename ReduceOp>
250 __device__ T ReduceBlockAlongCols(ReduceOp reduce_op, const T& value,
251 bool is_valid) {
252 GPU_DYNAMIC_SHARED_MEM_DECL(/*ALIGN=*/16, char, shared_memory_raw);
253 T* const shared_partial_reduction =
254 reinterpret_cast<T*>(shared_memory_raw); // [blockDim.y, blockDim.x]
255 const int x = threadIdx.x;
256 const int y = threadIdx.y;
257 T reduced = value;
258 // Reduce over the y dimension of the block.
259 for (unsigned k = blockDim.y / 2; k > 0; k /= 2) {
260 if (is_valid && y < 2 * k) {
261 shared_partial_reduction[y * blockDim.x + x] = reduced;
262 }
263 __syncthreads();
264 if (is_valid && y < k) {
265 reduced = reduce_op(reduced,
266 shared_partial_reduction[(y + k) * blockDim.x + x]);
267 }
268 __syncthreads();
269 }
270 return reduced;
271 }
272
273 // This kernel uses a 2D thread decomposition. The x dimension maps to the inner
274 // dimension of the input/output. The y grid dimension maps to segments, and y
275 // threads within a block cooperate to reduce over the block's segment.
276 // Note that Tinit is needed because Tvec and Treducevec may be vector types,
277 // but Tinit is always a scalar type.
278 // Note that the first dimension of input_vec is nouter if indices is not
279 // provided; otherwise it is indexed indirectly via indices and can have any
280 // size (as long as it spans at least the maximum value in indices). This also
281 // applies to the weights vector.
282 template <typename Treducevec, typename Tvec, typename Tindex,
283 typename Tsegmentids, typename ReduceOp, typename Tinit>
284 __global__ void SegmentReduceVectorKernel(
285 Tindex nouter, Tindex ninner_vec, Tsegmentids nsegments, ReduceOp reduce_op,
286 Tinit initial_value, Tinit empty_segment_value, bool is_mean, bool is_sqrtn,
287 const Tvec* __restrict__ input_vec, // [nouter or any, ninner_vec]
288 const Tindex* __restrict__ segment_offsets, // [nsegments + 1]
289 const Tindex* __restrict__ indices, // [nouter] (optional)
290 const Tinit* __restrict__ weights, // [nouter or any] (optional)
291 Tvec* __restrict__ output_vec) { // [nsegments, ninner_vec]
292 const int num_blocks_x = (ninner_vec - 1) / blockDim.x + 1;
293 // Grid-stride loop over inner dimension blocks.
294 for (Tindex blk_x = blockIdx.x; blk_x < num_blocks_x; blk_x += gridDim.x) {
295 const Tindex x = threadIdx.x + blk_x * blockDim.x;
296 const Tindex y = threadIdx.y;
297 const bool x_ok = x < ninner_vec;
298 // Grid-stride loop over segment blocks, each processing one segment.
299 for (Tsegmentids seg = blockIdx.y; seg < nsegments; seg += gridDim.y) {
300 // Load segment range.
301 const Tindex begin = segment_offsets[seg];
302 const Tindex end = segment_offsets[seg + 1];
303 // Reduce over the segment.
304 Treducevec result = Treducevec(initial_value);
305 // Loop over the segment, reducing blockDim.y elements at a time.
306 for (Tindex y_offset = begin; y_offset < end; y_offset += blockDim.y) {
307 const bool y_ok = (y_offset + y) < end;
308 // Perform indirect lookup if required.
309 const Tindex y_idx =
310 indices && y_ok ? indices[y_offset + y] : y_offset + y;
311 const int64_t input_idx = static_cast<int64_t>(y_idx) * ninner_vec + x;
312 // Load the input row from global mem.
313 Treducevec block_result =
314 x_ok && y_ok ? input_vec[input_idx] : Tvec(initial_value);
315 // Apply weights if provided.
316 if (weights && y_ok) block_result *= Tvec(weights[y_idx]);
317 // Reduce along the columns of the block, returning result in first row.
318 block_result = ReduceBlockAlongCols(reduce_op, block_result, x_ok);
319 if (y == 0 && x_ok) {
320 result = reduce_op(result, block_result);
321 }
322 }
323 // First row of the block stores the result to global memory.
324 if (y == 0 && x_ok) {
325 if (begin == end) {
326 // Empty segment.
327 result = Treducevec(empty_segment_value);
328 } else {
329 typename RealTypeIfComplex<Tinit>::type total_weight(end - begin);
330 // Normalize the results if necessary.
331 if (is_mean) {
332 result /= Treducevec(total_weight);
333 } else if (is_sqrtn) {
334 result /= Treducevec(sqrt(total_weight));
335 }
336 }
337 // Cast from Treducevec to Tvec.
338 const int64_t output_idx = static_cast<int64_t>(seg) * ninner_vec + x;
339 output_vec[output_idx] = static_cast<Tvec>(result);
340 }
341 }
342 }
343 }
344
345 // Reduces input matrix within segments over the outer dimension. Empty segments
346 // always output empty_segment_value.
347 // If is_mean or is_sqrtn is true, the results are normalized using the
348 // corresponding function.
349 // If indices is not nullptr, input rows are accessed indirectly as
350 // input[indices[i]], instead of input[i].
351 // Note: Treducevec is to allow reducing in higher precision than Tvec.
352 template <typename Treducevec, typename Tvec, typename Tindex,
353 typename Tsegmentids, typename ReduceOp, typename Tinit>
354 Status LaunchSegmentReduceVectorKernel(
355 const GPUDevice& d, Tindex nouter, Tindex ninner_vec, Tsegmentids nsegments,
356 ReduceOp reduce_op, Tinit initial_value, Tinit empty_segment_value,
357 bool is_mean, bool is_sqrtn,
358 const Tvec* input_vec, // [nouter or any, ninner_vec]
359 const Tindex* segment_offsets, // [nsegments + 1]
360 const Tindex* indices, // [nouter] (optional)
361 const Tinit* weights, // [nouter or any] (optional)
362 Tvec* output_vec) { // [nsegments, ninner_vec]
363 static constexpr const int kMaxGridX = (1u << 31) - 1;
364 static constexpr const int kMaxGridY = (1u << 16) - 1;
365 const int max_block_size = 1024; // Can be tuned for perf (<= 1024)
366 const int min_block_size = 64; // Can be tuned for perf
367 const Tindex ninner_pow2 = Tindex(1) << Log2Ceiling64(ninner_vec);
368 // This is a heuristic that first allocates threads in the block to the inner
369 // (x) dimension (which is most efficient) and then allocates the rest to the
370 // reduction (y) dimension (which is less efficient but increases
371 // parallelism).
372 int block_x = std::min(ninner_pow2, static_cast<Tindex>(max_block_size));
373 const Tindex avg_reduce_size =
374 Eigen::divup(nouter, static_cast<Tindex>(nsegments));
375 const Tindex avg_reduce_size_pow2 = Tindex(1)
376 << Log2Ceiling64(avg_reduce_size);
377 dim3 block(
378 block_x,
379 std::min(static_cast<Tindex>(Eigen::divup(min_block_size, block_x)),
380 avg_reduce_size_pow2));
381 dim3 grid(std::min(Eigen::divup(ninner_vec, static_cast<Tindex>(block.x)),
382 static_cast<Tindex>(kMaxGridX)),
383 std::min(nsegments, static_cast<Tsegmentids>(kMaxGridY)));
384 unsigned shared_memory_bytes = block.x * block.y * sizeof(Treducevec);
385 return GpuLaunchKernel(
386 SegmentReduceVectorKernel<Treducevec, Tvec, Tindex, Tsegmentids, ReduceOp,
387 Tinit>,
388 grid, block, shared_memory_bytes, d.stream(), nouter, ninner_vec,
389 nsegments, reduce_op, initial_value, empty_segment_value, is_mean,
390 is_sqrtn, input_vec, segment_offsets, indices, weights, output_vec);
391 }
392
393 template <typename Tvec, typename Treducevec, typename Tindex,
394 typename Tsegmentids, typename Tinit>
395 __global__ void SegmentReduceEpilogueKernel(
396 Tsegmentids nsegments, Tinit empty_segment_value, bool is_mean,
397 bool is_sqrtn,
398 const Treducevec* __restrict__ output_raw, // [nsegments]
399 const Tindex* __restrict__ segment_offsets, // [nsegments + 1]
400 Tvec* __restrict__ output) { // [nsegments]
401 GPU_1D_KERNEL_LOOP(seg, nsegments) {
402 Tindex segment_size = segment_offsets[seg + 1] - segment_offsets[seg];
403 Treducevec val = output_raw[seg];
404 if (segment_size == 0) {
405 // Empty segment.
406 val = Treducevec(empty_segment_value);
407 } else if (is_mean) {
408 val /= Treducevec(segment_size);
409 } else if (is_sqrtn) {
410 val /= Treducevec(
411 sqrt(typename RealTypeIfComplex<Tinit>::type(segment_size)));
412 }
413 // Cast from Treducevec to Tvec.
414 output[seg] = static_cast<Tvec>(val);
415 }
416 }
417
418 // Normalizes output_raw based on segment size and casts from Treducevec to
419 // Tvec. If Tvec == Treducevec, this is safe to call with output_raw == output.
420 // Note that Treducevec is the type that was used for the reduction, which may
421 // be a higher-precision type than the output type Tvec (e.g., float vs. half).
422 template <typename Tvec, typename Treducevec, typename Tindex,
423 typename Tsegmentids, typename Tinit>
424 Status LaunchSegmentReduceEpilogueKernel(
425 const GPUDevice& d, Tsegmentids nsegments, Tinit empty_segment_value,
426 bool is_mean, bool is_sqrtn,
427 const Treducevec* output_raw, // [nsegments]
428 const Tindex* segment_offsets, // [nsegments + 1]
429 Tvec* output) { // [nsegments]
430 GpuLaunchConfig config = GetGpuLaunchConfig(
431 nsegments, d,
432 &SegmentReduceEpilogueKernel<Tvec, Treducevec, Tindex, Tsegmentids,
433 Tinit>,
434 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
435 return GpuLaunchKernel(
436 SegmentReduceEpilogueKernel<Tvec, Treducevec, Tindex, Tsegmentids, Tinit>,
437 config.block_count, config.thread_per_block, 0, d.stream(), nsegments,
438 empty_segment_value, is_mean, is_sqrtn, output_raw, segment_offsets,
439 output);
440 }
441
442 template <typename Tto>
443 struct CastFunctor {
444 template <typename T>
445 __device__ Tto operator()(const T& val) const {
446 return static_cast<Tto>(val);
447 }
448 };
449
450 template <typename Treducevec, typename Tvec, typename Tindex, typename Tinit>
451 struct LookupAndScaleAndCastInputsFunctor {
452 LookupAndScaleAndCastInputsFunctor(const Tvec* input_vec,
453 const Tindex* indices,
454 const Tinit* weights)
455 : input_vec_(input_vec), indices_(indices), weights_(weights) {}
456
457 __device__ Treducevec operator()(Tindex idx) const {
458 if (indices_) idx = indices_[idx];
459 Treducevec result = static_cast<Treducevec>(input_vec_[idx]);
460 if (weights_) result *= Tvec(weights_[idx]);
461 return result;
462 }
463
464 private:
465 const Tvec* __restrict__ input_vec_;
466 const Tindex* __restrict__ indices_;
467 const Tinit* __restrict__ weights_;
468 };
469
470 template <typename Treducevec, typename Tvec, typename Tindex, typename Tinit>
471 struct CastIterator {
472 using FunctorTy =
473 LookupAndScaleAndCastInputsFunctor<Treducevec, Tvec, Tindex, Tinit>;
474 using InputIteratorTy = gpuprim::CountingInputIterator<Tindex>;
475 using IteratorTy =
476 gpuprim::TransformInputIterator<Treducevec, FunctorTy, InputIteratorTy>;
477 };
478
479 template <typename Treducevec, typename Tvec, typename Tindex, typename Tinit>
480 typename CastIterator<Treducevec, Tvec, Tindex, Tinit>::IteratorTy
481 MakeLookupAndScaleAndCastInputsIterator(const Tvec* input_vec,
482 const Tindex* indices,
483 const Tinit* weights) {
484 using CastIteratorTy = CastIterator<Treducevec, Tvec, Tindex, Tinit>;
485 typename CastIteratorTy::FunctorTy functor(input_vec, indices, weights);
486 return typename CastIteratorTy::IteratorTy(
487 typename CastIteratorTy::InputIteratorTy(Tindex(0)), functor);
488 }
489
490 template <typename Treducevec, typename Tvec, typename Tindex,
491 typename Tsegmentids, typename ReduceOp, typename Tinit>
492 Status SegmentReduceGPUImplNoInnerDim(
493 OpKernelContext* ctx, Tindex nouter, Tsegmentids nsegments,
494 ReduceOp reduce_op, Tinit initial_value, Tinit empty_segment_value,
495 bool is_mean, bool is_sqrtn,
496 const Tvec* input_vec, // [nouter or any]
497 const Tindex* segment_offsets, // [nsegments + 1]
498 const Tindex* indices, // [nouter] (optional)
499 const Tinit* weights, // [nouter or any] (optional)
500 Tvec* output_vec) { // [nsegments]
501 // Here we use gpuprim::DeviceSegmentedReduce (which is optimized for this
502 // shape) and add the additional required functionality using fancy input
503 // iterators and an epilogue kernel.
504
505 // Note: This reinterpret cast is only needed to avoid compilation error
506 // when Tvec != Treducevec; the result is only used if Tvec == Treducevec.
507 Treducevec* output_raw_ptr = reinterpret_cast<Treducevec*>(output_vec);
508 Tensor output_raw;
509 bool need_temp_output = !std::is_same<Tvec, Treducevec>::value;
510 if (need_temp_output) {
511 // Note: We must allocate and reinterpret as bytes because Treducevec may
512 // be a vector type and they are not supported as Tensor dtypes.
513 TF_RETURN_IF_ERROR(ctx->allocate_temp(
514 DT_INT8,
515 TensorShape({static_cast<int64_t>(nsegments * sizeof(Treducevec))}),
516 &output_raw));
517 output_raw_ptr =
518 reinterpret_cast<Treducevec*>(output_raw.flat<int8>().data());
519 }
520 auto input_iter = MakeLookupAndScaleAndCastInputsIterator<Treducevec>(
521 input_vec, indices, weights);
522 TF_RETURN_IF_ERROR(GpuSegmentedReduce(ctx, nsegments, reduce_op,
523 Treducevec(initial_value), input_iter,
524 segment_offsets, output_raw_ptr));
525 bool need_epilogue = !std::is_same<Tvec, Treducevec>::value ||
526 initial_value != empty_segment_value || is_mean ||
527 is_sqrtn;
528 if (need_epilogue) {
529 const GPUDevice& device = ctx->eigen_gpu_device();
530 // Normalize based on the segment size and cast results back to T.
531 TF_RETURN_IF_ERROR(LaunchSegmentReduceEpilogueKernel(
532 device, nsegments, empty_segment_value, is_mean, is_sqrtn,
533 output_raw_ptr, segment_offsets, output_vec));
534 }
535 return OkStatus();
536 }
537
538 template <typename Treducevec, typename Tvec, typename Tindex,
539 typename Tsegmentids, typename ReduceOp, typename Tinit>
540 Status SegmentReduceGPUImpl(
541 OpKernelContext* ctx, Tindex nouter, Tindex ninner_vec,
542 Tsegmentids nsegments, ReduceOp reduce_op, Tinit initial_value,
543 Tinit empty_segment_value, bool is_mean, bool is_sqrtn,
544 const Tvec* input_vec, // [nouter or any, ninner_vec]
545 const Tsegmentids* segment_ids, // [nouter]
546 const Tindex* indices, // [nouter] (optional)
547 const Tinit* weights, // [nouter or any] (optional)
548 Tvec* output_vec) { // [nsegments, ninner_vec]
549 const GPUDevice& device = ctx->eigen_gpu_device();
550
551 if (nouter == 0) {
552 // Just set output to empty_segment_value.
553 GPUDevice d = ctx->template eigen_device<GPUDevice>();
554 int64_t output_size = static_cast<int64_t>(nsegments) * ninner_vec;
555 GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d);
556 return GpuLaunchKernel(SetToValue<Tvec, Tinit>, config.block_count,
557 config.thread_per_block, 0, d.stream(), output_size,
558 output_vec, empty_segment_value);
559 }
560
561 // Allocate and compute segment_offsets.
562 Tensor segment_offsets;
563 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<Tindex>::value,
564 TensorShape({nsegments + 1}),
565 &segment_offsets));
566 Tindex* segment_offsets_ptr = segment_offsets.flat<Tindex>().data();
567 TF_RETURN_IF_ERROR(LaunchSegmentOffsetsKernel(
568 device, nouter, nsegments, segment_ids, segment_offsets_ptr));
569
570 const Tindex avg_reduce_size =
571 Eigen::divup(nouter, static_cast<Tindex>(nsegments));
572 // This avg_reduce_size threshold is a performance heuristic.
573 if (ninner_vec == 1 && avg_reduce_size >= 512) {
574 // Here we use a gpuprim-based implementation that doesn't support an
575 // inner dimension but can be significantly faster for large reductions.
576 return SegmentReduceGPUImplNoInnerDim<Treducevec>(
577 ctx, nouter, nsegments, reduce_op, initial_value, empty_segment_value,
578 is_mean, is_sqrtn, input_vec, segment_offsets_ptr, indices, weights,
579 output_vec);
580 }
581 // Here we use a custom kernel that is optimized for ninner_vec >= ~64 and
582 // gives decent performance for smaller cases. It also handles indices,
583 // casting to/from Treducevec, and normalizing the output.
584 return LaunchSegmentReduceVectorKernel<Treducevec>(
585 device, nouter, ninner_vec, nsegments, reduce_op, initial_value,
586 empty_segment_value, is_mean, is_sqrtn, input_vec, segment_offsets_ptr,
587 indices, weights, output_vec);
588 }
589
590 template <typename Treduce>
591 struct SegmentReduceGPUVectorized {
592 template <int vec_size>
593 struct Impl {
594 template <typename T, typename Tindex, typename Tsegmentids,
595 typename ReduceOp>
596 Status operator()(OpKernelContext* ctx, Tindex nouter, Tindex ninner,
597 Tsegmentids nsegments, ReduceOp reduce_op,
598 T initial_value, T empty_segment_value, bool is_mean,
599 bool is_sqrtn, const T* input,
600 const Tsegmentids* segment_ids, const Tindex* indices,
601 const T* weights, T* output) {
602 DCHECK_EQ(ninner % vec_size, 0);
603 DCHECK_EQ(reinterpret_cast<std::uintptr_t>(input) % vec_size, 0);
604 DCHECK_EQ(reinterpret_cast<std::uintptr_t>(output) % vec_size, 0);
605 Tindex ninner_vec = ninner / vec_size;
606 using Tvec = AlignedVector<T, vec_size>;
607 using Treducevec = AlignedVector<Treduce, vec_size>;
608 const Tvec* input_vec = reinterpret_cast<const Tvec*>(input);
609 Tvec* output_vec = reinterpret_cast<Tvec*>(output);
610
611 return SegmentReduceGPUImpl<Treducevec>(
612 ctx, nouter, ninner_vec, nsegments, reduce_op, initial_value,
613 empty_segment_value, is_mean, is_sqrtn, input_vec, segment_ids,
614 indices, weights, output_vec);
615 }
616 };
617 };
618
619 // Reduces input matrix within segments over the outer dimension. Empty segments
620 // always output empty_segment_value.
621 // The segment_ids vector must be sorted.
622 // If is_mean or is_sqrtn is true, the results are normalized using the
623 // corresponding function.
624 // If indices is not nullptr, input rows are accessed indirectly as
625 // input[indices[i]], instead of input[i].
626 // The implementation is deterministic.
627 // Note: Treduce is to allow reducing in higher precision than T.
628 template <typename Treduce, typename T, typename Tindex, typename Tsegmentids,
629 typename ReduceOp>
630 Status SegmentReduceGPU(OpKernelContext* ctx, Tindex nouter, Tindex ninner,
631 Tsegmentids nsegments, ReduceOp reduce_op,
632 T initial_value, T empty_segment_value, bool is_mean,
633 bool is_sqrtn,
634 const T* input, // [nouter or any, ninner]
635 const Tsegmentids* segment_ids, // [nouter]
636 const Tindex* indices, // [nouter] (optional)
637 const T* weights, // [nouter or any] (optional)
638 T* output) { // [nsegments, ninner]
639 if (ninner == 0 || nsegments == 0) return OkStatus();
640 return DispatchToVectorized<
641 T, SegmentReduceGPUVectorized<Treduce>::template Impl>(
642 MinAlignmentOf(input, output, ninner), ctx, nouter, ninner, nsegments,
643 reduce_op, initial_value, empty_segment_value, is_mean, is_sqrtn, input,
644 segment_ids, indices, weights, output);
645 }
646
647 template <typename SegmentId, typename Index, typename T>
648 __global__ void SegmentWeightsKernel(
649 SegmentId nsegments, SparseSegmentReductionOperation operation,
650 const Index* __restrict__ segment_offsets, // [nsegments + 1]
651 T* __restrict__ weights) { // [nsegments]
652 GPU_1D_KERNEL_LOOP(i, nsegments) {
653 Index segment_size = segment_offsets[i + 1] - segment_offsets[i];
654 segment_size = max(segment_size, Index(1)); // Avoid division by zero
655 if (operation == SparseSegmentReductionOperation::kMean) {
656 weights[i] = T(1) / static_cast<T>(segment_size);
657 } else if (operation == SparseSegmentReductionOperation::kSqrtN) {
658 weights[i] = T(1) / sqrt(static_cast<T>(segment_size));
659 }
660 }
661 }
662
663 template <typename SegmentId, typename Index, typename T>
664 Status LaunchSegmentWeightsKernel(
665 const GPUDevice& d, SegmentId nsegments,
666 SparseSegmentReductionOperation operation,
667 const Index* segment_offsets, // [nsegments + 1]
668 T* weights) { // [nsegments]
669 GpuLaunchConfig config = GetGpuLaunchConfig(
670 nsegments, d, &SegmentWeightsKernel<SegmentId, Index, T>,
671 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
672 return GpuLaunchKernel(SegmentWeightsKernel<SegmentId, Index, T>,
673 config.block_count, config.thread_per_block, 0,
674 d.stream(), nsegments, operation, segment_offsets,
675 weights);
676 }
677
678 template <typename ReduceOp, typename T>
679 struct ReduceType {
680 using type = T;
681 };
682
683 // Sum fp16 values using an fp32 accumulator to avoid numerical issues.
684 template <>
685 struct ReduceType<functor::Sum, Eigen::half> {
686 using type = float;
687 };
688
689 namespace functor {
690
691 template <typename T, typename Index, typename InitialValueF,
692 typename EmptySegmentValueF, typename ReductionF>
693 void SegmentReductionFunctor<
694 T, Index, InitialValueF, EmptySegmentValueF,
695 ReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d,
696 const Index output_rows,
697 const TensorShape& segment_ids_shape, bool is_mean,
698 typename TTypes<Index>::ConstFlat segment_ids,
699 const Index data_size, const T* data,
700 typename TTypes<T, 2>::Tensor output) {
701 if (output.size() == 0) {
702 return;
703 }
704
705 // Launch kernel(s) to compute sorted segment reduction.
706 // Notes:
707 // *) 'input_total_size' is the total number of elements to process.
708 // *) 'segment_ids.shape' is a prefix of data's shape.
709 // *) 'input_outer_dim_size' is the total number of segments to process.
710 const Index input_total_size = data_size;
711 const Index input_outer_dim_size = segment_ids.dimension(0);
712 const Index input_inner_dim_size = input_total_size / input_outer_dim_size;
713 const Index num_segments = output.size() / input_inner_dim_size;
714
715 bool use_deterministic_kernels =
716 #if defined(PLATFORM_WINDOWS)
717 // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
718 // build error.
719 false;
720 #else
721 UseDeterministicSegmentReductions() ||
722 (OpDeterminismRequired() && !ReduceOpIsAssociative<ReductionF, T>::value);
723 #endif
724
725 // TODO(benbarsdell): If there are no performance concerns with the new
726 // deterministic kernels, remove this runtime check and only compile the old
727 // non-deterministic kernels on Windows (as a workaround for the build failure
728 // issue).
729 if (!use_deterministic_kernels) {
730 // Set 'output' to initial value.
731 GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
732 const T InitialValue = InitialValueF()();
733 TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count,
734 config.thread_per_block, 0, d.stream(),
735 output.size(), output.data(), InitialValue));
736 if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
737 return;
738 }
739
740 const int OuterDimTileSize = 8;
741
742 const Index input_outer_dim_num_stripe =
743 Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
744
745 const Index total_stripe_count =
746 input_inner_dim_size * input_outer_dim_num_stripe;
747
748 config = GetGpuLaunchConfig(total_stripe_count, d);
749 TF_CHECK_OK(GpuLaunchKernel(
750 SortedSegmentReductionCustomKernel<
751 T, Index, OuterDimTileSize,
752 typename ReduceUpdateOpFor<ReductionF>::nonatomic_op,
753 typename ReduceUpdateOpFor<ReductionF>::atomic_op>,
754 config.block_count, config.thread_per_block, 0, d.stream(),
755 input_outer_dim_size, input_inner_dim_size, output_rows,
756 segment_ids.data(), data, output.data(), total_stripe_count,
757 InitialValue));
758
759 if (is_mean) {
760 Tensor segment_offsets;
761 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<Index>::value,
762 TensorShape({num_segments + 1}),
763 &segment_offsets));
764 Index* segment_offsets_ptr = segment_offsets.flat<Index>().data();
765 OP_REQUIRES_OK(ctx, LaunchSegmentOffsetsKernel(
766 d, input_outer_dim_size, num_segments,
767 segment_ids.data(), segment_offsets_ptr));
768
769 OP_REQUIRES_OK(ctx, LaunchSegmentMeanNormalizeKernel(
770 d, num_segments, input_inner_dim_size,
771 segment_offsets_ptr, output.data()));
772 }
773 } else {
774 // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
775 // build error.
776 #if !defined(PLATFORM_WINDOWS)
777 using Treduce = typename ReduceType<ReductionF, T>::type;
778 OP_REQUIRES_OK(
779 ctx,
780 SegmentReduceGPU<Treduce>(
781 ctx, input_outer_dim_size, input_inner_dim_size, num_segments,
782 ReductionF(), InitialValueF()(), EmptySegmentValueF()(),
783 /*is_mean=*/is_mean, /*is_sqrtn=*/false, data, segment_ids.data(),
784 /*indices=*/static_cast<const Index*>(nullptr),
785 /*weights=*/static_cast<T*>(nullptr), output.data()));
786 #else
787 // Note: Shouldn't reach here because use_deterministic_kernels is always
788 // false on Windows.
789 OP_REQUIRES(ctx, false,
790 errors::Unimplemented("Deterministic segment reductions are "
791 "not implemented on Windows."));
792 #endif
793 }
794 }
795
796 template <typename T, typename Index, typename InitialValueF,
797 typename ReductionF>
798 struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
799 void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
800 typename TTypes<Index>::ConstFlat unsorted_segment_ids,
801 typename TTypes<T, 2>::ConstTensor data,
802 typename TTypes<T, 2>::Tensor output) {
803 if (output.size() == 0) {
804 return;
805 }
806
807 bool use_deterministic_kernels =
808 #if defined(PLATFORM_WINDOWS)
809 // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
810 // build error.
811 false;
812 #else
813 UseDeterministicSegmentReductions() ||
814 (!ReduceOpIsAssociative<ReductionF, T>::value &&
815 OpDeterminismRequired());
816 #endif
817
818 bool determinism_requirement_met =
819 use_deterministic_kernels ||
820 ReduceOpIsAssociative<ReductionF, T>::value ||
821 !OpDeterminismRequired() ||
822 DisableSegmentReductionOpDeterminismExceptions();
823 OP_REQUIRES(
824 ctx, determinism_requirement_met,
825 errors::Unimplemented(
826 "Deterministic GPU implementation of unsorted segment reduction op"
827 " not available."));
828
829 // Launch kernel(s) to compute unsorted segment reduction.
830 // Notes:
831 // *) 'data_size' is the total number of elements to process.
832 // *) 'segment_ids.shape' is a prefix of data's shape.
833 // *) 'input_outer_dim_size' is the total number of segments to process.
834 const Index input_outer_dim_size = unsorted_segment_ids.dimension(0);
835 const Index input_inner_dim_size = data.dimension(1);
836 const Index output_outer_dim_size = output.dimension(0);
837 const Index num_segments = output.size() / input_inner_dim_size;
838
839 // TODO(benbarsdell): If there are no performance concerns with the new
840 // deterministic kernels, remove this runtime check and only compile the old
841 // non-deterministic kernels on Windows (as a workaround for the build
842 // failure issue).
843 if (!use_deterministic_kernels) {
844 // Set 'output' to initial value.
845 GPUDevice d = ctx->template eigen_device<GPUDevice>();
846 GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
847 TF_CHECK_OK(GpuLaunchKernel(
848 SetToValue<T>, config.block_count, config.thread_per_block, 0,
849 d.stream(), output.size(), output.data(), InitialValueF()()));
850 const int64_t data_size = data.size();
851 if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
852 return;
853 }
854 config = GetGpuLaunchConfig(data_size, d);
855 TF_CHECK_OK(GpuLaunchKernel(
856 UnsortedSegmentCustomKernel<
857 T, Index, typename ReduceUpdateOpFor<ReductionF>::atomic_op>,
858 config.block_count, config.thread_per_block, 0, d.stream(),
859 input_outer_dim_size, input_inner_dim_size, output_outer_dim_size,
860 unsorted_segment_ids.data(), data.data(), output.data()));
861 } else {
862 // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows CI
863 // build error.
864 #if !defined(PLATFORM_WINDOWS)
865 // Allocate temporary space and sort segment_ids, then call the sorted
866 // implem.
867 Tensor segment_ids;
868 OP_REQUIRES_OK(
869 ctx, ctx->allocate_temp(
870 DataTypeToEnum<Index>::value,
871 TensorShape({static_cast<int64_t>(input_outer_dim_size)}),
872 &segment_ids));
873 Index* segment_ids_ptr = segment_ids.flat<Index>().data();
874 Tensor sorted_indices;
875 OP_REQUIRES_OK(
876 ctx, ctx->allocate_temp(
877 DataTypeToEnum<Index>::value,
878 TensorShape({static_cast<int64_t>(input_outer_dim_size)}),
879 &sorted_indices));
880 Index* sorted_indices_ptr = sorted_indices.flat<Index>().data();
881 // Note: We must sort using all bits here because unsorted_segment_ids
882 // may contain negative values.
883 OP_REQUIRES_OK(
884 ctx, GpuRadixSort(ctx, input_outer_dim_size,
885 /*keys_in=*/unsorted_segment_ids.data(),
886 /*keys_out=*/segment_ids_ptr,
887 /*indices_in=*/static_cast<const Index*>(nullptr),
888 /*indices_out=*/sorted_indices_ptr));
889 using Treduce = typename ReduceType<ReductionF, T>::type;
890 OP_REQUIRES_OK(
891 ctx,
892 SegmentReduceGPU<Treduce>(
893 ctx, input_outer_dim_size, input_inner_dim_size, num_segments,
894 ReductionF(), /*initial_value=*/InitialValueF()(),
895 /*empty_segment_value=*/InitialValueF()(), /*is_mean=*/false,
896 /*is_sqrtn=*/false, /*input=*/data.data(),
897 /*segment_ids=*/segment_ids_ptr, /*indices=*/sorted_indices_ptr,
898 /*weights=*/static_cast<T*>(nullptr), output.data()));
899 #else
900 // Note: Shouldn't reach here because use_deterministic_kernels is always
901 // false on Windows.
902 OP_REQUIRES(
903 ctx, false,
904 errors::Unimplemented("Deterministic unsorted segment reductions are "
905 "not implemented on Windows."));
906 #endif
907 }
908 }
909 };
910
911 template <typename T, typename Index, typename SegmentId>
912 Status SparseSegmentReductionFunctor<T, Index, SegmentId>::operator()(
913 OpKernelContext* context, bool is_mean, bool is_sqrtn, T default_value,
914 typename TTypes<T, 2>::ConstTensor input,
915 typename TTypes<Index>::ConstVec indices,
916 typename TTypes<SegmentId>::ConstVec segment_ids,
917 typename TTypes<T, 2>::Tensor output) {
918 using ReduceOp = functor::Sum;
919 using Treduce = typename ReduceType<ReduceOp, T>::type;
920 Index nouter = segment_ids.size();
921 Index ninner = input.dimension(1);
922 SegmentId nsegments = output.dimension(0);
923 return SegmentReduceGPU<Treduce>(
924 context, /*nouter=*/nouter, /*ninner=*/ninner,
925 /*nsegments=*/nsegments, /*reduce_op=*/ReduceOp(),
926 /*initial_value=*/T(0),
927 /*empty_segment_value=*/default_value,
928 /*is_mean=*/is_mean, /*is_sqrtn=*/is_sqrtn,
929 /*input=*/input.data(), /*segment_ids=*/segment_ids.data(),
930 /*indices=*/indices.data(), /*weights=*/static_cast<T*>(nullptr),
931 /*output=*/output.data());
932 }
933
934 template <typename T, typename Index, typename SegmentId>
935 struct SparseSegmentGradFunctor<GPUDevice, T, Index, SegmentId> {
936 void operator()(OpKernelContext* context,
937 SparseSegmentReductionOperation operation,
938 typename TTypes<T>::ConstMatrix input_flat,
939 typename TTypes<Index>::ConstVec indices_vec,
940 typename TTypes<SegmentId>::ConstVec segment_vec,
941 typename TTypes<T>::Matrix output_flat) {
942 const GPUDevice& device = context->eigen_gpu_device();
943
944 const SegmentId nsegments = input_flat.dimension(0);
945 const Index ninner = input_flat.dimension(1);
946 const Index nouter = indices_vec.dimension(0);
947 const Index noutput = output_flat.dimension(0);
948
949 // Allocate and compute segment weights (for Mean/SqrtN operations only).
950 Tensor weights;
951 T* weights_ptr = nullptr;
952 if (operation != SparseSegmentReductionOperation::kSum) {
953 OP_REQUIRES_OK(
954 context, context->allocate_temp(DataTypeToEnum<T>::value,
955 TensorShape({nsegments}), &weights));
956 weights_ptr = weights.flat<T>().data();
957 // Allocate and compute segment_offsets.
958 Tensor segment_offsets;
959 OP_REQUIRES_OK(context,
960 context->allocate_temp(DataTypeToEnum<Index>::value,
961 TensorShape({nsegments + 1}),
962 &segment_offsets));
963 Index* segment_offsets_ptr = segment_offsets.flat<Index>().data();
964 OP_REQUIRES_OK(context, LaunchSegmentOffsetsKernel(
965 device, nouter, nsegments, segment_vec.data(),
966 segment_offsets_ptr));
967 // Compute the weights based on the segment sizes using segment_offsets.
968 OP_REQUIRES_OK(context, LaunchSegmentWeightsKernel(
969 device, nsegments, operation,
970 segment_offsets_ptr, weights_ptr));
971 }
972
973 const Index* sorted_indices_ptr = indices_vec.data();
974 const SegmentId* sorted_segment_ptr = segment_vec.data();
975 Tensor tmp_sorted_indices;
976 Tensor tmp_sorted_segment;
977 if (noutput > 1) {
978 // Sort indices and permute segments.
979 OP_REQUIRES_OK(context, context->allocate_temp(
980 DataTypeToEnum<Index>::value,
981 TensorShape({nouter}), &tmp_sorted_indices));
982 Index* tmp_sorted_indices_ptr = tmp_sorted_indices.flat<Index>().data();
983 OP_REQUIRES_OK(context, context->allocate_temp(
984 DataTypeToEnum<SegmentId>::value,
985 TensorShape({nouter}), &tmp_sorted_segment));
986 SegmentId* tmp_sorted_segment_ptr =
987 tmp_sorted_segment.flat<SegmentId>().data();
988 OP_REQUIRES_OK(context,
989 GpuRadixSort(context, nouter,
990 /*keys_in=*/indices_vec.data(),
991 /*keys_out=*/tmp_sorted_indices_ptr,
992 /*indices_in=*/segment_vec.data(),
993 /*indices_out=*/tmp_sorted_segment_ptr,
994 /*num_bits=*/Log2Ceiling64(noutput)));
995 sorted_indices_ptr = tmp_sorted_indices_ptr;
996 sorted_segment_ptr = tmp_sorted_segment_ptr;
997 }
998
999 // Compute the gradient using a weighted SegmentReduceGPU with the segment
1000 // IDs and indices swapped.
1001 using ReduceOp = gpuprim::Sum;
1002 using Treduce = typename ReduceType<ReduceOp, T>::type;
1003 OP_REQUIRES_OK(
1004 context,
1005 SegmentReduceGPU<Treduce>(
1006 context, /*nouter=*/static_cast<SegmentId>(nouter),
1007 /*ninner=*/static_cast<SegmentId>(ninner),
1008 /*nsegments=*/noutput,
1009 /*reduce_op=*/ReduceOp(),
1010 /*initial_value=*/T(0),
1011 /*empty_segment_value=*/T(0),
1012 /*is_mean=*/false, /*is_sqrtn=*/false,
1013 /*input=*/input_flat.data(), /*segment_ids=*/sorted_indices_ptr,
1014 /*indices=*/sorted_segment_ptr, /*weights=*/weights_ptr,
1015 /*output=*/output_flat.data()));
1016 }
1017 };
1018
1019 } // namespace functor
1020 } // namespace tensorflow
1021
1022 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1023