1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_
17 #define TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20
21 #define EIGEN_USE_GPU
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/kernels/gpu_prim.h"
28 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
29 #include "tensorflow/core/lib/core/bits.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32 #include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace
33
34 #if GOOGLE_CUDA
35 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
36 #elif TENSORFLOW_USE_ROCM
37 #include "tensorflow/core/platform/rocm.h"
38 #endif
39
40 namespace tensorflow {
41
42 typedef Eigen::GpuDevice GPUDevice;
43
44 namespace unique_op_gpu {
45
46 // Returns true iff index is at the end of a segment (which is equivalent to the
47 // beginning of the next segment).
48 template <typename T, typename TIndex>
49 struct SegmentIndicatorFunctor {
50 const T* __restrict__ sorted_input_ptr_;
SegmentIndicatorFunctorSegmentIndicatorFunctor51 SegmentIndicatorFunctor(const T* sorted_input_ptr)
52 : sorted_input_ptr_(sorted_input_ptr) {}
operatorSegmentIndicatorFunctor53 __device__ bool operator()(const TIndex& i) const {
54 return i > 0 && sorted_input_ptr_[i] != sorted_input_ptr_[i - 1];
55 }
56 };
57
58 template <typename TIndex>
ExtractFirstOccurrenceIndicesKernel(int64_t input_size,int64_t uniq_size,const TIndex * __restrict__ sorted_input_inds,const TIndex * __restrict__ sorted_input_unique_ids,TIndex * __restrict__ unique_input_inds,TIndex * __restrict__ segment_ends)59 __global__ void ExtractFirstOccurrenceIndicesKernel(
60 int64_t input_size, int64_t uniq_size,
61 const TIndex* __restrict__ sorted_input_inds,
62 const TIndex* __restrict__ sorted_input_unique_ids,
63 TIndex* __restrict__ unique_input_inds, TIndex* __restrict__ segment_ends) {
64 GPU_1D_KERNEL_LOOP(i, input_size) {
65 TIndex sorted_input_unique_id = sorted_input_unique_ids[i];
66 if (i == 0 || sorted_input_unique_id != sorted_input_unique_ids[i - 1]) {
67 unique_input_inds[sorted_input_unique_id] = sorted_input_inds[i];
68 if (segment_ends) {
69 if (i == 0) {
70 // First thread writes the last element.
71 segment_ends[uniq_size - 1] = input_size;
72 } else {
73 segment_ends[sorted_input_unique_id - 1] = i;
74 }
75 }
76 }
77 }
78 }
79
80 // Scatters the index of the first occurrence of each unique input value to
81 // unique_input_inds.
82 // If segment_ends is not nullptr, it is filled with the end index of each
83 // unique value's range in the sorted input (the last element is always set
84 // to input_size).
85 template <typename TIndex>
ExtractFirstOccurrenceIndices(const GPUDevice & d,int64_t input_size,int64_t uniq_size,const TIndex * sorted_input_inds,const TIndex * sorted_input_unique_ids,TIndex * unique_input_inds,TIndex * segment_ends)86 Status ExtractFirstOccurrenceIndices(const GPUDevice& d, int64_t input_size,
87 int64_t uniq_size,
88 const TIndex* sorted_input_inds,
89 const TIndex* sorted_input_unique_ids,
90 TIndex* unique_input_inds,
91 TIndex* segment_ends) {
92 CHECK_GT(input_size, 0); // Crash OK
93 GpuLaunchConfig config = GetGpuLaunchConfig(
94 input_size, d, &ExtractFirstOccurrenceIndicesKernel<TIndex>,
95 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
96 return GpuLaunchKernel(ExtractFirstOccurrenceIndicesKernel<TIndex>,
97 config.block_count, config.thread_per_block, 0,
98 d.stream(), input_size, uniq_size, sorted_input_inds,
99 sorted_input_unique_ids, unique_input_inds,
100 segment_ends);
101 }
102
103 template <typename T, typename TIndex>
GatherOutputsAndInvertPermutationKernel(int64_t uniq_size,const T * __restrict__ input,const TIndex * __restrict__ sorted_unique_input_inds,const TIndex * __restrict__ sorted_unique_perm,const TIndex * __restrict__ segment_ends,T * __restrict__ output,TIndex * __restrict__ inv_sorted_unique_perm,TIndex * __restrict__ count)104 __global__ void GatherOutputsAndInvertPermutationKernel(
105 int64_t uniq_size, const T* __restrict__ input,
106 const TIndex* __restrict__ sorted_unique_input_inds,
107 const TIndex* __restrict__ sorted_unique_perm,
108 const TIndex* __restrict__ segment_ends, T* __restrict__ output,
109 TIndex* __restrict__ inv_sorted_unique_perm, TIndex* __restrict__ count) {
110 GPU_1D_KERNEL_LOOP(i, uniq_size) {
111 output[i] = input[sorted_unique_input_inds[i]];
112 auto j = sorted_unique_perm[i];
113 inv_sorted_unique_perm[j] = i;
114 if (count) {
115 TIndex beg = j == 0 ? 0 : segment_ends[j - 1];
116 TIndex end = segment_ends[j];
117 count[i] = end - beg;
118 }
119 }
120 }
121
122 // Gathers input values using sorted_unique_input_inds, and inverts the
123 // permutation specified by sorted_unique_perm.
124 template <typename T, typename TIndex>
GatherOutputsAndInvertPermutation(const GPUDevice & d,int64_t uniq_size,const T * input,const TIndex * sorted_unique_input_inds,const TIndex * sorted_unique_perm,const TIndex * segment_ends,T * output,TIndex * inv_sorted_unique_perm,TIndex * count)125 Status GatherOutputsAndInvertPermutation(const GPUDevice& d, int64_t uniq_size,
126 const T* input,
127 const TIndex* sorted_unique_input_inds,
128 const TIndex* sorted_unique_perm,
129 const TIndex* segment_ends, T* output,
130 TIndex* inv_sorted_unique_perm,
131 TIndex* count) {
132 if (uniq_size == 0) return OkStatus();
133 GpuLaunchConfig config = GetGpuLaunchConfig(
134 uniq_size, d, &GatherOutputsAndInvertPermutationKernel<T, TIndex>,
135 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
136 return GpuLaunchKernel(GatherOutputsAndInvertPermutationKernel<T, TIndex>,
137 config.block_count, config.thread_per_block, 0,
138 d.stream(), uniq_size, input, sorted_unique_input_inds,
139 sorted_unique_perm, segment_ends, output,
140 inv_sorted_unique_perm, count);
141 }
142
143 template <typename TIndex>
LookupAndScatterUniqueIdsKernel(int64_t input_size,const TIndex * sorted_input_inds,const TIndex * __restrict__ sorted_input_unique_ids,const TIndex * __restrict__ inv_sorted_unique_perm,TIndex * __restrict__ idx)144 __global__ void LookupAndScatterUniqueIdsKernel(
145 int64_t input_size, const TIndex* sorted_input_inds,
146 const TIndex* __restrict__ sorted_input_unique_ids,
147 const TIndex* __restrict__ inv_sorted_unique_perm,
148 TIndex* __restrict__ idx) {
149 GPU_1D_KERNEL_LOOP(i, input_size) {
150 idx[sorted_input_inds[i]] =
151 inv_sorted_unique_perm[sorted_input_unique_ids[i]];
152 }
153 }
154
155 // Maps the values of sorted_input_unique_ids and scatters them to idx using
156 // sorted_input_inds.
157 template <typename TIndex>
LookupAndScatterUniqueIds(const GPUDevice & d,int64_t input_size,const TIndex * sorted_input_inds,const TIndex * sorted_input_unique_ids,const TIndex * inv_sorted_unique_perm,TIndex * idx)158 Status LookupAndScatterUniqueIds(const GPUDevice& d, int64_t input_size,
159 const TIndex* sorted_input_inds,
160 const TIndex* sorted_input_unique_ids,
161 const TIndex* inv_sorted_unique_perm,
162 TIndex* idx) {
163 CHECK_GT(input_size, 0); // Crash OK
164 GpuLaunchConfig config = GetGpuLaunchConfig(
165 input_size, d, &LookupAndScatterUniqueIdsKernel<TIndex>,
166 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
167 return GpuLaunchKernel(LookupAndScatterUniqueIdsKernel<TIndex>,
168 config.block_count, config.thread_per_block, 0,
169 d.stream(), input_size, sorted_input_inds,
170 sorted_input_unique_ids, inv_sorted_unique_perm, idx);
171 }
172
173 } // namespace unique_op_gpu
174
175 // This only supports Unique[WithCounts], not Unique[WithCounts]V2.
176 template <typename T, typename TIndex>
177 class UniqueOpGPU : public AsyncOpKernel {
178 public:
UniqueOpGPU(OpKernelConstruction * context)179 explicit UniqueOpGPU(OpKernelConstruction* context)
180 : AsyncOpKernel(context) {}
181
182 template <typename U>
AllocateTemp(OpKernelContext * context,int64_t size,Tensor * tensor,U ** tensor_data,DoneCallback done)183 void AllocateTemp(OpKernelContext* context, int64_t size, Tensor* tensor,
184 U** tensor_data, DoneCallback done) const {
185 OP_REQUIRES_OK_ASYNC(context,
186 context->allocate_temp(DataTypeToEnum<U>::value,
187 TensorShape({size}), tensor),
188 done);
189 *tensor_data = tensor->flat<U>().data();
190 }
191
ComputeAsync(OpKernelContext * context,DoneCallback done)192 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
193 const Tensor& input = context->input(0);
194 // TODO(dga): Make unique polymorphic for returning int32 and int64
195 // vectors to support large tensors.
196 OP_REQUIRES_ASYNC(context,
197 input.NumElements() <= std::numeric_limits<int32>::max(),
198 errors::InvalidArgument(
199 "unique does not support input tensors larger than ",
200 std::numeric_limits<int32>::max(), " elements"),
201 done);
202
203 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input.shape()),
204 errors::InvalidArgument("unique expects a 1D vector."),
205 done);
206
207 se::Stream* stream = context->op_device_context()->stream();
208 OP_REQUIRES_ASYNC(context, stream,
209 errors::Internal("No GPU stream available."), done);
210
211 int64_t input_size = input.NumElements();
212 bool has_count_output = num_outputs() > 2;
213 if (input_size == 0) {
214 // Early exit for trivial case.
215 Tensor* t = nullptr;
216 OP_REQUIRES_OK_ASYNC(
217 context, context->allocate_output(0, TensorShape({0}), &t), done);
218 OP_REQUIRES_OK_ASYNC(
219 context, context->allocate_output(1, TensorShape({0}), &t), done);
220 if (has_count_output) {
221 OP_REQUIRES_OK_ASYNC(
222 context, context->allocate_output(2, TensorShape({0}), &t), done);
223 }
224 done();
225 return;
226 }
227
228 // The algorithm implemented here is as follows:
229 // input = [3, 5, 3, 4, 1, 4, 9, 8, 6, 3, 5, 7, 8, 8, 4, 6, 4, 2, 5, 6]
230 // 1) Sort the input to group equal values together in segments.
231 // sorted_input, sorted_input_inds = sort(input)
232 // sorted_input:
233 // [1, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 8, 8, 8, 9]
234 // sorted_input_inds:
235 // [4, 17, 0, 2, 9, 3, 5, 14, 16, 1, 10, 18, 8, 15, 19, 11, 7, 12, 13, 6]
236 // 2) Identify the boundaries between segments and use prefix sum to
237 // compute the unique ID for each sorted value.
238 // sorted_input_unique_ids = prefix_sum(indicator(sorted_input))
239 // indicator(sorted_input):
240 // [0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1]
241 // sorted_input_unique_ids:
242 // [0, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 7, 7, 7, 8]
243 // 3) Extract the input index of the first occurrence of each unique value.
244 // If counts are required, also extract the end index of each segment.
245 // unique_input_inds[sorted_input_unique_ids] =
246 // sorted_input_inds (@ indicator)
247 // segment_ends[sorted_input_unique_ids[i] - 1] = i (@ indicator)
248 // unique_input_inds: [4, 17, 0, 3, 1, 8, 11, 7, 6]
249 // segment_ends: [1, 2, 5, 9, 12, 15, 16, 19, 20]
250 // 4) Sort the extracted unique input indices to put them in order of
251 // first appearance.
252 // sorted_unique_input_inds, sorted_unique_perm =
253 // sort(unique_input_inds)
254 // sorted_unique_input_inds: [0, 1, 3, 4, 6, 7, 8, 11, 17]
255 // sorted_unique_perm: [2, 4, 3, 0, 8, 7, 5, 6, 1]
256 // 5) Gather the sorted unique input values to produce output, and invert
257 // the second sort permutation to produce an inverse ID mapping. If
258 // counts are required, also take the adjacent difference between
259 // segment_ends indices to produce counts.
260 // output = input[sorted_unique_input_inds]
261 // inv_sorted_unique_perm[sorted_unique_perm[i]] = i
262 // counts = adjacent_difference(segment_ends)
263 // output: [3, 5, 4, 1, 9, 8, 6, 7, 2]
264 // inv_sorted_unique_perm: [3, 8, 0, 2, 1, 6, 7, 5, 4]
265 // counts: [3, 3, 4, 1, 1, 3, 3, 1, 1]
266 // 6) Look up unique IDs via the inverse ID mapping and scatter them using
267 // the original sort permutation to produce the indices output.
268 // idx[sorted_input_inds] =
269 // inv_sorted_unique_perm[sorted_input_unique_ids]
270 // idx: [0, 1, 0, 2, 3, 2, 4, 5, 6, 0, 1, 7, 5, 5, 2, 6, 2, 8, 1, 6]
271
272 Tensor sorted_input_inds;
273 TIndex* sorted_input_inds_ptr = nullptr;
274 AllocateTemp(context, input_size, &sorted_input_inds,
275 &sorted_input_inds_ptr, done);
276 if (!context->status().ok()) return;
277
278 Tensor sorted_input;
279 T* sorted_input_ptr = nullptr;
280 AllocateTemp(context, input_size, &sorted_input, &sorted_input_ptr, done);
281 if (!context->status().ok()) return;
282
283 const T* input_ptr = input.flat<T>().data();
284 OP_REQUIRES_OK_ASYNC(
285 context,
286 GpuRadixSort(context, input_size, /*keys_in=*/input_ptr,
287 /*keys_out=*/sorted_input_ptr,
288 /*indices_in=*/static_cast<const TIndex*>(nullptr),
289 /*indices_out=*/sorted_input_inds_ptr),
290 done);
291
292 using namespace unique_op_gpu;
293
294 // Create a fancy input iterator to indicate segment boundaries.
295 gpuprim::CountingInputIterator<TIndex> counting_iter(0);
296 gpuprim::TransformInputIterator<TIndex, SegmentIndicatorFunctor<T, TIndex>,
297 gpuprim::CountingInputIterator<TIndex>>
298 segment_indicator_iter(counting_iter, {sorted_input_ptr});
299
300 Tensor sorted_input_unique_ids;
301 TIndex* sorted_input_unique_ids_ptr = nullptr;
302 AllocateTemp(context, input_size, &sorted_input_unique_ids,
303 &sorted_input_unique_ids_ptr, done);
304 if (!context->status().ok()) return;
305
306 OP_REQUIRES_OK_ASYNC(
307 context,
308 GpuInclusivePrefixSum(context, input_size, segment_indicator_iter,
309 sorted_input_unique_ids_ptr),
310 done);
311
312 // Copy the last element of sorted_input_unique_ids back to the host to
313 // obtain uniq_size.
314 ScratchSpace<TIndex> last_idx_host(context, 1, /*on_host=*/true);
315 OP_REQUIRES_ASYNC(
316 context,
317 stream
318 ->ThenMemcpy(last_idx_host.mutable_data(),
319 se::DeviceMemoryBase(
320 const_cast<TIndex*>(sorted_input_unique_ids_ptr) +
321 (input_size - 1),
322 sizeof(*last_idx_host.data())),
323 sizeof(*last_idx_host.data()))
324 .ok(),
325 errors::Internal("Failed to copy last_idx to host"), done);
326
327 auto async_finish_computation = [this, context, input_size, input_ptr,
328 sorted_input_inds, sorted_input_inds_ptr,
329 sorted_input_unique_ids,
330 sorted_input_unique_ids_ptr, last_idx_host,
331 has_count_output, done]() -> void {
332 const GPUDevice& device = context->eigen_gpu_device();
333 int64 uniq_size = (*last_idx_host.data()) + 1;
334
335 se::gpu::ScopedActivateExecutorContext scoped_activation{
336 context->op_device_context()->stream()->parent()};
337
338 Tensor unique_input_inds;
339 TIndex* unique_input_inds_ptr = nullptr;
340 AllocateTemp(context, uniq_size, &unique_input_inds,
341 &unique_input_inds_ptr, done);
342 if (!context->status().ok()) return;
343
344 Tensor segment_ends;
345 TIndex* segment_ends_ptr = nullptr;
346 if (has_count_output) {
347 AllocateTemp(context, uniq_size, &segment_ends, &segment_ends_ptr,
348 done);
349 if (!context->status().ok()) return;
350 }
351
352 OP_REQUIRES_OK_ASYNC(
353 context,
354 ExtractFirstOccurrenceIndices(
355 device, input_size, uniq_size, sorted_input_inds_ptr,
356 sorted_input_unique_ids_ptr, unique_input_inds_ptr,
357 segment_ends_ptr),
358 done);
359
360 Tensor sorted_unique_input_inds;
361 TIndex* sorted_unique_input_inds_ptr = nullptr;
362 AllocateTemp(context, uniq_size, &sorted_unique_input_inds,
363 &sorted_unique_input_inds_ptr, done);
364 if (!context->status().ok()) return;
365
366 Tensor sorted_unique_perm;
367 TIndex* sorted_unique_perm_ptr = nullptr;
368 AllocateTemp(context, uniq_size, &sorted_unique_perm,
369 &sorted_unique_perm_ptr, done);
370 if (!context->status().ok()) return;
371
372 // Sort by input index so that output is in order of appearance.
373 OP_REQUIRES_OK_ASYNC(
374 context,
375 GpuRadixSort(context, uniq_size,
376 /*keys_in=*/unique_input_inds_ptr,
377 /*keys_out=*/sorted_unique_input_inds_ptr,
378 /*indices_in=*/static_cast<const TIndex*>(nullptr),
379 /*indices_out=*/sorted_unique_perm_ptr,
380 /*num_bits=*/Log2Ceiling(input_size)),
381 done);
382
383 // Free temporary tensor that is no longer needed.
384 unique_input_inds = Tensor();
385 unique_input_inds_ptr = nullptr;
386
387 Tensor* output = nullptr;
388 OP_REQUIRES_OK_ASYNC(
389 context,
390 context->allocate_output(0, TensorShape({uniq_size}), &output), done);
391 T* output_ptr = output->flat<T>().data();
392
393 Tensor inv_sorted_unique_perm;
394 TIndex* inv_sorted_unique_perm_ptr = nullptr;
395 AllocateTemp(context, uniq_size, &inv_sorted_unique_perm,
396 &inv_sorted_unique_perm_ptr, done);
397 if (!context->status().ok()) return;
398
399 TIndex* count_ptr = nullptr;
400 if (has_count_output) {
401 Tensor* count = nullptr;
402 OP_REQUIRES_OK_ASYNC(
403 context,
404 context->allocate_output(2, TensorShape({uniq_size}), &count),
405 done);
406 count_ptr = count->flat<TIndex>().data();
407 }
408
409 // Compute output and counts (if necessary).
410 OP_REQUIRES_OK_ASYNC(
411 context,
412 GatherOutputsAndInvertPermutation(
413 device, uniq_size, input_ptr, sorted_unique_input_inds_ptr,
414 sorted_unique_perm_ptr, segment_ends_ptr, output_ptr,
415 inv_sorted_unique_perm_ptr, count_ptr),
416 done);
417
418 // Free temporary tensors that are no longer needed.
419 sorted_unique_perm = Tensor();
420 sorted_unique_perm_ptr = nullptr;
421 sorted_unique_input_inds = Tensor();
422 sorted_unique_input_inds_ptr = nullptr;
423 segment_ends = Tensor();
424 segment_ends_ptr = nullptr;
425
426 Tensor* idx = nullptr;
427 OP_REQUIRES_OK_ASYNC(
428 context, context->allocate_output(1, TensorShape({input_size}), &idx),
429 done);
430 TIndex* idx_ptr = idx->flat<TIndex>().data();
431
432 // Compute indices output.
433 OP_REQUIRES_OK_ASYNC(
434 context,
435 LookupAndScatterUniqueIds(device, input_size, sorted_input_inds_ptr,
436 sorted_input_unique_ids_ptr,
437 inv_sorted_unique_perm_ptr, idx_ptr),
438 done);
439
440 done();
441 };
442
443 context->device()
444 ->tensorflow_accelerator_device_info()
445 ->event_mgr->ThenExecute(stream, async_finish_computation);
446 }
447 };
448
449 } // end namespace tensorflow
450
451 #endif // GOOGLE_CUDA
452
453 #endif // TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_
454