xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/topk_op_gpu.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
16 #define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 
20 #define EIGEN_USE_GPU
21 
22 #include <cmath>
23 #include <string>
24 #include <vector>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/kernels/gpu_prim.h"
32 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
33 #include "tensorflow/core/kernels/topk_op.h"
34 #include "tensorflow/core/lib/gtl/top_n.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/gpu_kernel_helper.h"
38 
39 namespace tensorflow {
40 
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 namespace impl {
44 
45 enum class HeapType { kMinHeap, kMaxHeap };
46 enum class PreferIndices { kLower, kHigher };
47 
48 template <typename T>
49 struct Entry {
50   int index;
51   T value;
52 
53   // Test-only.
greaterEntry54   static bool greater(const Entry<T>& a, const Entry<T>& b) {
55     if (a.value == b.value) {
56       return a.index < b.index;
57     }
58     return a.value > b.value;
59   }
60 };
61 
62 template <typename T>
63 struct LinearData {
64   typedef impl::Entry<T> Entry;
65 
66   __device__ Entry& operator[](std::size_t index) const { return data[index]; }
67 
get_indexLinearData68   __device__ int get_index(int i) const { return data[i].index; }
get_valueLinearData69   __device__ T get_value(int i) const { return data[i].value; }
70 
71   Entry* const data;
72 };
73 
74 template <typename T>
75 struct IndirectLinearData {
76   typedef impl::Entry<T> Entry;
77 
78   __device__ Entry& operator[](std::size_t index) const { return data[index]; }
79 
get_indexIndirectLinearData80   __device__ int get_index(int i) const {
81     return backing_data[data[i].index].index;
82   }
get_valueIndirectLinearData83   __device__ T get_value(int i) const { return data[i].value; }
84 
85   Entry* const data;
86   Entry* const backing_data;
87 };
88 
89 template <typename T>
90 struct StridedData {
91   typedef impl::Entry<T> Entry;
92 
93   __device__ Entry& operator[](std::size_t index) const {
94     return data[index * blockDim.x + threadIdx.x];
95   }
96 
get_indexStridedData97   __device__ int get_index(int i) const { return (*this)[i].index; }
get_valueStridedData98   __device__ T get_value(int i) const { return (*this)[i].value; }
99 
100   Entry* const data;
101 };
102 
103 // A heap of Entry<T> that can either work as a min-heap or as a max-heap.
104 template <HeapType heapType, PreferIndices preferIndices,
105           template <typename> class Data, typename T>
106 struct IndexedHeap {
107   typedef typename Data<T>::Entry Entry;
108   const Data<T> data;
IndexedHeapIndexedHeap109   __device__ IndexedHeap(const Data<T>& d) : data(d) {}
110 
is_aboveIndexedHeap111   __device__ bool is_above(int left, int right) {
112     T left_value = data.get_value(left);
113     T right_value = data.get_value(right);
114     if (left_value == right_value) {
115       if (preferIndices == PreferIndices::kLower) {
116         return data.get_index(left) < data.get_index(right);
117       } else {
118         return data.get_index(left) > data.get_index(right);
119       }
120     }
121     if (heapType == HeapType::kMinHeap) {
122       return left_value < right_value;
123     } else {
124       return left_value > right_value;
125     }
126   }
127 
assignIndexedHeap128   __device__ void assign(int i, const Entry& entry) { data[i] = entry; }
129 
push_upIndexedHeap130   __device__ void push_up(int i) {
131     int child = i;
132     int parent;
133     for (; child > 0; child = parent) {
134       parent = (child - 1) / 2;
135       if (!is_above(child, parent)) {
136         // Heap property satisfied.
137         break;
138       }
139       swap(child, parent);
140     }
141   }
142 
swapIndexedHeap143   __device__ void swap(int a, int b) {
144     auto tmp = data[b];
145     data[b] = data[a];
146     data[a] = tmp;
147   }
148 
push_root_downIndexedHeap149   __device__ void push_root_down(int k) { push_down(0, k); }
150 
151   // MAX-HEAPIFY in Cormen
push_downIndexedHeap152   __device__ void push_down(int node, int k) {
153     while (true) {
154       const int left = 2 * node + 1;
155       const int right = left + 1;
156       int smallest = node;
157       if (left < k && is_above(left, smallest)) {
158         smallest = left;
159       }
160       if (right < k && is_above(right, smallest)) {
161         smallest = right;
162       }
163       if (smallest == node) {
164         break;
165       }
166       swap(smallest, node);
167       node = smallest;
168     }
169   }
170 
171   // BUILD-MAX-HEAPIFY in Cormen
buildIndexedHeap172   __device__ void build(int k) {
173     for (int node = (k - 1) / 2; node >= 0; node--) {
174       push_down(node, k);
175     }
176   }
177 
178   // HEAP-EXTRACT-MAX in Cormen
remove_rootIndexedHeap179   __device__ void remove_root(int k) {
180     data[0] = data[k - 1];
181     push_root_down(k - 1);
182   }
183 
184   // in-place HEAPSORT in Cormen
185   // This method destroys the heap property.
sortIndexedHeap186   __device__ void sort(int k) {
187     for (int slot = k - 1; slot > 0; slot--) {
188       // This is like remove_root but we insert the element at the end.
189       swap(slot, 0);
190       // Heap is now an element smaller.
191       push_root_down(/*k=*/slot);
192     }
193   }
194 
replace_rootIndexedHeap195   __device__ void replace_root(const Entry& entry, int k) {
196     data[0] = entry;
197     push_root_down(k);
198   }
199 
rootIndexedHeap200   __device__ const Entry& root() { return data[0]; }
201 };
202 
203 template <HeapType heapType, PreferIndices preferIndices,
204           template <typename> class Data, typename T>
make_indexed_heap(typename Data<T>::Entry * data)205 __device__ IndexedHeap<heapType, preferIndices, Data, T> make_indexed_heap(
206     typename Data<T>::Entry* data) {
207   return IndexedHeap<heapType, preferIndices, Data, T>{Data<T>{data}};
208 }
209 
210 // heapTopK walks over [input, input+length) with `step_size` stride starting at
211 // `start_index`.
212 // It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to
213 // access elements in `heap_entries`. If sorted=true, the elements will be
214 // sorted at the end.
215 template <typename T, template <typename> class Data = LinearData>
216 __device__ void heapTopK(const T* __restrict__ input, int length, int k,
217                          Entry<T>* __restrict__ heap_entries,
218                          bool sorted = false, int start_index = 0,
219                          int step_size = 1) {
220   assert(k <= length);
221 
222   auto heap =
223       make_indexed_heap<HeapType::kMinHeap, PreferIndices::kHigher, Data, T>(
224           heap_entries);
225 
226   int heap_end_index = start_index + k * step_size;
227   if (heap_end_index > length) {
228     heap_end_index = length;
229   }
230   // Initialize the min-heap.
231   for (int index = start_index, slot = 0; index < heap_end_index;
232        index += step_size, slot++) {
233     heap.assign(slot, {index, input[index]});
234   }
235 
236   heap.build(k);
237 
238   // Now iterate over the remaining items.
239   // If an item is smaller than the min element, it is not amongst the top k.
240   // Otherwise, replace the min element with it and push upwards.
241   for (int index = heap_end_index; index < length; index += step_size) {
242     // We prefer elements with lower indices. This is given here.
243     // Later elements automatically have higher indices, so can be discarded.
244     if (input[index] > heap.root().value) {
245       // This element should replace the min.
246       heap.replace_root({index, input[index]}, k);
247     }
248   }
249 
250   // Sort if wanted.
251   if (sorted) {
252     heap.sort(k);
253   }
254 }
255 
256 // mergeShards performs a top-k merge on `num_shards` many sorted streams that
257 // are sorted and stored in `entries` in a strided way:
258 // |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|...
259 // The overall top k elements are written to `top_k_values` and their indices
260 // to top_k_indices.
261 // `top_k_heap` is used as temporary storage for the merge heap.
262 template <typename T>
mergeShards(int num_shards,int k,Entry<T> * __restrict__ entries,Entry<T> * __restrict__ top_k_heap,T * top_k_values,int * top_k_indices)263 __device__ void mergeShards(int num_shards, int k,
264                             Entry<T>* __restrict__ entries,
265                             Entry<T>* __restrict__ top_k_heap, T* top_k_values,
266                             int* top_k_indices) {
267   // If k < num_shards, we can use a min-heap with k elements to get the top k
268   // of the sorted blocks.
269   // If k > num_shards, we can initialize a min-heap with the top element from
270   // each sorted block.
271   const int heap_size = k < num_shards ? k : num_shards;
272 
273   // Min-heap part.
274   {
275     auto min_heap = IndexedHeap<HeapType::kMinHeap, PreferIndices::kHigher,
276                                 IndirectLinearData, T>{
277         IndirectLinearData<T>{top_k_heap, entries}};
278     // Initialize the heap as a min-heap.
279     for (int slot = 0; slot < heap_size; slot++) {
280       min_heap.assign(slot, {slot, entries[slot].value});
281     }
282     min_heap.build(heap_size);
283 
284     // Now perform top k with the remaining shards (if num_shards > heap_size).
285     for (int shard = heap_size; shard < num_shards; shard++) {
286       const auto entry = entries[shard];
287       const auto root = min_heap.root();
288       if (entry.value < root.value) {
289         continue;
290       }
291       if (entry.value == root.value &&
292           entry.index > entries[root.index].index) {
293         continue;
294       }
295       // This element should replace the min.
296       min_heap.replace_root({shard, entry.value}, heap_size);
297     }
298   }
299 
300   // Max-part.
301   {
302     // Turn the min-heap into a max-heap in-place.
303     auto max_heap = IndexedHeap<HeapType::kMaxHeap, PreferIndices::kLower,
304                                 IndirectLinearData, T>{
305         IndirectLinearData<T>{top_k_heap, entries}};
306     // Heapify into a max heap.
307     max_heap.build(heap_size);
308 
309     // Now extract the minimum k-1 times.
310     // k is treated specially.
311     const int last_k = k - 1;
312     for (int rank = 0; rank < last_k; rank++) {
313       const Entry<T>& max_element = max_heap.root();
314       top_k_values[rank] = max_element.value;
315       int shard_index = max_element.index;
316       top_k_indices[rank] = entries[shard_index].index;
317       int next_shard_index = shard_index + num_shards;
318       // For rank < k-1, each top k heap still contains at least 1 element,
319       // so we can draw a replacement.
320       max_heap.replace_root({next_shard_index, entries[next_shard_index].value},
321                             heap_size);
322     }
323 
324     // rank == last_k.
325     const Entry<T>& max_element = max_heap.root();
326     top_k_values[last_k] = max_element.value;
327     int shard_index = max_element.index;
328     top_k_indices[last_k] = entries[shard_index].index;
329   }
330 }
331 
332 #if GOOGLE_CUDA
333 extern __shared__ char shared_memory[];
334 #endif  // GOOGLE_CUDA
335 
336 template <typename T>
337 #if TENSORFLOW_USE_ROCM
338 __attribute__((amdgpu_flat_work_group_size(1, 256)))
339 #endif  // TENSORFLOW_USE_ROCM
340 __global__ void
TopKKernel(const T * __restrict__ input,int length,int k,bool sorted,T * __restrict__ output,int * __restrict__ indices)341 TopKKernel(const T* __restrict__ input, int length, int k, bool sorted,
342            T* __restrict__ output, int* __restrict__ indices) {
343 #if TENSORFLOW_USE_ROCM
344   HIP_DYNAMIC_SHARED(char, shared_memory);
345 #endif  // TENSORFLOW_USE_ROCM
346 
347   const int batch_index = blockIdx.x;
348   const T* batch_input = input + batch_index * length;
349 
350   const int thread_index = threadIdx.x;
351   const int thread_count = blockDim.x;
352 
353   Entry<T>* shared_entries = (Entry<T>*)shared_memory;
354 
355   heapTopK<T, StridedData>(batch_input, length, k, shared_entries, true,
356                            thread_index, thread_count);
357 
358   __syncthreads();
359   if (thread_index == 0) {
360     const int offset = batch_index * k;
361     auto batch_output = output + offset;
362     auto batch_indices = indices + offset;
363     Entry<T>* top_k_heap = shared_entries + thread_count * k;
364 
365     // TODO(blackhc): Erich says: Performance can likely be improved
366     // significantly by having the merge be done by multiple threads rather than
367     // just one.  ModernGPU has some nice primitives that could help with this.
368     mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output,
369                 batch_indices);
370   }
371 }
372 
373 template <typename T>
LaunchTopKKernel(const gpuStream_t & stream,int num_shards,const T * input,int batch_size,int length,int k,bool sorted,T * output,int * indices)374 cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
375                            const T* input, int batch_size, int length, int k,
376                            bool sorted, T* output, int* indices) {
377   // This code assumes that k is small enough that the computation
378   // fits inside shared memory (hard coded to 48KB).  In practice this
379   // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64.
380   // The calculation is:
381   //   shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k.
382 
383   // Use as many shards as possible.
384   if (num_shards <= 0) {
385     constexpr auto shared_memory_size = 48 << 10;  // 48 KB
386     const auto heap_size = k * sizeof(Entry<T>);
387     // shared_memory_size = (num_shards + 1) * heap_size <=>
388     num_shards = shared_memory_size / heap_size - 1;
389     if (num_shards <= 0) {
390       num_shards = 1;
391     }
392     auto shard_size = length / num_shards;
393     auto min_shard_size = 2 * k;
394     if (shard_size < min_shard_size) {
395       num_shards = length / min_shard_size;
396     }
397     if (num_shards <= 0) {
398       num_shards = 1;
399 #if GOOGLE_CUDA
400     } else if (num_shards > 1024) {
401       num_shards = 1024;
402     }
403 #elif TENSORFLOW_USE_ROCM
404       // ROCm can't execute with 1024 and requires an explicit
405       // amdgpu_flat_work_group_size attribute with >256
406     } else if (num_shards > 256) {
407       num_shards = 256;
408     }
409 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
410   }
411   // We are limited by the amount of shared memory we have per block.
412   auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
413 
414   TF_CHECK_OK(GpuLaunchKernel(TopKKernel<T>, batch_size, num_shards,
415                               shared_memory_size, stream, input, length, k,
416                               sorted, output, indices));
417   return cudaGetLastError();
418 }
419 
420 struct SegmentOffsetCreator {
421   EIGEN_DEVICE_FUNC
SegmentOffsetCreatorSegmentOffsetCreator422   SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {}
423 
operatorSegmentOffsetCreator424   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
425     return idx * num_cols_;
426   }
427 
428   int num_cols_;
429 };
430 
431 struct ColumnIndexCreator {
ColumnIndexCreatorColumnIndexCreator432   ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {}
433 
operatorColumnIndexCreator434   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
435       const Eigen::array<int, 1>& ix) const {
436     return ix[0] % num_cols_;
437   }
438 
439   int num_cols_;
440 };
441 
442 template <typename T>
LaunchSortKernel(OpKernelContext * ctx,const T * input,int num_rows,int num_cols,int k,typename TTypes<T,2>::Tensor values,TTypes<int,2>::Tensor indices)443 Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
444                         int num_cols, int k,
445                         typename TTypes<T, 2>::Tensor values,
446                         TTypes<int, 2>::Tensor indices) {
447   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
448   const auto& cu_stream = GetGpuStream(ctx);
449   size_t temp_storage_bytes = -1;
450 
451   // TODO(ebrevdo): Once gpuprim supports iterators for ValueT replace that
452   // tensor with an iterator that directly returns the correct value.
453   Tensor input_indices;
454   TF_RETURN_IF_ERROR(ctx->allocate_temp(
455       DT_INT32, TensorShape({num_rows, num_cols}), &input_indices));
456   auto input_indices_t = To32Bit(input_indices.flat<int32>());
457   input_indices_t.device(d) =
458       input_indices_t.generate(ColumnIndexCreator(num_cols));
459 
460   gpuprim::CountingInputIterator<int> counting_iter(0);
461   gpuprim::TransformInputIterator<int, SegmentOffsetCreator,
462                                   gpuprim::CountingInputIterator<int>>
463       segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols));
464 
465   Tensor temp_values;
466   Tensor temp_indices;
467   T* sorted_values_ptr;
468   int* sorted_indices_ptr;
469   if (k == num_cols) {
470     // Doing a full sort, no intermediate values needed.
471     sorted_values_ptr = values.data();
472     sorted_indices_ptr = indices.data();
473   } else {
474     // Need to create intermediate values for sorting.
475     TF_RETURN_IF_ERROR(ctx->allocate_temp(
476         DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices));
477     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
478                                           TensorShape({num_rows, num_cols}),
479                                           &temp_values));
480     sorted_indices_ptr = temp_indices.flat<int32>().data();
481     sorted_values_ptr = temp_values.flat<T>().data();
482   }
483 
484   bool ran_nonsegmented_version = false;
485   if (num_rows == 1) {
486 #if GOOGLE_CUDA
487     constexpr bool is_supported = true;
488 #else
489     // GpuRadixSortDescending is not supported on ROCm for fp16.
490     constexpr bool is_supported = !std::is_same<T, Eigen::half>::value;
491 #endif
492     if constexpr (is_supported) {
493       // Note: DeviceSegmentedRadixSort is very slow when num_segments=1 because
494       // it only uses 1 SM per segment. Calling the un-segmented version is much
495       // faster in this case.
496       TF_RETURN_IF_ERROR(
497           GpuRadixSortDescending(ctx, num_cols, /*keys_in=*/input,
498                                  /*keys_out=*/sorted_values_ptr,
499                                  /*indices_in=*/input_indices_t.data(),
500                                  /*indices_out=*/sorted_indices_ptr,
501                                  /*num_bits=*/sizeof(T) * 8));
502       ran_nonsegmented_version = true;
503     }
504   }
505   if (!ran_nonsegmented_version) {
506     auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
507         /* d_temp_storage */ nullptr,
508         /* temp_storage_bytes */ temp_storage_bytes,
509         /* d_keys_in */ input,
510         /* d_keys_out */ sorted_values_ptr,
511         /* d_values_in */ input_indices_t.data(),
512         /* d_values_out */ sorted_indices_ptr,
513         /* num_items */ num_cols * num_rows,
514         /* num_segments */ num_rows,
515         /* d_begin_offsets */ segment_offsets_t,
516         /* d_end_offsets */ segment_offsets_t + 1,
517         /* begin_bit */ 0,
518         /* end_bit */ sizeof(T) * 8,
519         /* stream */ cu_stream);
520     if (err != cudaSuccess) {
521       return errors::Internal(
522           "TopKOp: Could not launch "
523           "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
524           "temp_storage_bytes, status: ",
525           cudaGetErrorString(err));
526     }
527     Tensor temp_storage;
528     TF_RETURN_IF_ERROR(ctx->allocate_temp(
529         DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
530         &temp_storage));
531     err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
532         /* d_temp_storage */ temp_storage.flat<int8>().data(),
533         /* temp_storage_bytes */ temp_storage_bytes,
534         /* d_keys_in */ input,
535         /* d_keys_out */ sorted_values_ptr,
536         /* d_values_in */ input_indices_t.data(),
537         /* d_values_out */ sorted_indices_ptr,
538         /* num_items */ num_cols * num_rows,
539         /* num_segments */ num_rows,
540         /* d_begin_offsets */ segment_offsets_t,
541         /* d_end_offsets */ segment_offsets_t + 1,
542         /* begin_bit */ 0,
543         /* end_bit */ sizeof(T) * 8,
544         /* stream */ cu_stream);
545     if (err != cudaSuccess) {
546       return errors::Internal(
547           "TopKOp: Could not launch "
548           "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to sort "
549           "input, "
550           "temp_storage_bytes: ",
551           temp_storage_bytes, ", status: ", cudaGetErrorString(err));
552     }
553   }
554   if (k < num_cols) {
555     // Need to copy subsets of sorted_indices and sorted_outputs to
556     // indices and outputs.
557     const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
558     const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
559     To32Bit(indices).device(d) =
560         To32Bit(temp_indices.matrix<int32>()).slice(slice_indices, slice_sizes);
561     To32Bit(values).device(d) =
562         To32Bit(temp_values.matrix<T>()).slice(slice_indices, slice_sizes);
563   }
564   return Status::OK();
565 }
566 
567 }  // end namespace impl
568 
569 namespace functor {
570 
571 template <typename T>
572 struct TopKFunctor<GPUDevice, T> {
573   static EIGEN_ALWAYS_INLINE Status
574   Compute(OpKernelContext* context, bool sorted, int k,
575           const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
576           const int64 num_cols, typename TTypes<T, 2>::Tensor values,
577           typename TTypes<int, 2>::Tensor indices) {
578     // For small k, use the heap implementation.  For larger k, use
579     // the in-place gpuprim sort.  For k == num_cols, always use the
580     // in-place gpuprim sort.  The thresholds for n and k were determined
581     // empirically.
582     if (num_cols <= 1000 || k == num_cols || k >= 100) {
583       return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
584                                     k, values, indices);
585     } else {
586       const auto& cu_stream = GetGpuStream(context);
587       auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0,
588                                         input.data(), num_rows, num_cols, k,
589                                         sorted, values.data(), indices.data());
590       if (err != cudaSuccess) {
591         return errors::Internal(
592             "Could not launch TopKKernel: ", cudaGetErrorString(err), ".");
593       } else {
594         return Status::OK();
595       }
596     }
597   }
598 };
599 
600 }  // end namespace functor
601 }  // namespace tensorflow
602 
603 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
604 
605 #endif  // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
606