xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/TensorTopK.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cuda/TensorTopK.h>
3 #include <ATen/core/TensorBase.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/detail/TensorInfo.cuh>
8 #include <ATen/cuda/detail/OffsetCalculator.cuh>
9 #include <ATen/cuda/ScanUtils.cuh>
10 #include <ATen/cuda/AsmUtils.cuh>
11 #include <ATen/cuda/DeviceUtils.cuh>
12 #include <ATen/native/cuda/SortingCommon.cuh>
13 #include <ATen/native/cuda/SortingRadixSelect.cuh>
14 #include <ATen/cuda/cub.cuh>
15 #include <c10/cuda/CUDACachingAllocator.h>
16 #include <ATen/cuda/detail/KernelUtils.h>
17 
18 #include <c10/macros/Macros.h>
19 
20 using namespace at::native;
21 
22 namespace at::native {
23 
24 // TODO: remove this when CUDA <11.6 is no longer supported
disable_sort_for_topk()25 bool disable_sort_for_topk() {
26   return CUB_SUPPORTS_SCAN_BY_KEY();
27 }
28 
29 namespace sbtopk { // single_block_topk
30 
31 template <typename T>
32 struct AddOp {
operator ()at::native::sbtopk::AddOp33   __device__ __forceinline__ T operator()(T const &lhs, T const &rhs) {
34     return (lhs + rhs);
35   }
36 };
37 
38 template <typename T, typename IndexType, int Dim, bool WithKthValues>
39 C10_LAUNCH_BOUNDS_1(1024)
gatherTopK(at::cuda::detail::TensorInfo<const T,IndexType> input,IndexType inputSliceSize,IndexType outputSliceSize,bool largest,IndexType numInputSlices,IndexType inputWithinSliceStride,at::cuda::detail::TensorInfo<T,IndexType> topK,IndexType topKWithinSliceStride,at::cuda::detail::TensorInfo<int64_t,IndexType> indices,IndexType indicesWithinSliceStride,T * kthValues)40 __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> input,
41                            IndexType inputSliceSize,
42                            IndexType outputSliceSize, // aka `k`
43                            bool largest,
44 
45                            IndexType numInputSlices,
46                            IndexType inputWithinSliceStride,
47 
48                            at::cuda::detail::TensorInfo<T, IndexType> topK,
49                            IndexType topKWithinSliceStride,
50 
51                            at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
52                            IndexType indicesWithinSliceStride,
53                            T* kthValues) {
54   // Indices are limited to integer fp precision, so counts can fit in
55   // int32, regardless of IndexType
56 #if defined(USE_ROCM)
57   __shared__ int smem[64];
58 #else
59   __shared__ int smem[32]; // one per each warp, up to warp limit
60 #endif
61   IndexType slice = getLinearBlockId<IndexType>();
62   if (slice >= numInputSlices) {
63     return;
64   }
65 
66   // Find the start offset for our slice
67   IndexType sliceStartIndex =
68     at::cuda::detail::IndexToOffset<const T, IndexType, Dim>::get(slice, input);
69   IndexType topKSliceStartIndex =
70     at::cuda::detail::IndexToOffset<T, IndexType, Dim>::get(slice, topK);
71   IndexType indicesSliceStartIndex =
72     at::cuda::detail::IndexToOffset<int64_t, IndexType, Dim>::get(slice, indices);
73 
74   const T* inputSliceStart = &input.data[sliceStartIndex];
75   T* topKSliceStart = &topK.data[topKSliceStartIndex];
76   int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
77 
78   // Find the k-th highest element in our input
79   T topKValue;
80   if (WithKthValues){
81     topKValue = kthValues[slice];
82   } else {
83     topKValue = static_cast<T>(0);
84     radixSelect<T, typename TopKTypeConfig<T>::RadixType, IndexType>(
85       inputSliceStart, outputSliceSize, largest,
86       inputSliceSize, inputWithinSliceStride,
87       smem, &topKValue);
88   }
89   const auto topKConverted = at::native::TopKTypeConfig<T>::convert(topKValue);
90 
91   // Every value that is strictly less/greater than `pattern`
92   // (depending on sort dir) in sorted int format is in the top-K.
93   // The top-K value itself might not be unique.
94   //
95   // Since there are a variable number of elements that we see that
96   // are within the top-k, we don't know at what index to write out
97   // the resulting values.
98   // In order to get this, we perform an exclusive prefix sum of
99   // `hasTopK`. This will return the resulting index into which we
100   // need to write the result, if a thread has a result.
101 
102   // All threads need to participate in the loop and the prefix sum,
103   // but not necessarily in the load; hence loop bounds being rounded
104   // up to a multiple of the block dim.
105   IndexType numIterations = round_up(inputSliceSize, (IndexType) blockDim.x);
106   IndexType writeIndexStart = 0;
107 
108   for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
109     bool inRange = (i < inputSliceSize);
110     T v =
111       inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : static_cast<T>(0);
112     const auto convertedV = at::native::TopKTypeConfig<T>::convert(v);
113     bool hasTopK;
114     if (largest) {
115       hasTopK = inRange && (convertedV > topKConverted);
116     } else {
117       hasTopK = inRange && (convertedV < topKConverted);
118     }
119 
120     int index;
121     int carry;
122     at::cuda::exclusiveBinaryPrefixScan<int, true>(
123         smem, hasTopK, &index, &carry, AddOp<int>());
124 
125     if (hasTopK) {
126       int writeIndex = writeIndexStart + index;
127       CUDA_KERNEL_ASSERT(writeIndex < outputSliceSize);
128 
129       IndexType topKOffset = writeIndex * topKWithinSliceStride;
130       IndexType indexOffset = writeIndex * indicesWithinSliceStride;
131 
132       topKSliceStart[topKOffset] = v;
133       indicesSliceStart[indexOffset] = i;
134     }
135 
136     writeIndexStart += carry;
137   }
138 
139   // We need to fill in the rest with actual == top-K values.
140   // The number that we need is outputSliceSize -
141   // writeIndexStart. There might be more than that number available,
142   // in which case we have to choose the first seen set. We do this
143   // via a prefix sum to calculate indices for writing results.
144   CUDA_KERNEL_ASSERT(outputSliceSize >= writeIndexStart);
145   IndexType topKRemaining = (outputSliceSize - writeIndexStart);
146 
147   for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
148     bool inRange = (i < inputSliceSize);
149     T v =
150       inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : static_cast<T>(0);
151     const auto convertedV = at::native::TopKTypeConfig<T>::convert(v);
152     bool hasTopK = inRange && (convertedV == topKConverted);
153 
154     int index;
155     int carry;
156     at::cuda::exclusiveBinaryPrefixScan<int, true>(
157         smem, hasTopK, &index, &carry, AddOp<int>());
158 
159     if (hasTopK && index < topKRemaining) {
160       int writeIndex = writeIndexStart + index;
161       CUDA_KERNEL_ASSERT(writeIndex < outputSliceSize);
162 
163       IndexType topKOffset = writeIndex * topKWithinSliceStride;
164       IndexType indexOffset = writeIndex * indicesWithinSliceStride;
165 
166       topKSliceStart[topKOffset] = v;
167       indicesSliceStart[indexOffset] = i;
168     }
169 
170     if (carry >= topKRemaining) {
171       break;
172     }
173 
174     topKRemaining -= carry;
175     writeIndexStart += carry;
176   }
177 
178 };
179 
180 template <typename T, typename IndexType, int Dim>
launch(at::cuda::detail::TensorInfo<const T,IndexType> input,IndexType inputSliceSize,IndexType outputSliceSize,bool largest,IndexType numInputSlices,IndexType inputWithinSliceStride,at::cuda::detail::TensorInfo<T,IndexType> topK,IndexType topKWithinSliceStride,at::cuda::detail::TensorInfo<int64_t,IndexType> indices,IndexType indicesWithinSliceStride)181 void launch(
182     at::cuda::detail::TensorInfo<const T, IndexType> input,
183     IndexType inputSliceSize,
184     IndexType outputSliceSize, // aka `k`
185     bool largest,
186 
187     IndexType numInputSlices,
188     IndexType inputWithinSliceStride,
189 
190     at::cuda::detail::TensorInfo<T, IndexType> topK,
191     IndexType topKWithinSliceStride,
192 
193     at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
194     IndexType indicesWithinSliceStride) {
195 
196     dim3 grid;
197     TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
198     int warp_size = at::cuda::warp_size();
199     dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
200     gatherTopK<T, IndexType, Dim, /* WithKthValues= */false><<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
201         input,
202         inputSliceSize,
203         outputSliceSize,
204         largest,
205         numInputSlices,
206         inputWithinSliceStride,
207         topK,
208         topKWithinSliceStride,
209         indices,
210         indicesWithinSliceStride,
211         nullptr);
212     C10_CUDA_KERNEL_LAUNCH_CHECK();
213 }
214 } // namespace sbtopk
215 
216 namespace mbtopk { // multi_block_topk
217 
218 // Assumptions:
219 // The number of elements can be larger than UINT32_MAX, but
220 // the number of total blocks can not be larger than UINT32_MAX.
221 // So we can not have more than UINT32_MAX slices. The actual limit
222 // for number of slices could be a few fold smaller than UINT32_MAX,
223 // because we could be using multiple blocks per slice.
224 // Further more, the size of each input slice is also assumped to be
225 // smaller than UINT32_MAX
226 
227 constexpr int BLOCK_THREADS = 256;
228 
229 // Over what radix we are selecting values
230 constexpr int RADIX_BITS = 8;
231 constexpr int RADIX_DIGITS = 1 << RADIX_BITS; // 2 ^ RADIX_BITS
232 constexpr int RADIX_MASK = (RADIX_DIGITS - 1);
233 static_assert(RADIX_DIGITS <= BLOCK_THREADS, "radixFindKthValues kernel requires RADIX_DIGITS <= BLOCK_THREADS");
234 constexpr int MIN_ITEMS_PER_THREAD = 4;
235 constexpr int MAX_ITEMS_PER_THREAD = 64;
236 
237 template <typename T, typename IndexType>
fill(T * x,T value,IndexType size)238 __global__ void fill(T* x, T value, IndexType size) {
239   IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
240   for (IndexType i = idx; i < size; i += gridDim.x * blockDim.x) {
241     x[i] = value;
242   }
243 }
244 
245 // find the kth smallest value,
246 // for largest topk, k_to_find = slice_size - k + 1
247 template <typename T, typename IndexType, typename Bitwise, int Dim>
C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)248 C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)
249 __global__ void radixFindKthValues(
250     at::cuda::detail::TensorInfo<const T, IndexType> input,
251     uint32_t slice_size,
252     uint32_t* ks_to_find,  // size: num_slices
253 
254     uint32_t num_slices,
255     IndexType withinSliceStride,
256 
257     int current_bit,
258     int items_per_thread,
259     uint32_t blocks_per_slice,
260     Bitwise desiredMask,
261 
262     // outputs
263     uint32_t* semaphores,  // size: num_slices
264     Bitwise* desires,      // size: num_slices
265     short* counts,         // size: num_slices * blocks_per_slice * radix_digits
266     T* kthValues           // size: num_slices, only write when current_bit reaches 0
267   ) {
268 
269   int items_per_block = items_per_thread * BLOCK_THREADS;
270   int tidx = threadIdx.x;
271   uint32_t block_idx = getLinearBlockId<uint32_t>();
272   uint32_t slice_idx = block_idx / blocks_per_slice;
273   uint32_t blk_idx_in_slice = block_idx % blocks_per_slice;
274   if (slice_idx >= num_slices) {
275     return;
276   }
277 
278   Bitwise desired = desires[slice_idx];
279   uint32_t k_to_find = ks_to_find[slice_idx];
280   IndexType slice_start_index = at::cuda::detail::IndexToOffset<const T, IndexType, Dim>::get(slice_idx, input);
281   const T* data = &input.data[slice_start_index];
282 
283   typedef cub::BlockScan<uint32_t, BLOCK_THREADS> BlockScan;
284   static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < std::numeric_limits<short>::max(),
285     "blockwise counter too large");
286   union __align__(16) TempStorage {
287     uint32_t digit_counters[RADIX_DIGITS];
288     uint32_t digit_count_cumsum[RADIX_DIGITS]; // only used if this it the last block for this slice
289     typename BlockScan::TempStorage scan_storage;
290   };
291   __shared__ TempStorage temp_storage;
292 
293   // fill digit_counters with zeros
294   if (tidx < RADIX_DIGITS) {
295     temp_storage.digit_counters[tidx] = 0;
296   }
297   __syncthreads();
298 
299   items_per_thread = (blk_idx_in_slice + 1 < blocks_per_slice)
300       ? items_per_thread
301       : at::ceil_div((int64_t)(slice_size - blk_idx_in_slice * items_per_block), (int64_t)BLOCK_THREADS);
302 
303   // collect digit counts and store in shared memory
304   for (int i = 0; i < items_per_thread; ++i) {
305     // Find the start offset for this slice
306     IndexType idx = blk_idx_in_slice * items_per_block + i * BLOCK_THREADS + tidx;
307     if (idx < slice_size) {
308       idx *= withinSliceStride;
309       Bitwise val = TopKTypeConfig<T>::convert(doLdg(&data[idx]));
310       bool has_val = ((val & desiredMask) == (desired & desiredMask));
311       Bitwise digit = at::cuda::Bitfield<Bitwise>::getBitfield(val, current_bit, RADIX_BITS);
312       if (has_val) {
313         atomicAdd(&temp_storage.digit_counters[digit], 1);
314       }
315     }
316   }
317 
318   __syncthreads();
319 
320   // load digit counter to register, one digit per thread
321   static_assert(RADIX_DIGITS <= BLOCK_THREADS, "this kernel requires RADIX_DIGITS <= BLOCK_THREADS");
322   uint32_t digit_count = 0;
323   if (tidx < RADIX_DIGITS) {
324     digit_count = temp_storage.digit_counters[tidx];
325   }
326 
327   // We always write out counts regardless if blocks_per_slice == 1 because
328   // it will be used to compute offsets for `gatherTopK`.
329   if (tidx < RADIX_DIGITS) {
330     counts[block_idx * RADIX_DIGITS + tidx] = digit_count;
331   }
332   // if blocks_per_slice == 1, there is no need to do cross-block reduction
333   // in this case we use counts saved at registers directly
334   if (blocks_per_slice > 1) {
335     __threadfence(); // make sure writes are globally visible
336     __syncthreads(); // make sure all writes are finished before update semaphores
337   }
338 
339   // the last block of each slice accumulates counters from multiple blocks and updates desired and ks_to_find
340   __shared__ bool s_is_last_block_done;
341 
342   if (tidx == 0) {
343     if (blocks_per_slice == 1) {
344       s_is_last_block_done = true;
345     } else {
346       uint32_t blocks_finished_old = atomicAdd(&semaphores[slice_idx], 1);
347       s_is_last_block_done = (blocks_finished_old == blocks_per_slice - 1);
348     }
349   }
350 
351   __syncthreads();
352 
353   if (!s_is_last_block_done)
354     return;
355 
356   // accumulates counters from multiple blocks
357   if (tidx < RADIX_DIGITS && blocks_per_slice > 1) {
358     digit_count = 0;
359     for (int blk = 0; blk < blocks_per_slice; ++blk) {
360       digit_count += counts[(slice_idx * blocks_per_slice + blk) * RADIX_DIGITS + tidx];
361     }
362   }
363 
364   // compute the block-wide inclusive prefix sum
365   uint32_t digit_count_cumsum;
366   BlockScan(temp_storage.scan_storage).InclusiveSum(digit_count, digit_count_cumsum);
367   __syncthreads();
368   // every thread also need the perfix_sum of it's left value for comparison, so save a copy in shared mem
369   if (tidx < RADIX_DIGITS) {
370     temp_storage.digit_count_cumsum[tidx] = digit_count_cumsum;
371   }
372   __syncthreads();
373 
374   if (tidx < RADIX_DIGITS) {
375     uint32_t digit_count_cumsum_left = (tidx == 0) ? 0 : temp_storage.digit_count_cumsum[tidx - 1];
376 
377     // if not the last pass: update desired and ks_to_find
378     // if last pass: write out the kth value
379     if (digit_count_cumsum_left < k_to_find && k_to_find <= digit_count_cumsum) {
380       desired = at::cuda::Bitfield<Bitwise>::setBitfield(desired, tidx, current_bit, RADIX_BITS);
381       desires[slice_idx] = desired;
382       if (current_bit > 0) {
383         ks_to_find[slice_idx] = k_to_find - digit_count_cumsum_left;
384       } else {
385         kthValues[slice_idx] = TopKTypeConfig<T>::deconvert(desired);
386       }
387     }
388   }
389 
390   // reset semaphores for the next pass
391   if (tidx == 0) {
392     semaphores[slice_idx] = 0;
393   }
394 }
395 
396 #if CUB_SUPPORTS_SCAN_BY_KEY()
397 // Assumption: k can not be larger than UINT32_MAX
398 template <typename Bitwise>
C10_LAUNCH_BOUNDS_1(RADIX_DIGITS)399 C10_LAUNCH_BOUNDS_1(RADIX_DIGITS)  // one thread per digit
400 __global__ void computeBlockwiseWithinKCounts(
401   Bitwise* desires,          // size: num_slices
402   short* counts,             // size: num_slices * blocks_per_slice * radix_digits
403   uint32_t blocks_per_slice,
404   int current_bit,
405   bool largest,
406   // outputs:
407   uint32_t* withinKCounts,  // size: num_slices * blocks_per_slice == num_blocks
408   uint32_t num_blocks
409 ) {
410   // This kernel should be launched with the same number of blocks as the `radixFindKthValues` kernel.
411   int tidx = threadIdx.x;
412   uint32_t block_idx = getLinearBlockId<uint32_t>();
413   uint32_t slice_idx = block_idx / blocks_per_slice;
414 
415   // The grid is computed from `getGridFromTiles`, when there are lots of
416   // elements, we will use both blockIdx.x and blockIdx.y, and maybe blockIdx.z
417   // when this is the case, the number of blocks that we are launching can be
418   // more than the number of blocks we need. So we need to check the range of
419   // `block_idx`.
420   if (block_idx >= num_blocks) {
421     return;
422   }
423 
424   Bitwise desired = doLdg(desires + slice_idx);
425   Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS);
426 
427   // if largest, then only threads that has tidx > desired_digit are active
428   // if !largest, then only threads that has tidx < desired_digit are active
429   // each active thread will read the count for its corresponding, and
430   // do warp reduction followed by shared memory reduction to get the total count
431   // non-active thread should not load, and non-active warp should not do reduction.
432   bool warp_is_active, thread_is_active;
433   int warp = tidx / C10_WARP_SIZE;
434   if (largest) {
435     int end_of_warp = warp * C10_WARP_SIZE + C10_WARP_SIZE - 1;
436     warp_is_active = end_of_warp > desired_digit;
437     thread_is_active = tidx > desired_digit;
438   } else {
439     int start_of_warp = warp * C10_WARP_SIZE;
440     warp_is_active = start_of_warp < desired_digit;
441     thread_is_active = tidx < desired_digit;
442   }
443   uint32_t count = 0;
444   if (warp_is_active) {
445     if (thread_is_active) {
446       count = doLdg(counts + block_idx * RADIX_DIGITS + tidx);
447     }
448     for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
449       count += WARP_SHFL_DOWN(count, offset);
450     }
451   }
452 
453   constexpr int num_warps = RADIX_DIGITS / C10_WARP_SIZE;
454   __shared__ uint32_t warp_counts[num_warps];
455   if (tidx % C10_WARP_SIZE == 0) {
456     warp_counts[warp] = count;
457   }
458   __syncthreads();
459   static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE,
460     "Assuming only 1 warp is needed for final reduction");
461   if (warp != 0) {
462     return;
463   }
464   count = 0;
465   if (tidx < num_warps) {
466     count = warp_counts[tidx];
467   }
468   for (int offset = num_warps / 2; offset > 0; offset /= 2) {
469     count += WARP_SHFL_DOWN(count, offset);
470   }
471   if (tidx == 0) {
472     withinKCounts[block_idx] += count;
473   }
474 }
475 
476 // Assumption: slice_size can not be larger than UINT32_MAX
477 template <typename Bitwise>
computeBlockwiseKthCounts(Bitwise * desires,short * counts,uint32_t num_blocks,uint32_t blocks_per_slice,uint32_t * kthCounts)478 __global__ void computeBlockwiseKthCounts(
479   Bitwise* desires,            // size: num_slices
480   short* counts,               // size: num_slices * blocks_per_slice * radix_digits
481   uint32_t num_blocks,         // the number of blocks used by `radixFindKthValues` kernel
482   uint32_t blocks_per_slice,
483   // outputs:
484   uint32_t* kthCounts          // size: num_slices * blocks_per_slice == num_blocks
485 ) {
486   CUDA_KERNEL_LOOP_TYPE(idx, num_blocks, uint32_t) {
487     uint32_t slice_idx = idx / blocks_per_slice;
488     Bitwise desired = doLdg(desires + slice_idx);
489     Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, 0, RADIX_BITS);
490     kthCounts[idx] = doLdg(counts + idx * RADIX_DIGITS + desired_digit);
491   }
492 }
493 
494 template <typename T, typename IndexType, int Dim>
C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)495 C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)
496 __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> input,
497                            IndexType inputSliceSize,
498                            IndexType outputSliceSize, // aka `k`
499                            bool largest,
500 
501                            uint32_t numInputSlices,
502                            IndexType inputWithinSliceStride,
503 
504                            at::cuda::detail::TensorInfo<T, IndexType> topK,
505                            IndexType topKWithinSliceStride,
506 
507                            at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
508                            IndexType indicesWithinSliceStride,
509 
510                            uint32_t items_per_thread,
511                            uint32_t blocks_per_slice,
512 
513                            T *kthValues,
514                            uint32_t* withinKCounts,
515                            uint32_t* kthCounts,
516                            uint32_t num_blocks) {
517 
518   uint32_t items_per_block = items_per_thread * BLOCK_THREADS;
519   uint32_t tidx = threadIdx.x;
520   uint32_t block_idx = getLinearBlockId<uint32_t>();
521 
522   // The grid is computed from `getGridFromTiles`, when there are lots of
523   // elements, we will use both blockIdx.x and blockIdx.y, and maybe blockIdx.z
524   // when this is the case, the number of blocks that we are launching can be
525   // more than the number of blocks we need. So we need to check the range of
526   // `block_idx`.
527   if (block_idx >= num_blocks) {
528     return;
529   }
530 
531   uint32_t slice_idx = block_idx / blocks_per_slice;
532   uint32_t blk_idx_in_slice = block_idx % blocks_per_slice;
533 
534   items_per_thread = (blk_idx_in_slice + 1 < blocks_per_slice)
535       ? items_per_thread
536       : at::ceil_div((int64_t)(inputSliceSize - blk_idx_in_slice * items_per_block), (int64_t)BLOCK_THREADS);
537 
538   // Find the start offset for our slice
539   IndexType sliceStartIndex =
540     at::cuda::detail::IndexToOffset<const T, IndexType, Dim>::get(slice_idx, input);
541   IndexType topKSliceStartIndex =
542     at::cuda::detail::IndexToOffset<T, IndexType, Dim>::get(slice_idx, topK);
543   IndexType indicesSliceStartIndex =
544     at::cuda::detail::IndexToOffset<int64_t, IndexType, Dim>::get(slice_idx, indices);
545 
546   const T* inputSliceStart = &input.data[sliceStartIndex];
547   T* topKSliceStart = &topK.data[topKSliceStartIndex];
548   int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
549 
550   // Find the k-th highest element in our input
551   T kthValue = kthValues[slice_idx];
552   const auto kthValueConverted = at::native::TopKTypeConfig<T>::convert(kthValue);
553 
554   // Find the start index in output tensor of this block
555   uint32_t startWithinK = 0;
556   if (blk_idx_in_slice > 0) {
557     startWithinK = withinKCounts[block_idx - 1];
558   }
559   uint32_t startKth = withinKCounts[slice_idx * blocks_per_slice + blocks_per_slice - 1];
560   if (blk_idx_in_slice > 0) {
561     startKth += kthCounts[block_idx - 1];
562   }
563 
564   // Read input, select topk out and write
565   typedef cub::BlockScan<uint32_t, BLOCK_THREADS> BlockScan;
566   __shared__ typename BlockScan::TempStorage temp_storage;
567   for (int i = 0; i < items_per_thread; ++i) {
568     // Find the start offset for this slice
569     IndexType idx = blk_idx_in_slice * items_per_block + i * BLOCK_THREADS + tidx;
570     T val;
571     int withinK = 0;
572     int kth = 0;
573     if (idx < inputSliceSize) {
574       val = doLdg(inputSliceStart + idx * inputWithinSliceStride);
575       const auto valConverted = at::native::TopKTypeConfig<T>::convert(val);
576       withinK = (largest ? valConverted > kthValueConverted : valConverted < kthValueConverted);
577       kth = (valConverted == kthValueConverted);
578     }
579 
580     uint32_t withinKIndex;
581     uint32_t numWithinK;
582     BlockScan(temp_storage).ExclusiveSum(withinK, withinKIndex, numWithinK);
583     __syncthreads();
584     if (withinK) {
585       uint32_t offset = withinKIndex + startWithinK;
586       topKSliceStart[offset * topKWithinSliceStride] = val;
587       indicesSliceStart[offset * indicesWithinSliceStride] = idx;
588     }
589     startWithinK += numWithinK;
590 
591     if (startKth < outputSliceSize) {
592       uint32_t kthIndex;
593       uint32_t numKth;
594       BlockScan(temp_storage).ExclusiveSum(kth, kthIndex, numKth);
595       __syncthreads();
596       if (kth) {
597         uint32_t offset = kthIndex + startKth;
598         if (offset < outputSliceSize) {
599           topKSliceStart[offset * topKWithinSliceStride] = val;
600           indicesSliceStart[offset * indicesWithinSliceStride] = idx;
601         }
602       }
603       startKth += numKth;
604     }
605   }
606 }
607 #endif
608 
get_items_per_thread(uint64_t num_slices,uint64_t slice_size)609 int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) {
610   // occupancy of this kernel is limited by registers per threads
611   constexpr int REGS_PER_THREAD = 40; // from nsight launch statistics
612   constexpr int REGS_PER_BLOCK = REGS_PER_THREAD * BLOCK_THREADS;
613   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
614   int mpc = prop->multiProcessorCount;
615 #if defined(USE_ROCM)
616   int regs_per_mp = prop->regsPerBlock;
617   int max_blocks_per_mp = 32;
618 #else
619   int regs_per_mp = prop->regsPerMultiprocessor;
620 #if !defined(USE_ROCM)
621   int max_blocks_per_mp = prop->maxBlocksPerMultiProcessor;
622 #else
623   int max_blocks_per_mp = 32;
624 #endif
625 #endif
626   int blocks_per_mp = std::min(regs_per_mp / REGS_PER_BLOCK, max_blocks_per_mp);
627   int64_t items_per_thread = at::ceil_div((int64_t)(slice_size * num_slices), (int64_t)(mpc * blocks_per_mp * BLOCK_THREADS));
628   items_per_thread = std::max(MIN_ITEMS_PER_THREAD, std::min((int)items_per_thread, MAX_ITEMS_PER_THREAD)); // clamp to (4, 64)
629   return items_per_thread;
630 }
631 
632 class BlockIdxToKey {
633   uint32_t blocks_per_slice;
634 public:
BlockIdxToKey(uint32_t blocks_per_slice)635   BlockIdxToKey(uint32_t blocks_per_slice): blocks_per_slice(blocks_per_slice) {}
operator ()(uint32_t blk) const636   __device__ __forceinline__ uint32_t operator()(uint32_t blk) const {
637     return blk / blocks_per_slice;
638   }
639 };
640 
641 template <typename T, typename IndexType, int Dim>
launch(at::cuda::detail::TensorInfo<const T,IndexType> input,IndexType inputSliceSize,IndexType outputSliceSize,bool largest,uint32_t numInputSlices,IndexType inputWithinSliceStride,at::cuda::detail::TensorInfo<T,IndexType> topK,IndexType topKWithinSliceStride,at::cuda::detail::TensorInfo<int64_t,IndexType> indices,IndexType indicesWithinSliceStride)642 void launch(
643     at::cuda::detail::TensorInfo<const T, IndexType> input,
644     IndexType inputSliceSize,
645     IndexType outputSliceSize, // aka `k`
646     bool largest,
647 
648     uint32_t numInputSlices,
649     IndexType inputWithinSliceStride,
650 
651     at::cuda::detail::TensorInfo<T, IndexType> topK,
652     IndexType topKWithinSliceStride,
653 
654     at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
655     IndexType indicesWithinSliceStride) {
656   auto stream = c10::cuda::getCurrentCUDAStream();
657 
658   // configure items_per_thread based on device architecture and input size
659   int items_per_thread = get_items_per_thread(numInputSlices, inputSliceSize);
660   int items_per_block = items_per_thread * BLOCK_THREADS;
661 
662   using Bitwise = typename TopKTypeConfig<T>::RadixType;
663   uint32_t blocks_per_slice = at::ceil_div((int64_t)inputSliceSize, (int64_t)items_per_block);
664   uint32_t num_blocks = numInputSlices * blocks_per_slice;
665 
666   // temporary storage
667   auto& allocator = *c10::cuda::CUDACachingAllocator::get();
668 
669   auto kthValues_buffer = allocator.allocate(numInputSlices * sizeof(T));
670   T* kthValues = reinterpret_cast<T*>(kthValues_buffer.get());
671 
672   TORCH_CHECK(blocks_per_slice <= std::numeric_limits<uint32_t>::max(), "blocks_per_slice larger than uint32 maximum is not supported");
673   auto semaphores_buffer = allocator.allocate(numInputSlices * sizeof(uint32_t));
674   uint32_t* semaphores = reinterpret_cast<uint32_t*>(semaphores_buffer.get());
675   AT_CUDA_CHECK(cudaMemsetAsync(semaphores, 0, numInputSlices * sizeof(uint32_t), stream));
676 
677   auto ks_to_find_buffer = allocator.allocate(numInputSlices * sizeof(uint32_t));
678   uint32_t* ks_to_find = reinterpret_cast<uint32_t*>(ks_to_find_buffer.get());
679   uint32_t k_to_find = largest ? inputSliceSize - outputSliceSize + 1: outputSliceSize;
680   fill<uint32_t><<<std::min(((int64_t)numInputSlices + 511) / 512, (int64_t)1073741824), 512, 0, stream>>>(
681     ks_to_find, k_to_find, numInputSlices);
682   C10_CUDA_KERNEL_LAUNCH_CHECK();
683 
684   auto desired_buffer = allocator.allocate(numInputSlices * sizeof(Bitwise));
685   Bitwise* desired = reinterpret_cast<Bitwise*>(desired_buffer.get());
686 
687   auto counts_buffer = allocator.allocate(num_blocks * RADIX_DIGITS * sizeof(short));
688   short* counts = reinterpret_cast<short*>(counts_buffer.get());
689   static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < std::numeric_limits<short>::max(),
690     "blockwise counter too large");
691 
692 #if CUB_SUPPORTS_SCAN_BY_KEY()
693   auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
694   uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
695   AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream));
696 
697   auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
698   uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get());
699 #endif
700 
701   Bitwise desiredMask = 0;
702   dim3 grid;
703   TORCH_INTERNAL_ASSERT(getGridFromTiles(num_blocks, grid), "Too many slices for topk");
704   dim3 block(BLOCK_THREADS);
705 
706   // iterate radix bits for multiple passes
707   for (int current_bit = sizeof(T) * 8 - RADIX_BITS; current_bit >= 0; current_bit -= RADIX_BITS) {
708     radixFindKthValues<T, IndexType, Bitwise, Dim><<<grid, block, 0, stream>>>(
709         input,
710         inputSliceSize,
711         ks_to_find,
712         numInputSlices,
713         inputWithinSliceStride,
714         current_bit,
715         items_per_thread,
716         blocks_per_slice,
717         desiredMask,
718         semaphores,
719         desired,
720         counts,
721         kthValues);
722     C10_CUDA_KERNEL_LAUNCH_CHECK();
723 #if CUB_SUPPORTS_SCAN_BY_KEY()
724     computeBlockwiseWithinKCounts<Bitwise><<<grid, RADIX_DIGITS, 0, stream>>>(
725       desired, counts, blocks_per_slice, current_bit, largest, withinKCounts, num_blocks);
726     C10_CUDA_KERNEL_LAUNCH_CHECK();
727 #endif
728     desiredMask = at::cuda::Bitfield<Bitwise>::setBitfield(desiredMask, RADIX_MASK, current_bit, RADIX_BITS);
729   }
730 
731 #if CUB_SUPPORTS_SCAN_BY_KEY()
732   computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>(
733     desired, counts, num_blocks, blocks_per_slice, kthCounts);
734   C10_CUDA_KERNEL_LAUNCH_CHECK();
735   // Do a prefix scan of withinKCounts and kthCounts using slice_idx as keys to get the starting index of each block
736   using counting_iter_t = cub::CountingInputIterator<uint32_t, uint32_t>;
737   using slice_idx_iter_t = cub::TransformInputIterator<uint32_t, BlockIdxToKey, counting_iter_t>;
738   slice_idx_iter_t slice_idx_iter(counting_iter_t(0), BlockIdxToKey(blocks_per_slice));
739   at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, withinKCounts, withinKCounts, num_blocks);
740   at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, kthCounts, kthCounts, num_blocks);
741   // copy topk values to output tensor
742   gatherTopK<T, IndexType, Dim><<<grid, block, 0, stream>>>(
743     input, inputSliceSize, outputSliceSize, largest, numInputSlices, inputWithinSliceStride,
744     topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread,
745     blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks);
746   C10_CUDA_KERNEL_LAUNCH_CHECK();
747 #else
748   // Find topk values based on kth values
749   {
750     dim3 grid;
751     TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
752     int warp_size = at::cuda::warp_size();
753     dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
754     sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>(
755         input,
756         inputSliceSize,
757         outputSliceSize,
758         largest,
759         numInputSlices,
760         inputWithinSliceStride,
761         topK,
762         topKWithinSliceStride,
763         indices,
764         indicesWithinSliceStride,
765         kthValues);
766     C10_CUDA_KERNEL_LAUNCH_CHECK();
767   }
768 #endif
769 }
770 
771 } // namespace mbtopk
772 
should_use_multiblock(int64_t num_slices,int64_t slice_size)773 bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
774   if (num_slices > std::numeric_limits<uint32_t>::max() ||
775       slice_size > std::numeric_limits<uint32_t>::max()) return false;
776 #if CUB_SUPPORTS_SCAN_BY_KEY()
777   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267
778   return (num_slices <= 20 && slice_size >= 20000) ||
779       (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) ||
780       (num_slices > 40 && num_slices <= 80 && slice_size >= 8000) ||
781       (num_slices > 80 && num_slices < 200 && slice_size >= 5000) ||
782       (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) ||
783       (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) ||
784       (num_slices > 4000 && slice_size >= 400);
785 #else
786   // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081
787   return (num_slices <= 400 && slice_size >= 5000) ||
788       (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) ||
789       (num_slices >= 4000 && slice_size >= 300);
790 #endif
791 }
792 
launch_gather_topk_kernel(const TensorBase & self,int64_t k,int64_t dim,bool largest,const TensorBase & values,const TensorBase & indices)793 void launch_gather_topk_kernel(
794     const TensorBase& self, int64_t k, int64_t dim, bool largest,
795     const TensorBase& values, const TensorBase& indices) {
796   int numDims = self.dim();
797   numDims = numDims == 0 ? 1 : numDims;
798   TORCH_CHECK(numDims <= MAX_DIMS, "input tensor has too many dimensions");
799   int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
800 
801   auto input = self.contiguous();
802   // static_cast is required to ensure that the correct type (INDEX_T)
803   // is provided to the kernel for the arguments.
804 #define RUN_K(INDEX_T, DIM, LAUNCH_FUNCTION_NAME)                       \
805   LAUNCH_FUNCTION_NAME<scalar_t, INDEX_T, DIM>(                         \
806       inputInfo,                                                        \
807       static_cast<INDEX_T>(sliceSize),                                  \
808       static_cast<INDEX_T>(k),                                          \
809       largest,                                                          \
810       static_cast<INDEX_T>(numInputSlices),                             \
811       /* The actual dimension that the k-selection is running in */     \
812       /* may have changed from collapseDims() */                        \
813       static_cast<INDEX_T>(inputInfo.strides[collapseInputDim]),        \
814       topKInfo,                                                         \
815       static_cast<INDEX_T>(topKInfo.strides[collapseTopKDim]),          \
816       indicesInfo,                                                      \
817       static_cast<INDEX_T>(indicesInfo.strides[collapseIndicesDim]));
818 
819 #define RUN_MB(INDEX_T, DIM)                                            \
820   if (should_use_multiblock(numInputSlices, sliceSize)) {               \
821     RUN_K(INDEX_T, DIM, mbtopk::launch);                                \
822   } else {                                                              \
823     RUN_K(INDEX_T, DIM, sbtopk::launch);                                \
824   }
825 
826 #define RUN_DIM(INDEX_T)                        \
827   if (allDims == 1) {                           \
828     RUN_MB(INDEX_T, 1);                         \
829   } else if (allDims == 2) {                    \
830     RUN_MB(INDEX_T, 2);                         \
831   } else if (allDims == 3) {                    \
832     RUN_MB(INDEX_T, 3);                         \
833   } else {                                      \
834     RUN_MB(INDEX_T, -1);                        \
835   }
836 
837 #define RUN_T(INDEX_T)                                                    \
838   AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "topk_out_cuda", [&] { \
839     at::cuda::detail::TensorInfo<const scalar_t, INDEX_T> inputInfo =     \
840       at::cuda::detail::getTensorInfo<const scalar_t, INDEX_T>(input);    \
841     at::cuda::detail::TensorInfo<scalar_t, INDEX_T> topKInfo =            \
842       at::cuda::detail::getTensorInfo<scalar_t, INDEX_T>(values);         \
843     at::cuda::detail::TensorInfo<int64_t, INDEX_T> indicesInfo =          \
844       at::cuda::detail::getTensorInfo<int64_t, INDEX_T>(indices);         \
845     /* tensorInfoLegacyIfScalar*/                                         \
846     if (!input.dim()) {                                                   \
847       inputInfo.dims = 1;                                                 \
848       inputInfo.sizes[0] = 1;                                             \
849       inputInfo.strides[0] = 1;                                           \
850       topKInfo.dims = 1;                                                  \
851       topKInfo.sizes[0] = 1;                                              \
852       topKInfo.strides[0] = 1;                                            \
853       indicesInfo.dims = 1;                                               \
854       indicesInfo.sizes[0] = 1;                                           \
855       indicesInfo.strides[0] = 1;                                         \
856     }                                                                     \
857     /* We use these structures solely to find the offset to */            \
858     /* each slice we are operating on */                                  \
859     inputInfo.sizes[dim] = 1;                                             \
860     topKInfo.sizes[dim] = 1;                                              \
861     indicesInfo.sizes[dim] = 1;                                           \
862     /* stash the stride of dim because it can be accidentally collapsed */ \
863     auto strideTopK = topKInfo.strides[dim];                              \
864     auto strideIndices = indicesInfo.strides[dim];                        \
865     /* Collapse all other dims */                                         \
866     int collapseInputDim = inputInfo.collapseDims(dim);                   \
867     int collapseTopKDim = topKInfo.collapseDims(dim);                     \
868     int collapseIndicesDim = indicesInfo.collapseDims(dim);               \
869     /* restore stride in case it was collapsed */                         \
870     topKInfo.strides[collapseTopKDim] = strideTopK;                       \
871     indicesInfo.strides[collapseIndicesDim] = strideIndices;              \
872     int64_t numInputSlices = 1;                                           \
873     for (int i = 0; i < inputInfo.dims; ++i) {                            \
874       numInputSlices *= inputInfo.sizes[i];                               \
875     }                                                                     \
876                                                                           \
877     /* This is used as a template parameter to calculate indices. */      \
878     /* We only specialize it if all collapsed dim sizes are the */        \
879     /* same; otherwise, we use -1 which is the specialization */          \
880     /* parameter for arbitrary dimensions */                              \
881     int allDims = inputInfo.dims;                                         \
882     if (topKInfo.dims != allDims || indicesInfo.dims != allDims) {        \
883       allDims = -1;                                                       \
884     }                                                                     \
885                                                                           \
886     RUN_DIM(INDEX_T);                                                     \
887   });
888 
889   // the below is safe with 0-dimensional tensors because it is based on
890   // TensorInfo which implicitly expands to 1-dimensional.
891   if (input.numel() > 0) {
892     // Based on required index size, run the algorithm with the
893     // appropriate index type
894     if (at::cuda::detail::canUse32BitIndexMath(input) &&
895         at::cuda::detail::canUse32BitIndexMath(values) &&
896         at::cuda::detail::canUse32BitIndexMath(indices)) {
897       RUN_T(uint32_t);
898     } else {
899       RUN_T(uint64_t);
900     }
901   }
902 #undef RUN_T
903 #undef RUN_DIM
904 #undef RUN_K
905 }
906 
907 } // at::native
908