xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/unique_op_gpu.cu.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_
17 #define TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #define EIGEN_USE_GPU
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/kernels/gpu_prim.h"
28 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
29 #include "tensorflow/core/lib/core/bits.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32 #include "tensorflow/core/util/gpu_solvers.h"  // For ScratchSpace
33 
34 #if GOOGLE_CUDA
35 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
36 #elif TENSORFLOW_USE_ROCM
37 #include "tensorflow/core/platform/rocm.h"
38 #endif
39 
40 namespace tensorflow {
41 
42 typedef Eigen::GpuDevice GPUDevice;
43 
44 namespace unique_op_gpu {
45 
46 // Returns true iff index is at the end of a segment (which is equivalent to the
47 // beginning of the next segment).
48 template <typename T, typename TIndex>
49 struct SegmentIndicatorFunctor {
50   const T* __restrict__ sorted_input_ptr_;
SegmentIndicatorFunctorSegmentIndicatorFunctor51   SegmentIndicatorFunctor(const T* sorted_input_ptr)
52       : sorted_input_ptr_(sorted_input_ptr) {}
operatorSegmentIndicatorFunctor53   __device__ bool operator()(const TIndex& i) const {
54     return i > 0 && sorted_input_ptr_[i] != sorted_input_ptr_[i - 1];
55   }
56 };
57 
58 template <typename TIndex>
ExtractFirstOccurrenceIndicesKernel(int64_t input_size,int64_t uniq_size,const TIndex * __restrict__ sorted_input_inds,const TIndex * __restrict__ sorted_input_unique_ids,TIndex * __restrict__ unique_input_inds,TIndex * __restrict__ segment_ends)59 __global__ void ExtractFirstOccurrenceIndicesKernel(
60     int64_t input_size, int64_t uniq_size,
61     const TIndex* __restrict__ sorted_input_inds,
62     const TIndex* __restrict__ sorted_input_unique_ids,
63     TIndex* __restrict__ unique_input_inds, TIndex* __restrict__ segment_ends) {
64   GPU_1D_KERNEL_LOOP(i, input_size) {
65     TIndex sorted_input_unique_id = sorted_input_unique_ids[i];
66     if (i == 0 || sorted_input_unique_id != sorted_input_unique_ids[i - 1]) {
67       unique_input_inds[sorted_input_unique_id] = sorted_input_inds[i];
68       if (segment_ends) {
69         if (i == 0) {
70           // First thread writes the last element.
71           segment_ends[uniq_size - 1] = input_size;
72         } else {
73           segment_ends[sorted_input_unique_id - 1] = i;
74         }
75       }
76     }
77   }
78 }
79 
80 // Scatters the index of the first occurrence of each unique input value to
81 // unique_input_inds.
82 // If segment_ends is not nullptr, it is filled with the end index of each
83 // unique value's range in the sorted input (the last element is always set
84 // to input_size).
85 template <typename TIndex>
ExtractFirstOccurrenceIndices(const GPUDevice & d,int64_t input_size,int64_t uniq_size,const TIndex * sorted_input_inds,const TIndex * sorted_input_unique_ids,TIndex * unique_input_inds,TIndex * segment_ends)86 Status ExtractFirstOccurrenceIndices(const GPUDevice& d, int64_t input_size,
87                                      int64_t uniq_size,
88                                      const TIndex* sorted_input_inds,
89                                      const TIndex* sorted_input_unique_ids,
90                                      TIndex* unique_input_inds,
91                                      TIndex* segment_ends) {
92   CHECK_GT(input_size, 0);  // Crash OK
93   GpuLaunchConfig config = GetGpuLaunchConfig(
94       input_size, d, &ExtractFirstOccurrenceIndicesKernel<TIndex>,
95       /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
96   return GpuLaunchKernel(ExtractFirstOccurrenceIndicesKernel<TIndex>,
97                          config.block_count, config.thread_per_block, 0,
98                          d.stream(), input_size, uniq_size, sorted_input_inds,
99                          sorted_input_unique_ids, unique_input_inds,
100                          segment_ends);
101 }
102 
103 template <typename T, typename TIndex>
GatherOutputsAndInvertPermutationKernel(int64_t uniq_size,const T * __restrict__ input,const TIndex * __restrict__ sorted_unique_input_inds,const TIndex * __restrict__ sorted_unique_perm,const TIndex * __restrict__ segment_ends,T * __restrict__ output,TIndex * __restrict__ inv_sorted_unique_perm,TIndex * __restrict__ count)104 __global__ void GatherOutputsAndInvertPermutationKernel(
105     int64_t uniq_size, const T* __restrict__ input,
106     const TIndex* __restrict__ sorted_unique_input_inds,
107     const TIndex* __restrict__ sorted_unique_perm,
108     const TIndex* __restrict__ segment_ends, T* __restrict__ output,
109     TIndex* __restrict__ inv_sorted_unique_perm, TIndex* __restrict__ count) {
110   GPU_1D_KERNEL_LOOP(i, uniq_size) {
111     output[i] = input[sorted_unique_input_inds[i]];
112     auto j = sorted_unique_perm[i];
113     inv_sorted_unique_perm[j] = i;
114     if (count) {
115       TIndex beg = j == 0 ? 0 : segment_ends[j - 1];
116       TIndex end = segment_ends[j];
117       count[i] = end - beg;
118     }
119   }
120 }
121 
122 // Gathers input values using sorted_unique_input_inds, and inverts the
123 // permutation specified by sorted_unique_perm.
124 template <typename T, typename TIndex>
GatherOutputsAndInvertPermutation(const GPUDevice & d,int64_t uniq_size,const T * input,const TIndex * sorted_unique_input_inds,const TIndex * sorted_unique_perm,const TIndex * segment_ends,T * output,TIndex * inv_sorted_unique_perm,TIndex * count)125 Status GatherOutputsAndInvertPermutation(const GPUDevice& d, int64_t uniq_size,
126                                          const T* input,
127                                          const TIndex* sorted_unique_input_inds,
128                                          const TIndex* sorted_unique_perm,
129                                          const TIndex* segment_ends, T* output,
130                                          TIndex* inv_sorted_unique_perm,
131                                          TIndex* count) {
132   if (uniq_size == 0) return OkStatus();
133   GpuLaunchConfig config = GetGpuLaunchConfig(
134       uniq_size, d, &GatherOutputsAndInvertPermutationKernel<T, TIndex>,
135       /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
136   return GpuLaunchKernel(GatherOutputsAndInvertPermutationKernel<T, TIndex>,
137                          config.block_count, config.thread_per_block, 0,
138                          d.stream(), uniq_size, input, sorted_unique_input_inds,
139                          sorted_unique_perm, segment_ends, output,
140                          inv_sorted_unique_perm, count);
141 }
142 
143 template <typename TIndex>
LookupAndScatterUniqueIdsKernel(int64_t input_size,const TIndex * sorted_input_inds,const TIndex * __restrict__ sorted_input_unique_ids,const TIndex * __restrict__ inv_sorted_unique_perm,TIndex * __restrict__ idx)144 __global__ void LookupAndScatterUniqueIdsKernel(
145     int64_t input_size, const TIndex* sorted_input_inds,
146     const TIndex* __restrict__ sorted_input_unique_ids,
147     const TIndex* __restrict__ inv_sorted_unique_perm,
148     TIndex* __restrict__ idx) {
149   GPU_1D_KERNEL_LOOP(i, input_size) {
150     idx[sorted_input_inds[i]] =
151         inv_sorted_unique_perm[sorted_input_unique_ids[i]];
152   }
153 }
154 
155 // Maps the values of sorted_input_unique_ids and scatters them to idx using
156 // sorted_input_inds.
157 template <typename TIndex>
LookupAndScatterUniqueIds(const GPUDevice & d,int64_t input_size,const TIndex * sorted_input_inds,const TIndex * sorted_input_unique_ids,const TIndex * inv_sorted_unique_perm,TIndex * idx)158 Status LookupAndScatterUniqueIds(const GPUDevice& d, int64_t input_size,
159                                  const TIndex* sorted_input_inds,
160                                  const TIndex* sorted_input_unique_ids,
161                                  const TIndex* inv_sorted_unique_perm,
162                                  TIndex* idx) {
163   CHECK_GT(input_size, 0);  // Crash OK
164   GpuLaunchConfig config = GetGpuLaunchConfig(
165       input_size, d, &LookupAndScatterUniqueIdsKernel<TIndex>,
166       /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
167   return GpuLaunchKernel(LookupAndScatterUniqueIdsKernel<TIndex>,
168                          config.block_count, config.thread_per_block, 0,
169                          d.stream(), input_size, sorted_input_inds,
170                          sorted_input_unique_ids, inv_sorted_unique_perm, idx);
171 }
172 
173 }  // namespace unique_op_gpu
174 
175 // This only supports Unique[WithCounts], not Unique[WithCounts]V2.
176 template <typename T, typename TIndex>
177 class UniqueOpGPU : public AsyncOpKernel {
178  public:
UniqueOpGPU(OpKernelConstruction * context)179   explicit UniqueOpGPU(OpKernelConstruction* context)
180       : AsyncOpKernel(context) {}
181 
182   template <typename U>
AllocateTemp(OpKernelContext * context,int64_t size,Tensor * tensor,U ** tensor_data,DoneCallback done)183   void AllocateTemp(OpKernelContext* context, int64_t size, Tensor* tensor,
184                     U** tensor_data, DoneCallback done) const {
185     OP_REQUIRES_OK_ASYNC(context,
186                          context->allocate_temp(DataTypeToEnum<U>::value,
187                                                 TensorShape({size}), tensor),
188                          done);
189     *tensor_data = tensor->flat<U>().data();
190   }
191 
ComputeAsync(OpKernelContext * context,DoneCallback done)192   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
193     const Tensor& input = context->input(0);
194     // TODO(dga):  Make unique polymorphic for returning int32 and int64
195     // vectors to support large tensors.
196     OP_REQUIRES_ASYNC(context,
197                       input.NumElements() <= std::numeric_limits<int32>::max(),
198                       errors::InvalidArgument(
199                           "unique does not support input tensors larger than ",
200                           std::numeric_limits<int32>::max(), " elements"),
201                       done);
202 
203     OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input.shape()),
204                       errors::InvalidArgument("unique expects a 1D vector."),
205                       done);
206 
207     se::Stream* stream = context->op_device_context()->stream();
208     OP_REQUIRES_ASYNC(context, stream,
209                       errors::Internal("No GPU stream available."), done);
210 
211     int64_t input_size = input.NumElements();
212     bool has_count_output = num_outputs() > 2;
213     if (input_size == 0) {
214       // Early exit for trivial case.
215       Tensor* t = nullptr;
216       OP_REQUIRES_OK_ASYNC(
217           context, context->allocate_output(0, TensorShape({0}), &t), done);
218       OP_REQUIRES_OK_ASYNC(
219           context, context->allocate_output(1, TensorShape({0}), &t), done);
220       if (has_count_output) {
221         OP_REQUIRES_OK_ASYNC(
222             context, context->allocate_output(2, TensorShape({0}), &t), done);
223       }
224       done();
225       return;
226     }
227 
228     // The algorithm implemented here is as follows:
229     // input = [3, 5, 3, 4, 1, 4, 9, 8, 6, 3, 5, 7, 8, 8, 4, 6, 4, 2, 5, 6]
230     // 1) Sort the input to group equal values together in segments.
231     //      sorted_input, sorted_input_inds = sort(input)
232     // sorted_input:
233     //   [1, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 8, 8, 8, 9]
234     // sorted_input_inds:
235     //   [4, 17, 0, 2, 9, 3, 5, 14, 16, 1, 10, 18, 8, 15, 19, 11, 7, 12, 13, 6]
236     // 2) Identify the boundaries between segments and use prefix sum to
237     //    compute the unique ID for each sorted value.
238     //      sorted_input_unique_ids = prefix_sum(indicator(sorted_input))
239     // indicator(sorted_input):
240     //   [0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1]
241     // sorted_input_unique_ids:
242     //   [0, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 7, 7, 7, 8]
243     // 3) Extract the input index of the first occurrence of each unique value.
244     //    If counts are required, also extract the end index of each segment.
245     //      unique_input_inds[sorted_input_unique_ids] =
246     //          sorted_input_inds (@ indicator)
247     //      segment_ends[sorted_input_unique_ids[i] - 1] = i (@ indicator)
248     // unique_input_inds: [4, 17, 0, 3, 1, 8, 11, 7, 6]
249     // segment_ends: [1, 2, 5, 9, 12, 15, 16, 19, 20]
250     // 4) Sort the extracted unique input indices to put them in order of
251     //    first appearance.
252     //      sorted_unique_input_inds, sorted_unique_perm =
253     //          sort(unique_input_inds)
254     // sorted_unique_input_inds: [0, 1, 3, 4, 6, 7, 8, 11, 17]
255     // sorted_unique_perm: [2, 4, 3, 0, 8, 7, 5, 6, 1]
256     // 5) Gather the sorted unique input values to produce output, and invert
257     //    the second sort permutation to produce an inverse ID mapping. If
258     //    counts are required, also take the adjacent difference between
259     //    segment_ends indices to produce counts.
260     //      output = input[sorted_unique_input_inds]
261     //      inv_sorted_unique_perm[sorted_unique_perm[i]] = i
262     //      counts = adjacent_difference(segment_ends)
263     // output: [3, 5, 4, 1, 9, 8, 6, 7, 2]
264     // inv_sorted_unique_perm: [3, 8, 0, 2, 1, 6, 7, 5, 4]
265     // counts: [3, 3, 4, 1, 1, 3, 3, 1, 1]
266     // 6) Look up unique IDs via the inverse ID mapping and scatter them using
267     //    the original sort permutation to produce the indices output.
268     //      idx[sorted_input_inds] =
269     //          inv_sorted_unique_perm[sorted_input_unique_ids]
270     // idx: [0, 1, 0, 2, 3, 2, 4, 5, 6, 0, 1, 7, 5, 5, 2, 6, 2, 8, 1, 6]
271 
272     Tensor sorted_input_inds;
273     TIndex* sorted_input_inds_ptr = nullptr;
274     AllocateTemp(context, input_size, &sorted_input_inds,
275                  &sorted_input_inds_ptr, done);
276     if (!context->status().ok()) return;
277 
278     Tensor sorted_input;
279     T* sorted_input_ptr = nullptr;
280     AllocateTemp(context, input_size, &sorted_input, &sorted_input_ptr, done);
281     if (!context->status().ok()) return;
282 
283     const T* input_ptr = input.flat<T>().data();
284     OP_REQUIRES_OK_ASYNC(
285         context,
286         GpuRadixSort(context, input_size, /*keys_in=*/input_ptr,
287                      /*keys_out=*/sorted_input_ptr,
288                      /*indices_in=*/static_cast<const TIndex*>(nullptr),
289                      /*indices_out=*/sorted_input_inds_ptr),
290         done);
291 
292     using namespace unique_op_gpu;
293 
294     // Create a fancy input iterator to indicate segment boundaries.
295     gpuprim::CountingInputIterator<TIndex> counting_iter(0);
296     gpuprim::TransformInputIterator<TIndex, SegmentIndicatorFunctor<T, TIndex>,
297                                     gpuprim::CountingInputIterator<TIndex>>
298         segment_indicator_iter(counting_iter, {sorted_input_ptr});
299 
300     Tensor sorted_input_unique_ids;
301     TIndex* sorted_input_unique_ids_ptr = nullptr;
302     AllocateTemp(context, input_size, &sorted_input_unique_ids,
303                  &sorted_input_unique_ids_ptr, done);
304     if (!context->status().ok()) return;
305 
306     OP_REQUIRES_OK_ASYNC(
307         context,
308         GpuInclusivePrefixSum(context, input_size, segment_indicator_iter,
309                               sorted_input_unique_ids_ptr),
310         done);
311 
312     // Copy the last element of sorted_input_unique_ids back to the host to
313     // obtain uniq_size.
314     ScratchSpace<TIndex> last_idx_host(context, 1, /*on_host=*/true);
315     OP_REQUIRES_ASYNC(
316         context,
317         stream
318             ->ThenMemcpy(last_idx_host.mutable_data(),
319                          se::DeviceMemoryBase(
320                              const_cast<TIndex*>(sorted_input_unique_ids_ptr) +
321                                  (input_size - 1),
322                              sizeof(*last_idx_host.data())),
323                          sizeof(*last_idx_host.data()))
324             .ok(),
325         errors::Internal("Failed to copy last_idx to host"), done);
326 
327     auto async_finish_computation = [this, context, input_size, input_ptr,
328                                      sorted_input_inds, sorted_input_inds_ptr,
329                                      sorted_input_unique_ids,
330                                      sorted_input_unique_ids_ptr, last_idx_host,
331                                      has_count_output, done]() -> void {
332       const GPUDevice& device = context->eigen_gpu_device();
333       int64 uniq_size = (*last_idx_host.data()) + 1;
334 
335       se::gpu::ScopedActivateExecutorContext scoped_activation{
336           context->op_device_context()->stream()->parent()};
337 
338       Tensor unique_input_inds;
339       TIndex* unique_input_inds_ptr = nullptr;
340       AllocateTemp(context, uniq_size, &unique_input_inds,
341                    &unique_input_inds_ptr, done);
342       if (!context->status().ok()) return;
343 
344       Tensor segment_ends;
345       TIndex* segment_ends_ptr = nullptr;
346       if (has_count_output) {
347         AllocateTemp(context, uniq_size, &segment_ends, &segment_ends_ptr,
348                      done);
349         if (!context->status().ok()) return;
350       }
351 
352       OP_REQUIRES_OK_ASYNC(
353           context,
354           ExtractFirstOccurrenceIndices(
355               device, input_size, uniq_size, sorted_input_inds_ptr,
356               sorted_input_unique_ids_ptr, unique_input_inds_ptr,
357               segment_ends_ptr),
358           done);
359 
360       Tensor sorted_unique_input_inds;
361       TIndex* sorted_unique_input_inds_ptr = nullptr;
362       AllocateTemp(context, uniq_size, &sorted_unique_input_inds,
363                    &sorted_unique_input_inds_ptr, done);
364       if (!context->status().ok()) return;
365 
366       Tensor sorted_unique_perm;
367       TIndex* sorted_unique_perm_ptr = nullptr;
368       AllocateTemp(context, uniq_size, &sorted_unique_perm,
369                    &sorted_unique_perm_ptr, done);
370       if (!context->status().ok()) return;
371 
372       // Sort by input index so that output is in order of appearance.
373       OP_REQUIRES_OK_ASYNC(
374           context,
375           GpuRadixSort(context, uniq_size,
376                        /*keys_in=*/unique_input_inds_ptr,
377                        /*keys_out=*/sorted_unique_input_inds_ptr,
378                        /*indices_in=*/static_cast<const TIndex*>(nullptr),
379                        /*indices_out=*/sorted_unique_perm_ptr,
380                        /*num_bits=*/Log2Ceiling(input_size)),
381           done);
382 
383       // Free temporary tensor that is no longer needed.
384       unique_input_inds = Tensor();
385       unique_input_inds_ptr = nullptr;
386 
387       Tensor* output = nullptr;
388       OP_REQUIRES_OK_ASYNC(
389           context,
390           context->allocate_output(0, TensorShape({uniq_size}), &output), done);
391       T* output_ptr = output->flat<T>().data();
392 
393       Tensor inv_sorted_unique_perm;
394       TIndex* inv_sorted_unique_perm_ptr = nullptr;
395       AllocateTemp(context, uniq_size, &inv_sorted_unique_perm,
396                    &inv_sorted_unique_perm_ptr, done);
397       if (!context->status().ok()) return;
398 
399       TIndex* count_ptr = nullptr;
400       if (has_count_output) {
401         Tensor* count = nullptr;
402         OP_REQUIRES_OK_ASYNC(
403             context,
404             context->allocate_output(2, TensorShape({uniq_size}), &count),
405             done);
406         count_ptr = count->flat<TIndex>().data();
407       }
408 
409       // Compute output and counts (if necessary).
410       OP_REQUIRES_OK_ASYNC(
411           context,
412           GatherOutputsAndInvertPermutation(
413               device, uniq_size, input_ptr, sorted_unique_input_inds_ptr,
414               sorted_unique_perm_ptr, segment_ends_ptr, output_ptr,
415               inv_sorted_unique_perm_ptr, count_ptr),
416           done);
417 
418       // Free temporary tensors that are no longer needed.
419       sorted_unique_perm = Tensor();
420       sorted_unique_perm_ptr = nullptr;
421       sorted_unique_input_inds = Tensor();
422       sorted_unique_input_inds_ptr = nullptr;
423       segment_ends = Tensor();
424       segment_ends_ptr = nullptr;
425 
426       Tensor* idx = nullptr;
427       OP_REQUIRES_OK_ASYNC(
428           context, context->allocate_output(1, TensorShape({input_size}), &idx),
429           done);
430       TIndex* idx_ptr = idx->flat<TIndex>().data();
431 
432       // Compute indices output.
433       OP_REQUIRES_OK_ASYNC(
434           context,
435           LookupAndScatterUniqueIds(device, input_size, sorted_input_inds_ptr,
436                                     sorted_input_unique_ids_ptr,
437                                     inv_sorted_unique_perm_ptr, idx_ptr),
438           done);
439 
440       done();
441     };
442 
443     context->device()
444         ->tensorflow_accelerator_device_info()
445         ->event_mgr->ThenExecute(stream, async_finish_computation);
446   }
447 };
448 
449 }  // end namespace tensorflow
450 
451 #endif  // GOOGLE_CUDA
452 
453 #endif  // TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_
454