1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
16 #define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
17
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19
20 #define EIGEN_USE_GPU
21
22 #include <cmath>
23 #include <string>
24 #include <vector>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/kernels/gpu_prim.h"
32 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
33 #include "tensorflow/core/kernels/topk_op.h"
34 #include "tensorflow/core/lib/gtl/top_n.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/gpu_kernel_helper.h"
38
39 namespace tensorflow {
40
41 typedef Eigen::GpuDevice GPUDevice;
42
43 namespace impl {
44
45 enum class HeapType { kMinHeap, kMaxHeap };
46 enum class PreferIndices { kLower, kHigher };
47
48 template <typename T>
49 struct Entry {
50 int index;
51 T value;
52
53 // Test-only.
greaterEntry54 static bool greater(const Entry<T>& a, const Entry<T>& b) {
55 if (a.value == b.value) {
56 return a.index < b.index;
57 }
58 return a.value > b.value;
59 }
60 };
61
62 template <typename T>
63 struct LinearData {
64 typedef impl::Entry<T> Entry;
65
66 __device__ Entry& operator[](std::size_t index) const { return data[index]; }
67
get_indexLinearData68 __device__ int get_index(int i) const { return data[i].index; }
get_valueLinearData69 __device__ T get_value(int i) const { return data[i].value; }
70
71 Entry* const data;
72 };
73
74 template <typename T>
75 struct IndirectLinearData {
76 typedef impl::Entry<T> Entry;
77
78 __device__ Entry& operator[](std::size_t index) const { return data[index]; }
79
get_indexIndirectLinearData80 __device__ int get_index(int i) const {
81 return backing_data[data[i].index].index;
82 }
get_valueIndirectLinearData83 __device__ T get_value(int i) const { return data[i].value; }
84
85 Entry* const data;
86 Entry* const backing_data;
87 };
88
89 template <typename T>
90 struct StridedData {
91 typedef impl::Entry<T> Entry;
92
93 __device__ Entry& operator[](std::size_t index) const {
94 return data[index * blockDim.x + threadIdx.x];
95 }
96
get_indexStridedData97 __device__ int get_index(int i) const { return (*this)[i].index; }
get_valueStridedData98 __device__ T get_value(int i) const { return (*this)[i].value; }
99
100 Entry* const data;
101 };
102
103 // A heap of Entry<T> that can either work as a min-heap or as a max-heap.
104 template <HeapType heapType, PreferIndices preferIndices,
105 template <typename> class Data, typename T>
106 struct IndexedHeap {
107 typedef typename Data<T>::Entry Entry;
108 const Data<T> data;
IndexedHeapIndexedHeap109 __device__ IndexedHeap(const Data<T>& d) : data(d) {}
110
is_aboveIndexedHeap111 __device__ bool is_above(int left, int right) {
112 T left_value = data.get_value(left);
113 T right_value = data.get_value(right);
114 if (left_value == right_value) {
115 if (preferIndices == PreferIndices::kLower) {
116 return data.get_index(left) < data.get_index(right);
117 } else {
118 return data.get_index(left) > data.get_index(right);
119 }
120 }
121 if (heapType == HeapType::kMinHeap) {
122 return left_value < right_value;
123 } else {
124 return left_value > right_value;
125 }
126 }
127
assignIndexedHeap128 __device__ void assign(int i, const Entry& entry) { data[i] = entry; }
129
push_upIndexedHeap130 __device__ void push_up(int i) {
131 int child = i;
132 int parent;
133 for (; child > 0; child = parent) {
134 parent = (child - 1) / 2;
135 if (!is_above(child, parent)) {
136 // Heap property satisfied.
137 break;
138 }
139 swap(child, parent);
140 }
141 }
142
swapIndexedHeap143 __device__ void swap(int a, int b) {
144 auto tmp = data[b];
145 data[b] = data[a];
146 data[a] = tmp;
147 }
148
push_root_downIndexedHeap149 __device__ void push_root_down(int k) { push_down(0, k); }
150
151 // MAX-HEAPIFY in Cormen
push_downIndexedHeap152 __device__ void push_down(int node, int k) {
153 while (true) {
154 const int left = 2 * node + 1;
155 const int right = left + 1;
156 int smallest = node;
157 if (left < k && is_above(left, smallest)) {
158 smallest = left;
159 }
160 if (right < k && is_above(right, smallest)) {
161 smallest = right;
162 }
163 if (smallest == node) {
164 break;
165 }
166 swap(smallest, node);
167 node = smallest;
168 }
169 }
170
171 // BUILD-MAX-HEAPIFY in Cormen
buildIndexedHeap172 __device__ void build(int k) {
173 for (int node = (k - 1) / 2; node >= 0; node--) {
174 push_down(node, k);
175 }
176 }
177
178 // HEAP-EXTRACT-MAX in Cormen
remove_rootIndexedHeap179 __device__ void remove_root(int k) {
180 data[0] = data[k - 1];
181 push_root_down(k - 1);
182 }
183
184 // in-place HEAPSORT in Cormen
185 // This method destroys the heap property.
sortIndexedHeap186 __device__ void sort(int k) {
187 for (int slot = k - 1; slot > 0; slot--) {
188 // This is like remove_root but we insert the element at the end.
189 swap(slot, 0);
190 // Heap is now an element smaller.
191 push_root_down(/*k=*/slot);
192 }
193 }
194
replace_rootIndexedHeap195 __device__ void replace_root(const Entry& entry, int k) {
196 data[0] = entry;
197 push_root_down(k);
198 }
199
rootIndexedHeap200 __device__ const Entry& root() { return data[0]; }
201 };
202
203 template <HeapType heapType, PreferIndices preferIndices,
204 template <typename> class Data, typename T>
make_indexed_heap(typename Data<T>::Entry * data)205 __device__ IndexedHeap<heapType, preferIndices, Data, T> make_indexed_heap(
206 typename Data<T>::Entry* data) {
207 return IndexedHeap<heapType, preferIndices, Data, T>{Data<T>{data}};
208 }
209
210 // heapTopK walks over [input, input+length) with `step_size` stride starting at
211 // `start_index`.
212 // It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to
213 // access elements in `heap_entries`. If sorted=true, the elements will be
214 // sorted at the end.
215 template <typename T, template <typename> class Data = LinearData>
216 __device__ void heapTopK(const T* __restrict__ input, int length, int k,
217 Entry<T>* __restrict__ heap_entries,
218 bool sorted = false, int start_index = 0,
219 int step_size = 1) {
220 assert(k <= length);
221
222 auto heap =
223 make_indexed_heap<HeapType::kMinHeap, PreferIndices::kHigher, Data, T>(
224 heap_entries);
225
226 int heap_end_index = start_index + k * step_size;
227 if (heap_end_index > length) {
228 heap_end_index = length;
229 }
230 // Initialize the min-heap.
231 for (int index = start_index, slot = 0; index < heap_end_index;
232 index += step_size, slot++) {
233 heap.assign(slot, {index, input[index]});
234 }
235
236 heap.build(k);
237
238 // Now iterate over the remaining items.
239 // If an item is smaller than the min element, it is not amongst the top k.
240 // Otherwise, replace the min element with it and push upwards.
241 for (int index = heap_end_index; index < length; index += step_size) {
242 // We prefer elements with lower indices. This is given here.
243 // Later elements automatically have higher indices, so can be discarded.
244 if (input[index] > heap.root().value) {
245 // This element should replace the min.
246 heap.replace_root({index, input[index]}, k);
247 }
248 }
249
250 // Sort if wanted.
251 if (sorted) {
252 heap.sort(k);
253 }
254 }
255
256 // mergeShards performs a top-k merge on `num_shards` many sorted streams that
257 // are sorted and stored in `entries` in a strided way:
258 // |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|...
259 // The overall top k elements are written to `top_k_values` and their indices
260 // to top_k_indices.
261 // `top_k_heap` is used as temporary storage for the merge heap.
262 template <typename T>
mergeShards(int num_shards,int k,Entry<T> * __restrict__ entries,Entry<T> * __restrict__ top_k_heap,T * top_k_values,int * top_k_indices)263 __device__ void mergeShards(int num_shards, int k,
264 Entry<T>* __restrict__ entries,
265 Entry<T>* __restrict__ top_k_heap, T* top_k_values,
266 int* top_k_indices) {
267 // If k < num_shards, we can use a min-heap with k elements to get the top k
268 // of the sorted blocks.
269 // If k > num_shards, we can initialize a min-heap with the top element from
270 // each sorted block.
271 const int heap_size = k < num_shards ? k : num_shards;
272
273 // Min-heap part.
274 {
275 auto min_heap = IndexedHeap<HeapType::kMinHeap, PreferIndices::kHigher,
276 IndirectLinearData, T>{
277 IndirectLinearData<T>{top_k_heap, entries}};
278 // Initialize the heap as a min-heap.
279 for (int slot = 0; slot < heap_size; slot++) {
280 min_heap.assign(slot, {slot, entries[slot].value});
281 }
282 min_heap.build(heap_size);
283
284 // Now perform top k with the remaining shards (if num_shards > heap_size).
285 for (int shard = heap_size; shard < num_shards; shard++) {
286 const auto entry = entries[shard];
287 const auto root = min_heap.root();
288 if (entry.value < root.value) {
289 continue;
290 }
291 if (entry.value == root.value &&
292 entry.index > entries[root.index].index) {
293 continue;
294 }
295 // This element should replace the min.
296 min_heap.replace_root({shard, entry.value}, heap_size);
297 }
298 }
299
300 // Max-part.
301 {
302 // Turn the min-heap into a max-heap in-place.
303 auto max_heap = IndexedHeap<HeapType::kMaxHeap, PreferIndices::kLower,
304 IndirectLinearData, T>{
305 IndirectLinearData<T>{top_k_heap, entries}};
306 // Heapify into a max heap.
307 max_heap.build(heap_size);
308
309 // Now extract the minimum k-1 times.
310 // k is treated specially.
311 const int last_k = k - 1;
312 for (int rank = 0; rank < last_k; rank++) {
313 const Entry<T>& max_element = max_heap.root();
314 top_k_values[rank] = max_element.value;
315 int shard_index = max_element.index;
316 top_k_indices[rank] = entries[shard_index].index;
317 int next_shard_index = shard_index + num_shards;
318 // For rank < k-1, each top k heap still contains at least 1 element,
319 // so we can draw a replacement.
320 max_heap.replace_root({next_shard_index, entries[next_shard_index].value},
321 heap_size);
322 }
323
324 // rank == last_k.
325 const Entry<T>& max_element = max_heap.root();
326 top_k_values[last_k] = max_element.value;
327 int shard_index = max_element.index;
328 top_k_indices[last_k] = entries[shard_index].index;
329 }
330 }
331
332 #if GOOGLE_CUDA
333 extern __shared__ char shared_memory[];
334 #endif // GOOGLE_CUDA
335
336 template <typename T>
337 #if TENSORFLOW_USE_ROCM
338 __attribute__((amdgpu_flat_work_group_size(1, 256)))
339 #endif // TENSORFLOW_USE_ROCM
340 __global__ void
TopKKernel(const T * __restrict__ input,int length,int k,bool sorted,T * __restrict__ output,int * __restrict__ indices)341 TopKKernel(const T* __restrict__ input, int length, int k, bool sorted,
342 T* __restrict__ output, int* __restrict__ indices) {
343 #if TENSORFLOW_USE_ROCM
344 HIP_DYNAMIC_SHARED(char, shared_memory);
345 #endif // TENSORFLOW_USE_ROCM
346
347 const int batch_index = blockIdx.x;
348 const T* batch_input = input + batch_index * length;
349
350 const int thread_index = threadIdx.x;
351 const int thread_count = blockDim.x;
352
353 Entry<T>* shared_entries = (Entry<T>*)shared_memory;
354
355 heapTopK<T, StridedData>(batch_input, length, k, shared_entries, true,
356 thread_index, thread_count);
357
358 __syncthreads();
359 if (thread_index == 0) {
360 const int offset = batch_index * k;
361 auto batch_output = output + offset;
362 auto batch_indices = indices + offset;
363 Entry<T>* top_k_heap = shared_entries + thread_count * k;
364
365 // TODO(blackhc): Erich says: Performance can likely be improved
366 // significantly by having the merge be done by multiple threads rather than
367 // just one. ModernGPU has some nice primitives that could help with this.
368 mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output,
369 batch_indices);
370 }
371 }
372
373 template <typename T>
LaunchTopKKernel(const gpuStream_t & stream,int num_shards,const T * input,int batch_size,int length,int k,bool sorted,T * output,int * indices)374 cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
375 const T* input, int batch_size, int length, int k,
376 bool sorted, T* output, int* indices) {
377 // This code assumes that k is small enough that the computation
378 // fits inside shared memory (hard coded to 48KB). In practice this
379 // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64.
380 // The calculation is:
381 // shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k.
382
383 // Use as many shards as possible.
384 if (num_shards <= 0) {
385 constexpr auto shared_memory_size = 48 << 10; // 48 KB
386 const auto heap_size = k * sizeof(Entry<T>);
387 // shared_memory_size = (num_shards + 1) * heap_size <=>
388 num_shards = shared_memory_size / heap_size - 1;
389 if (num_shards <= 0) {
390 num_shards = 1;
391 }
392 auto shard_size = length / num_shards;
393 auto min_shard_size = 2 * k;
394 if (shard_size < min_shard_size) {
395 num_shards = length / min_shard_size;
396 }
397 if (num_shards <= 0) {
398 num_shards = 1;
399 #if GOOGLE_CUDA
400 } else if (num_shards > 1024) {
401 num_shards = 1024;
402 }
403 #elif TENSORFLOW_USE_ROCM
404 // ROCm can't execute with 1024 and requires an explicit
405 // amdgpu_flat_work_group_size attribute with >256
406 } else if (num_shards > 256) {
407 num_shards = 256;
408 }
409 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
410 }
411 // We are limited by the amount of shared memory we have per block.
412 auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
413
414 TF_CHECK_OK(GpuLaunchKernel(TopKKernel<T>, batch_size, num_shards,
415 shared_memory_size, stream, input, length, k,
416 sorted, output, indices));
417 return cudaGetLastError();
418 }
419
420 struct SegmentOffsetCreator {
421 EIGEN_DEVICE_FUNC
SegmentOffsetCreatorSegmentOffsetCreator422 SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {}
423
operatorSegmentOffsetCreator424 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
425 return idx * num_cols_;
426 }
427
428 int num_cols_;
429 };
430
431 struct ColumnIndexCreator {
ColumnIndexCreatorColumnIndexCreator432 ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {}
433
operatorColumnIndexCreator434 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
435 const Eigen::array<int, 1>& ix) const {
436 return ix[0] % num_cols_;
437 }
438
439 int num_cols_;
440 };
441
442 template <typename T>
LaunchSortKernel(OpKernelContext * ctx,const T * input,int num_rows,int num_cols,int k,typename TTypes<T,2>::Tensor values,TTypes<int,2>::Tensor indices)443 Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
444 int num_cols, int k,
445 typename TTypes<T, 2>::Tensor values,
446 TTypes<int, 2>::Tensor indices) {
447 const GPUDevice& d = ctx->eigen_device<GPUDevice>();
448 const auto& cu_stream = GetGpuStream(ctx);
449 size_t temp_storage_bytes = -1;
450
451 // TODO(ebrevdo): Once gpuprim supports iterators for ValueT replace that
452 // tensor with an iterator that directly returns the correct value.
453 Tensor input_indices;
454 TF_RETURN_IF_ERROR(ctx->allocate_temp(
455 DT_INT32, TensorShape({num_rows, num_cols}), &input_indices));
456 auto input_indices_t = To32Bit(input_indices.flat<int32>());
457 input_indices_t.device(d) =
458 input_indices_t.generate(ColumnIndexCreator(num_cols));
459
460 gpuprim::CountingInputIterator<int> counting_iter(0);
461 gpuprim::TransformInputIterator<int, SegmentOffsetCreator,
462 gpuprim::CountingInputIterator<int>>
463 segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols));
464
465 Tensor temp_values;
466 Tensor temp_indices;
467 T* sorted_values_ptr;
468 int* sorted_indices_ptr;
469 if (k == num_cols) {
470 // Doing a full sort, no intermediate values needed.
471 sorted_values_ptr = values.data();
472 sorted_indices_ptr = indices.data();
473 } else {
474 // Need to create intermediate values for sorting.
475 TF_RETURN_IF_ERROR(ctx->allocate_temp(
476 DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices));
477 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
478 TensorShape({num_rows, num_cols}),
479 &temp_values));
480 sorted_indices_ptr = temp_indices.flat<int32>().data();
481 sorted_values_ptr = temp_values.flat<T>().data();
482 }
483
484 bool ran_nonsegmented_version = false;
485 if (num_rows == 1) {
486 #if GOOGLE_CUDA
487 constexpr bool is_supported = true;
488 #else
489 // GpuRadixSortDescending is not supported on ROCm for fp16.
490 constexpr bool is_supported = !std::is_same<T, Eigen::half>::value;
491 #endif
492 if constexpr (is_supported) {
493 // Note: DeviceSegmentedRadixSort is very slow when num_segments=1 because
494 // it only uses 1 SM per segment. Calling the un-segmented version is much
495 // faster in this case.
496 TF_RETURN_IF_ERROR(
497 GpuRadixSortDescending(ctx, num_cols, /*keys_in=*/input,
498 /*keys_out=*/sorted_values_ptr,
499 /*indices_in=*/input_indices_t.data(),
500 /*indices_out=*/sorted_indices_ptr,
501 /*num_bits=*/sizeof(T) * 8));
502 ran_nonsegmented_version = true;
503 }
504 }
505 if (!ran_nonsegmented_version) {
506 auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
507 /* d_temp_storage */ nullptr,
508 /* temp_storage_bytes */ temp_storage_bytes,
509 /* d_keys_in */ input,
510 /* d_keys_out */ sorted_values_ptr,
511 /* d_values_in */ input_indices_t.data(),
512 /* d_values_out */ sorted_indices_ptr,
513 /* num_items */ num_cols * num_rows,
514 /* num_segments */ num_rows,
515 /* d_begin_offsets */ segment_offsets_t,
516 /* d_end_offsets */ segment_offsets_t + 1,
517 /* begin_bit */ 0,
518 /* end_bit */ sizeof(T) * 8,
519 /* stream */ cu_stream);
520 if (err != cudaSuccess) {
521 return errors::Internal(
522 "TopKOp: Could not launch "
523 "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
524 "temp_storage_bytes, status: ",
525 cudaGetErrorString(err));
526 }
527 Tensor temp_storage;
528 TF_RETURN_IF_ERROR(ctx->allocate_temp(
529 DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
530 &temp_storage));
531 err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
532 /* d_temp_storage */ temp_storage.flat<int8>().data(),
533 /* temp_storage_bytes */ temp_storage_bytes,
534 /* d_keys_in */ input,
535 /* d_keys_out */ sorted_values_ptr,
536 /* d_values_in */ input_indices_t.data(),
537 /* d_values_out */ sorted_indices_ptr,
538 /* num_items */ num_cols * num_rows,
539 /* num_segments */ num_rows,
540 /* d_begin_offsets */ segment_offsets_t,
541 /* d_end_offsets */ segment_offsets_t + 1,
542 /* begin_bit */ 0,
543 /* end_bit */ sizeof(T) * 8,
544 /* stream */ cu_stream);
545 if (err != cudaSuccess) {
546 return errors::Internal(
547 "TopKOp: Could not launch "
548 "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to sort "
549 "input, "
550 "temp_storage_bytes: ",
551 temp_storage_bytes, ", status: ", cudaGetErrorString(err));
552 }
553 }
554 if (k < num_cols) {
555 // Need to copy subsets of sorted_indices and sorted_outputs to
556 // indices and outputs.
557 const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
558 const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
559 To32Bit(indices).device(d) =
560 To32Bit(temp_indices.matrix<int32>()).slice(slice_indices, slice_sizes);
561 To32Bit(values).device(d) =
562 To32Bit(temp_values.matrix<T>()).slice(slice_indices, slice_sizes);
563 }
564 return Status::OK();
565 }
566
567 } // end namespace impl
568
569 namespace functor {
570
571 template <typename T>
572 struct TopKFunctor<GPUDevice, T> {
573 static EIGEN_ALWAYS_INLINE Status
574 Compute(OpKernelContext* context, bool sorted, int k,
575 const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
576 const int64 num_cols, typename TTypes<T, 2>::Tensor values,
577 typename TTypes<int, 2>::Tensor indices) {
578 // For small k, use the heap implementation. For larger k, use
579 // the in-place gpuprim sort. For k == num_cols, always use the
580 // in-place gpuprim sort. The thresholds for n and k were determined
581 // empirically.
582 if (num_cols <= 1000 || k == num_cols || k >= 100) {
583 return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
584 k, values, indices);
585 } else {
586 const auto& cu_stream = GetGpuStream(context);
587 auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0,
588 input.data(), num_rows, num_cols, k,
589 sorted, values.data(), indices.data());
590 if (err != cudaSuccess) {
591 return errors::Internal(
592 "Could not launch TopKKernel: ", cudaGetErrorString(err), ".");
593 } else {
594 return Status::OK();
595 }
596 }
597 }
598 };
599
600 } // end namespace functor
601 } // namespace tensorflow
602
603 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
604
605 #endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
606