1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cuda/TensorTopK.h>
3 #include <ATen/core/TensorBase.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/detail/TensorInfo.cuh>
8 #include <ATen/cuda/detail/OffsetCalculator.cuh>
9 #include <ATen/cuda/ScanUtils.cuh>
10 #include <ATen/cuda/AsmUtils.cuh>
11 #include <ATen/cuda/DeviceUtils.cuh>
12 #include <ATen/native/cuda/SortingCommon.cuh>
13 #include <ATen/native/cuda/SortingRadixSelect.cuh>
14 #include <ATen/cuda/cub.cuh>
15 #include <c10/cuda/CUDACachingAllocator.h>
16 #include <ATen/cuda/detail/KernelUtils.h>
17
18 #include <c10/macros/Macros.h>
19
20 using namespace at::native;
21
22 namespace at::native {
23
24 // TODO: remove this when CUDA <11.6 is no longer supported
disable_sort_for_topk()25 bool disable_sort_for_topk() {
26 return CUB_SUPPORTS_SCAN_BY_KEY();
27 }
28
29 namespace sbtopk { // single_block_topk
30
31 template <typename T>
32 struct AddOp {
operator ()at::native::sbtopk::AddOp33 __device__ __forceinline__ T operator()(T const &lhs, T const &rhs) {
34 return (lhs + rhs);
35 }
36 };
37
38 template <typename T, typename IndexType, int Dim, bool WithKthValues>
39 C10_LAUNCH_BOUNDS_1(1024)
gatherTopK(at::cuda::detail::TensorInfo<const T,IndexType> input,IndexType inputSliceSize,IndexType outputSliceSize,bool largest,IndexType numInputSlices,IndexType inputWithinSliceStride,at::cuda::detail::TensorInfo<T,IndexType> topK,IndexType topKWithinSliceStride,at::cuda::detail::TensorInfo<int64_t,IndexType> indices,IndexType indicesWithinSliceStride,T * kthValues)40 __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> input,
41 IndexType inputSliceSize,
42 IndexType outputSliceSize, // aka `k`
43 bool largest,
44
45 IndexType numInputSlices,
46 IndexType inputWithinSliceStride,
47
48 at::cuda::detail::TensorInfo<T, IndexType> topK,
49 IndexType topKWithinSliceStride,
50
51 at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
52 IndexType indicesWithinSliceStride,
53 T* kthValues) {
54 // Indices are limited to integer fp precision, so counts can fit in
55 // int32, regardless of IndexType
56 #if defined(USE_ROCM)
57 __shared__ int smem[64];
58 #else
59 __shared__ int smem[32]; // one per each warp, up to warp limit
60 #endif
61 IndexType slice = getLinearBlockId<IndexType>();
62 if (slice >= numInputSlices) {
63 return;
64 }
65
66 // Find the start offset for our slice
67 IndexType sliceStartIndex =
68 at::cuda::detail::IndexToOffset<const T, IndexType, Dim>::get(slice, input);
69 IndexType topKSliceStartIndex =
70 at::cuda::detail::IndexToOffset<T, IndexType, Dim>::get(slice, topK);
71 IndexType indicesSliceStartIndex =
72 at::cuda::detail::IndexToOffset<int64_t, IndexType, Dim>::get(slice, indices);
73
74 const T* inputSliceStart = &input.data[sliceStartIndex];
75 T* topKSliceStart = &topK.data[topKSliceStartIndex];
76 int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
77
78 // Find the k-th highest element in our input
79 T topKValue;
80 if (WithKthValues){
81 topKValue = kthValues[slice];
82 } else {
83 topKValue = static_cast<T>(0);
84 radixSelect<T, typename TopKTypeConfig<T>::RadixType, IndexType>(
85 inputSliceStart, outputSliceSize, largest,
86 inputSliceSize, inputWithinSliceStride,
87 smem, &topKValue);
88 }
89 const auto topKConverted = at::native::TopKTypeConfig<T>::convert(topKValue);
90
91 // Every value that is strictly less/greater than `pattern`
92 // (depending on sort dir) in sorted int format is in the top-K.
93 // The top-K value itself might not be unique.
94 //
95 // Since there are a variable number of elements that we see that
96 // are within the top-k, we don't know at what index to write out
97 // the resulting values.
98 // In order to get this, we perform an exclusive prefix sum of
99 // `hasTopK`. This will return the resulting index into which we
100 // need to write the result, if a thread has a result.
101
102 // All threads need to participate in the loop and the prefix sum,
103 // but not necessarily in the load; hence loop bounds being rounded
104 // up to a multiple of the block dim.
105 IndexType numIterations = round_up(inputSliceSize, (IndexType) blockDim.x);
106 IndexType writeIndexStart = 0;
107
108 for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
109 bool inRange = (i < inputSliceSize);
110 T v =
111 inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : static_cast<T>(0);
112 const auto convertedV = at::native::TopKTypeConfig<T>::convert(v);
113 bool hasTopK;
114 if (largest) {
115 hasTopK = inRange && (convertedV > topKConverted);
116 } else {
117 hasTopK = inRange && (convertedV < topKConverted);
118 }
119
120 int index;
121 int carry;
122 at::cuda::exclusiveBinaryPrefixScan<int, true>(
123 smem, hasTopK, &index, &carry, AddOp<int>());
124
125 if (hasTopK) {
126 int writeIndex = writeIndexStart + index;
127 CUDA_KERNEL_ASSERT(writeIndex < outputSliceSize);
128
129 IndexType topKOffset = writeIndex * topKWithinSliceStride;
130 IndexType indexOffset = writeIndex * indicesWithinSliceStride;
131
132 topKSliceStart[topKOffset] = v;
133 indicesSliceStart[indexOffset] = i;
134 }
135
136 writeIndexStart += carry;
137 }
138
139 // We need to fill in the rest with actual == top-K values.
140 // The number that we need is outputSliceSize -
141 // writeIndexStart. There might be more than that number available,
142 // in which case we have to choose the first seen set. We do this
143 // via a prefix sum to calculate indices for writing results.
144 CUDA_KERNEL_ASSERT(outputSliceSize >= writeIndexStart);
145 IndexType topKRemaining = (outputSliceSize - writeIndexStart);
146
147 for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
148 bool inRange = (i < inputSliceSize);
149 T v =
150 inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : static_cast<T>(0);
151 const auto convertedV = at::native::TopKTypeConfig<T>::convert(v);
152 bool hasTopK = inRange && (convertedV == topKConverted);
153
154 int index;
155 int carry;
156 at::cuda::exclusiveBinaryPrefixScan<int, true>(
157 smem, hasTopK, &index, &carry, AddOp<int>());
158
159 if (hasTopK && index < topKRemaining) {
160 int writeIndex = writeIndexStart + index;
161 CUDA_KERNEL_ASSERT(writeIndex < outputSliceSize);
162
163 IndexType topKOffset = writeIndex * topKWithinSliceStride;
164 IndexType indexOffset = writeIndex * indicesWithinSliceStride;
165
166 topKSliceStart[topKOffset] = v;
167 indicesSliceStart[indexOffset] = i;
168 }
169
170 if (carry >= topKRemaining) {
171 break;
172 }
173
174 topKRemaining -= carry;
175 writeIndexStart += carry;
176 }
177
178 };
179
180 template <typename T, typename IndexType, int Dim>
launch(at::cuda::detail::TensorInfo<const T,IndexType> input,IndexType inputSliceSize,IndexType outputSliceSize,bool largest,IndexType numInputSlices,IndexType inputWithinSliceStride,at::cuda::detail::TensorInfo<T,IndexType> topK,IndexType topKWithinSliceStride,at::cuda::detail::TensorInfo<int64_t,IndexType> indices,IndexType indicesWithinSliceStride)181 void launch(
182 at::cuda::detail::TensorInfo<const T, IndexType> input,
183 IndexType inputSliceSize,
184 IndexType outputSliceSize, // aka `k`
185 bool largest,
186
187 IndexType numInputSlices,
188 IndexType inputWithinSliceStride,
189
190 at::cuda::detail::TensorInfo<T, IndexType> topK,
191 IndexType topKWithinSliceStride,
192
193 at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
194 IndexType indicesWithinSliceStride) {
195
196 dim3 grid;
197 TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
198 int warp_size = at::cuda::warp_size();
199 dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
200 gatherTopK<T, IndexType, Dim, /* WithKthValues= */false><<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
201 input,
202 inputSliceSize,
203 outputSliceSize,
204 largest,
205 numInputSlices,
206 inputWithinSliceStride,
207 topK,
208 topKWithinSliceStride,
209 indices,
210 indicesWithinSliceStride,
211 nullptr);
212 C10_CUDA_KERNEL_LAUNCH_CHECK();
213 }
214 } // namespace sbtopk
215
216 namespace mbtopk { // multi_block_topk
217
218 // Assumptions:
219 // The number of elements can be larger than UINT32_MAX, but
220 // the number of total blocks can not be larger than UINT32_MAX.
221 // So we can not have more than UINT32_MAX slices. The actual limit
222 // for number of slices could be a few fold smaller than UINT32_MAX,
223 // because we could be using multiple blocks per slice.
224 // Further more, the size of each input slice is also assumped to be
225 // smaller than UINT32_MAX
226
227 constexpr int BLOCK_THREADS = 256;
228
229 // Over what radix we are selecting values
230 constexpr int RADIX_BITS = 8;
231 constexpr int RADIX_DIGITS = 1 << RADIX_BITS; // 2 ^ RADIX_BITS
232 constexpr int RADIX_MASK = (RADIX_DIGITS - 1);
233 static_assert(RADIX_DIGITS <= BLOCK_THREADS, "radixFindKthValues kernel requires RADIX_DIGITS <= BLOCK_THREADS");
234 constexpr int MIN_ITEMS_PER_THREAD = 4;
235 constexpr int MAX_ITEMS_PER_THREAD = 64;
236
237 template <typename T, typename IndexType>
fill(T * x,T value,IndexType size)238 __global__ void fill(T* x, T value, IndexType size) {
239 IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
240 for (IndexType i = idx; i < size; i += gridDim.x * blockDim.x) {
241 x[i] = value;
242 }
243 }
244
245 // find the kth smallest value,
246 // for largest topk, k_to_find = slice_size - k + 1
247 template <typename T, typename IndexType, typename Bitwise, int Dim>
C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)248 C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)
249 __global__ void radixFindKthValues(
250 at::cuda::detail::TensorInfo<const T, IndexType> input,
251 uint32_t slice_size,
252 uint32_t* ks_to_find, // size: num_slices
253
254 uint32_t num_slices,
255 IndexType withinSliceStride,
256
257 int current_bit,
258 int items_per_thread,
259 uint32_t blocks_per_slice,
260 Bitwise desiredMask,
261
262 // outputs
263 uint32_t* semaphores, // size: num_slices
264 Bitwise* desires, // size: num_slices
265 short* counts, // size: num_slices * blocks_per_slice * radix_digits
266 T* kthValues // size: num_slices, only write when current_bit reaches 0
267 ) {
268
269 int items_per_block = items_per_thread * BLOCK_THREADS;
270 int tidx = threadIdx.x;
271 uint32_t block_idx = getLinearBlockId<uint32_t>();
272 uint32_t slice_idx = block_idx / blocks_per_slice;
273 uint32_t blk_idx_in_slice = block_idx % blocks_per_slice;
274 if (slice_idx >= num_slices) {
275 return;
276 }
277
278 Bitwise desired = desires[slice_idx];
279 uint32_t k_to_find = ks_to_find[slice_idx];
280 IndexType slice_start_index = at::cuda::detail::IndexToOffset<const T, IndexType, Dim>::get(slice_idx, input);
281 const T* data = &input.data[slice_start_index];
282
283 typedef cub::BlockScan<uint32_t, BLOCK_THREADS> BlockScan;
284 static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < std::numeric_limits<short>::max(),
285 "blockwise counter too large");
286 union __align__(16) TempStorage {
287 uint32_t digit_counters[RADIX_DIGITS];
288 uint32_t digit_count_cumsum[RADIX_DIGITS]; // only used if this it the last block for this slice
289 typename BlockScan::TempStorage scan_storage;
290 };
291 __shared__ TempStorage temp_storage;
292
293 // fill digit_counters with zeros
294 if (tidx < RADIX_DIGITS) {
295 temp_storage.digit_counters[tidx] = 0;
296 }
297 __syncthreads();
298
299 items_per_thread = (blk_idx_in_slice + 1 < blocks_per_slice)
300 ? items_per_thread
301 : at::ceil_div((int64_t)(slice_size - blk_idx_in_slice * items_per_block), (int64_t)BLOCK_THREADS);
302
303 // collect digit counts and store in shared memory
304 for (int i = 0; i < items_per_thread; ++i) {
305 // Find the start offset for this slice
306 IndexType idx = blk_idx_in_slice * items_per_block + i * BLOCK_THREADS + tidx;
307 if (idx < slice_size) {
308 idx *= withinSliceStride;
309 Bitwise val = TopKTypeConfig<T>::convert(doLdg(&data[idx]));
310 bool has_val = ((val & desiredMask) == (desired & desiredMask));
311 Bitwise digit = at::cuda::Bitfield<Bitwise>::getBitfield(val, current_bit, RADIX_BITS);
312 if (has_val) {
313 atomicAdd(&temp_storage.digit_counters[digit], 1);
314 }
315 }
316 }
317
318 __syncthreads();
319
320 // load digit counter to register, one digit per thread
321 static_assert(RADIX_DIGITS <= BLOCK_THREADS, "this kernel requires RADIX_DIGITS <= BLOCK_THREADS");
322 uint32_t digit_count = 0;
323 if (tidx < RADIX_DIGITS) {
324 digit_count = temp_storage.digit_counters[tidx];
325 }
326
327 // We always write out counts regardless if blocks_per_slice == 1 because
328 // it will be used to compute offsets for `gatherTopK`.
329 if (tidx < RADIX_DIGITS) {
330 counts[block_idx * RADIX_DIGITS + tidx] = digit_count;
331 }
332 // if blocks_per_slice == 1, there is no need to do cross-block reduction
333 // in this case we use counts saved at registers directly
334 if (blocks_per_slice > 1) {
335 __threadfence(); // make sure writes are globally visible
336 __syncthreads(); // make sure all writes are finished before update semaphores
337 }
338
339 // the last block of each slice accumulates counters from multiple blocks and updates desired and ks_to_find
340 __shared__ bool s_is_last_block_done;
341
342 if (tidx == 0) {
343 if (blocks_per_slice == 1) {
344 s_is_last_block_done = true;
345 } else {
346 uint32_t blocks_finished_old = atomicAdd(&semaphores[slice_idx], 1);
347 s_is_last_block_done = (blocks_finished_old == blocks_per_slice - 1);
348 }
349 }
350
351 __syncthreads();
352
353 if (!s_is_last_block_done)
354 return;
355
356 // accumulates counters from multiple blocks
357 if (tidx < RADIX_DIGITS && blocks_per_slice > 1) {
358 digit_count = 0;
359 for (int blk = 0; blk < blocks_per_slice; ++blk) {
360 digit_count += counts[(slice_idx * blocks_per_slice + blk) * RADIX_DIGITS + tidx];
361 }
362 }
363
364 // compute the block-wide inclusive prefix sum
365 uint32_t digit_count_cumsum;
366 BlockScan(temp_storage.scan_storage).InclusiveSum(digit_count, digit_count_cumsum);
367 __syncthreads();
368 // every thread also need the perfix_sum of it's left value for comparison, so save a copy in shared mem
369 if (tidx < RADIX_DIGITS) {
370 temp_storage.digit_count_cumsum[tidx] = digit_count_cumsum;
371 }
372 __syncthreads();
373
374 if (tidx < RADIX_DIGITS) {
375 uint32_t digit_count_cumsum_left = (tidx == 0) ? 0 : temp_storage.digit_count_cumsum[tidx - 1];
376
377 // if not the last pass: update desired and ks_to_find
378 // if last pass: write out the kth value
379 if (digit_count_cumsum_left < k_to_find && k_to_find <= digit_count_cumsum) {
380 desired = at::cuda::Bitfield<Bitwise>::setBitfield(desired, tidx, current_bit, RADIX_BITS);
381 desires[slice_idx] = desired;
382 if (current_bit > 0) {
383 ks_to_find[slice_idx] = k_to_find - digit_count_cumsum_left;
384 } else {
385 kthValues[slice_idx] = TopKTypeConfig<T>::deconvert(desired);
386 }
387 }
388 }
389
390 // reset semaphores for the next pass
391 if (tidx == 0) {
392 semaphores[slice_idx] = 0;
393 }
394 }
395
396 #if CUB_SUPPORTS_SCAN_BY_KEY()
397 // Assumption: k can not be larger than UINT32_MAX
398 template <typename Bitwise>
C10_LAUNCH_BOUNDS_1(RADIX_DIGITS)399 C10_LAUNCH_BOUNDS_1(RADIX_DIGITS) // one thread per digit
400 __global__ void computeBlockwiseWithinKCounts(
401 Bitwise* desires, // size: num_slices
402 short* counts, // size: num_slices * blocks_per_slice * radix_digits
403 uint32_t blocks_per_slice,
404 int current_bit,
405 bool largest,
406 // outputs:
407 uint32_t* withinKCounts, // size: num_slices * blocks_per_slice == num_blocks
408 uint32_t num_blocks
409 ) {
410 // This kernel should be launched with the same number of blocks as the `radixFindKthValues` kernel.
411 int tidx = threadIdx.x;
412 uint32_t block_idx = getLinearBlockId<uint32_t>();
413 uint32_t slice_idx = block_idx / blocks_per_slice;
414
415 // The grid is computed from `getGridFromTiles`, when there are lots of
416 // elements, we will use both blockIdx.x and blockIdx.y, and maybe blockIdx.z
417 // when this is the case, the number of blocks that we are launching can be
418 // more than the number of blocks we need. So we need to check the range of
419 // `block_idx`.
420 if (block_idx >= num_blocks) {
421 return;
422 }
423
424 Bitwise desired = doLdg(desires + slice_idx);
425 Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS);
426
427 // if largest, then only threads that has tidx > desired_digit are active
428 // if !largest, then only threads that has tidx < desired_digit are active
429 // each active thread will read the count for its corresponding, and
430 // do warp reduction followed by shared memory reduction to get the total count
431 // non-active thread should not load, and non-active warp should not do reduction.
432 bool warp_is_active, thread_is_active;
433 int warp = tidx / C10_WARP_SIZE;
434 if (largest) {
435 int end_of_warp = warp * C10_WARP_SIZE + C10_WARP_SIZE - 1;
436 warp_is_active = end_of_warp > desired_digit;
437 thread_is_active = tidx > desired_digit;
438 } else {
439 int start_of_warp = warp * C10_WARP_SIZE;
440 warp_is_active = start_of_warp < desired_digit;
441 thread_is_active = tidx < desired_digit;
442 }
443 uint32_t count = 0;
444 if (warp_is_active) {
445 if (thread_is_active) {
446 count = doLdg(counts + block_idx * RADIX_DIGITS + tidx);
447 }
448 for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
449 count += WARP_SHFL_DOWN(count, offset);
450 }
451 }
452
453 constexpr int num_warps = RADIX_DIGITS / C10_WARP_SIZE;
454 __shared__ uint32_t warp_counts[num_warps];
455 if (tidx % C10_WARP_SIZE == 0) {
456 warp_counts[warp] = count;
457 }
458 __syncthreads();
459 static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE,
460 "Assuming only 1 warp is needed for final reduction");
461 if (warp != 0) {
462 return;
463 }
464 count = 0;
465 if (tidx < num_warps) {
466 count = warp_counts[tidx];
467 }
468 for (int offset = num_warps / 2; offset > 0; offset /= 2) {
469 count += WARP_SHFL_DOWN(count, offset);
470 }
471 if (tidx == 0) {
472 withinKCounts[block_idx] += count;
473 }
474 }
475
476 // Assumption: slice_size can not be larger than UINT32_MAX
477 template <typename Bitwise>
computeBlockwiseKthCounts(Bitwise * desires,short * counts,uint32_t num_blocks,uint32_t blocks_per_slice,uint32_t * kthCounts)478 __global__ void computeBlockwiseKthCounts(
479 Bitwise* desires, // size: num_slices
480 short* counts, // size: num_slices * blocks_per_slice * radix_digits
481 uint32_t num_blocks, // the number of blocks used by `radixFindKthValues` kernel
482 uint32_t blocks_per_slice,
483 // outputs:
484 uint32_t* kthCounts // size: num_slices * blocks_per_slice == num_blocks
485 ) {
486 CUDA_KERNEL_LOOP_TYPE(idx, num_blocks, uint32_t) {
487 uint32_t slice_idx = idx / blocks_per_slice;
488 Bitwise desired = doLdg(desires + slice_idx);
489 Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, 0, RADIX_BITS);
490 kthCounts[idx] = doLdg(counts + idx * RADIX_DIGITS + desired_digit);
491 }
492 }
493
494 template <typename T, typename IndexType, int Dim>
C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)495 C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)
496 __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> input,
497 IndexType inputSliceSize,
498 IndexType outputSliceSize, // aka `k`
499 bool largest,
500
501 uint32_t numInputSlices,
502 IndexType inputWithinSliceStride,
503
504 at::cuda::detail::TensorInfo<T, IndexType> topK,
505 IndexType topKWithinSliceStride,
506
507 at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
508 IndexType indicesWithinSliceStride,
509
510 uint32_t items_per_thread,
511 uint32_t blocks_per_slice,
512
513 T *kthValues,
514 uint32_t* withinKCounts,
515 uint32_t* kthCounts,
516 uint32_t num_blocks) {
517
518 uint32_t items_per_block = items_per_thread * BLOCK_THREADS;
519 uint32_t tidx = threadIdx.x;
520 uint32_t block_idx = getLinearBlockId<uint32_t>();
521
522 // The grid is computed from `getGridFromTiles`, when there are lots of
523 // elements, we will use both blockIdx.x and blockIdx.y, and maybe blockIdx.z
524 // when this is the case, the number of blocks that we are launching can be
525 // more than the number of blocks we need. So we need to check the range of
526 // `block_idx`.
527 if (block_idx >= num_blocks) {
528 return;
529 }
530
531 uint32_t slice_idx = block_idx / blocks_per_slice;
532 uint32_t blk_idx_in_slice = block_idx % blocks_per_slice;
533
534 items_per_thread = (blk_idx_in_slice + 1 < blocks_per_slice)
535 ? items_per_thread
536 : at::ceil_div((int64_t)(inputSliceSize - blk_idx_in_slice * items_per_block), (int64_t)BLOCK_THREADS);
537
538 // Find the start offset for our slice
539 IndexType sliceStartIndex =
540 at::cuda::detail::IndexToOffset<const T, IndexType, Dim>::get(slice_idx, input);
541 IndexType topKSliceStartIndex =
542 at::cuda::detail::IndexToOffset<T, IndexType, Dim>::get(slice_idx, topK);
543 IndexType indicesSliceStartIndex =
544 at::cuda::detail::IndexToOffset<int64_t, IndexType, Dim>::get(slice_idx, indices);
545
546 const T* inputSliceStart = &input.data[sliceStartIndex];
547 T* topKSliceStart = &topK.data[topKSliceStartIndex];
548 int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
549
550 // Find the k-th highest element in our input
551 T kthValue = kthValues[slice_idx];
552 const auto kthValueConverted = at::native::TopKTypeConfig<T>::convert(kthValue);
553
554 // Find the start index in output tensor of this block
555 uint32_t startWithinK = 0;
556 if (blk_idx_in_slice > 0) {
557 startWithinK = withinKCounts[block_idx - 1];
558 }
559 uint32_t startKth = withinKCounts[slice_idx * blocks_per_slice + blocks_per_slice - 1];
560 if (blk_idx_in_slice > 0) {
561 startKth += kthCounts[block_idx - 1];
562 }
563
564 // Read input, select topk out and write
565 typedef cub::BlockScan<uint32_t, BLOCK_THREADS> BlockScan;
566 __shared__ typename BlockScan::TempStorage temp_storage;
567 for (int i = 0; i < items_per_thread; ++i) {
568 // Find the start offset for this slice
569 IndexType idx = blk_idx_in_slice * items_per_block + i * BLOCK_THREADS + tidx;
570 T val;
571 int withinK = 0;
572 int kth = 0;
573 if (idx < inputSliceSize) {
574 val = doLdg(inputSliceStart + idx * inputWithinSliceStride);
575 const auto valConverted = at::native::TopKTypeConfig<T>::convert(val);
576 withinK = (largest ? valConverted > kthValueConverted : valConverted < kthValueConverted);
577 kth = (valConverted == kthValueConverted);
578 }
579
580 uint32_t withinKIndex;
581 uint32_t numWithinK;
582 BlockScan(temp_storage).ExclusiveSum(withinK, withinKIndex, numWithinK);
583 __syncthreads();
584 if (withinK) {
585 uint32_t offset = withinKIndex + startWithinK;
586 topKSliceStart[offset * topKWithinSliceStride] = val;
587 indicesSliceStart[offset * indicesWithinSliceStride] = idx;
588 }
589 startWithinK += numWithinK;
590
591 if (startKth < outputSliceSize) {
592 uint32_t kthIndex;
593 uint32_t numKth;
594 BlockScan(temp_storage).ExclusiveSum(kth, kthIndex, numKth);
595 __syncthreads();
596 if (kth) {
597 uint32_t offset = kthIndex + startKth;
598 if (offset < outputSliceSize) {
599 topKSliceStart[offset * topKWithinSliceStride] = val;
600 indicesSliceStart[offset * indicesWithinSliceStride] = idx;
601 }
602 }
603 startKth += numKth;
604 }
605 }
606 }
607 #endif
608
get_items_per_thread(uint64_t num_slices,uint64_t slice_size)609 int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) {
610 // occupancy of this kernel is limited by registers per threads
611 constexpr int REGS_PER_THREAD = 40; // from nsight launch statistics
612 constexpr int REGS_PER_BLOCK = REGS_PER_THREAD * BLOCK_THREADS;
613 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
614 int mpc = prop->multiProcessorCount;
615 #if defined(USE_ROCM)
616 int regs_per_mp = prop->regsPerBlock;
617 int max_blocks_per_mp = 32;
618 #else
619 int regs_per_mp = prop->regsPerMultiprocessor;
620 #if !defined(USE_ROCM)
621 int max_blocks_per_mp = prop->maxBlocksPerMultiProcessor;
622 #else
623 int max_blocks_per_mp = 32;
624 #endif
625 #endif
626 int blocks_per_mp = std::min(regs_per_mp / REGS_PER_BLOCK, max_blocks_per_mp);
627 int64_t items_per_thread = at::ceil_div((int64_t)(slice_size * num_slices), (int64_t)(mpc * blocks_per_mp * BLOCK_THREADS));
628 items_per_thread = std::max(MIN_ITEMS_PER_THREAD, std::min((int)items_per_thread, MAX_ITEMS_PER_THREAD)); // clamp to (4, 64)
629 return items_per_thread;
630 }
631
632 class BlockIdxToKey {
633 uint32_t blocks_per_slice;
634 public:
BlockIdxToKey(uint32_t blocks_per_slice)635 BlockIdxToKey(uint32_t blocks_per_slice): blocks_per_slice(blocks_per_slice) {}
operator ()(uint32_t blk) const636 __device__ __forceinline__ uint32_t operator()(uint32_t blk) const {
637 return blk / blocks_per_slice;
638 }
639 };
640
641 template <typename T, typename IndexType, int Dim>
launch(at::cuda::detail::TensorInfo<const T,IndexType> input,IndexType inputSliceSize,IndexType outputSliceSize,bool largest,uint32_t numInputSlices,IndexType inputWithinSliceStride,at::cuda::detail::TensorInfo<T,IndexType> topK,IndexType topKWithinSliceStride,at::cuda::detail::TensorInfo<int64_t,IndexType> indices,IndexType indicesWithinSliceStride)642 void launch(
643 at::cuda::detail::TensorInfo<const T, IndexType> input,
644 IndexType inputSliceSize,
645 IndexType outputSliceSize, // aka `k`
646 bool largest,
647
648 uint32_t numInputSlices,
649 IndexType inputWithinSliceStride,
650
651 at::cuda::detail::TensorInfo<T, IndexType> topK,
652 IndexType topKWithinSliceStride,
653
654 at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
655 IndexType indicesWithinSliceStride) {
656 auto stream = c10::cuda::getCurrentCUDAStream();
657
658 // configure items_per_thread based on device architecture and input size
659 int items_per_thread = get_items_per_thread(numInputSlices, inputSliceSize);
660 int items_per_block = items_per_thread * BLOCK_THREADS;
661
662 using Bitwise = typename TopKTypeConfig<T>::RadixType;
663 uint32_t blocks_per_slice = at::ceil_div((int64_t)inputSliceSize, (int64_t)items_per_block);
664 uint32_t num_blocks = numInputSlices * blocks_per_slice;
665
666 // temporary storage
667 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
668
669 auto kthValues_buffer = allocator.allocate(numInputSlices * sizeof(T));
670 T* kthValues = reinterpret_cast<T*>(kthValues_buffer.get());
671
672 TORCH_CHECK(blocks_per_slice <= std::numeric_limits<uint32_t>::max(), "blocks_per_slice larger than uint32 maximum is not supported");
673 auto semaphores_buffer = allocator.allocate(numInputSlices * sizeof(uint32_t));
674 uint32_t* semaphores = reinterpret_cast<uint32_t*>(semaphores_buffer.get());
675 AT_CUDA_CHECK(cudaMemsetAsync(semaphores, 0, numInputSlices * sizeof(uint32_t), stream));
676
677 auto ks_to_find_buffer = allocator.allocate(numInputSlices * sizeof(uint32_t));
678 uint32_t* ks_to_find = reinterpret_cast<uint32_t*>(ks_to_find_buffer.get());
679 uint32_t k_to_find = largest ? inputSliceSize - outputSliceSize + 1: outputSliceSize;
680 fill<uint32_t><<<std::min(((int64_t)numInputSlices + 511) / 512, (int64_t)1073741824), 512, 0, stream>>>(
681 ks_to_find, k_to_find, numInputSlices);
682 C10_CUDA_KERNEL_LAUNCH_CHECK();
683
684 auto desired_buffer = allocator.allocate(numInputSlices * sizeof(Bitwise));
685 Bitwise* desired = reinterpret_cast<Bitwise*>(desired_buffer.get());
686
687 auto counts_buffer = allocator.allocate(num_blocks * RADIX_DIGITS * sizeof(short));
688 short* counts = reinterpret_cast<short*>(counts_buffer.get());
689 static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < std::numeric_limits<short>::max(),
690 "blockwise counter too large");
691
692 #if CUB_SUPPORTS_SCAN_BY_KEY()
693 auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
694 uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
695 AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream));
696
697 auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
698 uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get());
699 #endif
700
701 Bitwise desiredMask = 0;
702 dim3 grid;
703 TORCH_INTERNAL_ASSERT(getGridFromTiles(num_blocks, grid), "Too many slices for topk");
704 dim3 block(BLOCK_THREADS);
705
706 // iterate radix bits for multiple passes
707 for (int current_bit = sizeof(T) * 8 - RADIX_BITS; current_bit >= 0; current_bit -= RADIX_BITS) {
708 radixFindKthValues<T, IndexType, Bitwise, Dim><<<grid, block, 0, stream>>>(
709 input,
710 inputSliceSize,
711 ks_to_find,
712 numInputSlices,
713 inputWithinSliceStride,
714 current_bit,
715 items_per_thread,
716 blocks_per_slice,
717 desiredMask,
718 semaphores,
719 desired,
720 counts,
721 kthValues);
722 C10_CUDA_KERNEL_LAUNCH_CHECK();
723 #if CUB_SUPPORTS_SCAN_BY_KEY()
724 computeBlockwiseWithinKCounts<Bitwise><<<grid, RADIX_DIGITS, 0, stream>>>(
725 desired, counts, blocks_per_slice, current_bit, largest, withinKCounts, num_blocks);
726 C10_CUDA_KERNEL_LAUNCH_CHECK();
727 #endif
728 desiredMask = at::cuda::Bitfield<Bitwise>::setBitfield(desiredMask, RADIX_MASK, current_bit, RADIX_BITS);
729 }
730
731 #if CUB_SUPPORTS_SCAN_BY_KEY()
732 computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>(
733 desired, counts, num_blocks, blocks_per_slice, kthCounts);
734 C10_CUDA_KERNEL_LAUNCH_CHECK();
735 // Do a prefix scan of withinKCounts and kthCounts using slice_idx as keys to get the starting index of each block
736 using counting_iter_t = cub::CountingInputIterator<uint32_t, uint32_t>;
737 using slice_idx_iter_t = cub::TransformInputIterator<uint32_t, BlockIdxToKey, counting_iter_t>;
738 slice_idx_iter_t slice_idx_iter(counting_iter_t(0), BlockIdxToKey(blocks_per_slice));
739 at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, withinKCounts, withinKCounts, num_blocks);
740 at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, kthCounts, kthCounts, num_blocks);
741 // copy topk values to output tensor
742 gatherTopK<T, IndexType, Dim><<<grid, block, 0, stream>>>(
743 input, inputSliceSize, outputSliceSize, largest, numInputSlices, inputWithinSliceStride,
744 topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread,
745 blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks);
746 C10_CUDA_KERNEL_LAUNCH_CHECK();
747 #else
748 // Find topk values based on kth values
749 {
750 dim3 grid;
751 TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
752 int warp_size = at::cuda::warp_size();
753 dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
754 sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>(
755 input,
756 inputSliceSize,
757 outputSliceSize,
758 largest,
759 numInputSlices,
760 inputWithinSliceStride,
761 topK,
762 topKWithinSliceStride,
763 indices,
764 indicesWithinSliceStride,
765 kthValues);
766 C10_CUDA_KERNEL_LAUNCH_CHECK();
767 }
768 #endif
769 }
770
771 } // namespace mbtopk
772
should_use_multiblock(int64_t num_slices,int64_t slice_size)773 bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
774 if (num_slices > std::numeric_limits<uint32_t>::max() ||
775 slice_size > std::numeric_limits<uint32_t>::max()) return false;
776 #if CUB_SUPPORTS_SCAN_BY_KEY()
777 // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267
778 return (num_slices <= 20 && slice_size >= 20000) ||
779 (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) ||
780 (num_slices > 40 && num_slices <= 80 && slice_size >= 8000) ||
781 (num_slices > 80 && num_slices < 200 && slice_size >= 5000) ||
782 (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) ||
783 (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) ||
784 (num_slices > 4000 && slice_size >= 400);
785 #else
786 // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081
787 return (num_slices <= 400 && slice_size >= 5000) ||
788 (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) ||
789 (num_slices >= 4000 && slice_size >= 300);
790 #endif
791 }
792
launch_gather_topk_kernel(const TensorBase & self,int64_t k,int64_t dim,bool largest,const TensorBase & values,const TensorBase & indices)793 void launch_gather_topk_kernel(
794 const TensorBase& self, int64_t k, int64_t dim, bool largest,
795 const TensorBase& values, const TensorBase& indices) {
796 int numDims = self.dim();
797 numDims = numDims == 0 ? 1 : numDims;
798 TORCH_CHECK(numDims <= MAX_DIMS, "input tensor has too many dimensions");
799 int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
800
801 auto input = self.contiguous();
802 // static_cast is required to ensure that the correct type (INDEX_T)
803 // is provided to the kernel for the arguments.
804 #define RUN_K(INDEX_T, DIM, LAUNCH_FUNCTION_NAME) \
805 LAUNCH_FUNCTION_NAME<scalar_t, INDEX_T, DIM>( \
806 inputInfo, \
807 static_cast<INDEX_T>(sliceSize), \
808 static_cast<INDEX_T>(k), \
809 largest, \
810 static_cast<INDEX_T>(numInputSlices), \
811 /* The actual dimension that the k-selection is running in */ \
812 /* may have changed from collapseDims() */ \
813 static_cast<INDEX_T>(inputInfo.strides[collapseInputDim]), \
814 topKInfo, \
815 static_cast<INDEX_T>(topKInfo.strides[collapseTopKDim]), \
816 indicesInfo, \
817 static_cast<INDEX_T>(indicesInfo.strides[collapseIndicesDim]));
818
819 #define RUN_MB(INDEX_T, DIM) \
820 if (should_use_multiblock(numInputSlices, sliceSize)) { \
821 RUN_K(INDEX_T, DIM, mbtopk::launch); \
822 } else { \
823 RUN_K(INDEX_T, DIM, sbtopk::launch); \
824 }
825
826 #define RUN_DIM(INDEX_T) \
827 if (allDims == 1) { \
828 RUN_MB(INDEX_T, 1); \
829 } else if (allDims == 2) { \
830 RUN_MB(INDEX_T, 2); \
831 } else if (allDims == 3) { \
832 RUN_MB(INDEX_T, 3); \
833 } else { \
834 RUN_MB(INDEX_T, -1); \
835 }
836
837 #define RUN_T(INDEX_T) \
838 AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "topk_out_cuda", [&] { \
839 at::cuda::detail::TensorInfo<const scalar_t, INDEX_T> inputInfo = \
840 at::cuda::detail::getTensorInfo<const scalar_t, INDEX_T>(input); \
841 at::cuda::detail::TensorInfo<scalar_t, INDEX_T> topKInfo = \
842 at::cuda::detail::getTensorInfo<scalar_t, INDEX_T>(values); \
843 at::cuda::detail::TensorInfo<int64_t, INDEX_T> indicesInfo = \
844 at::cuda::detail::getTensorInfo<int64_t, INDEX_T>(indices); \
845 /* tensorInfoLegacyIfScalar*/ \
846 if (!input.dim()) { \
847 inputInfo.dims = 1; \
848 inputInfo.sizes[0] = 1; \
849 inputInfo.strides[0] = 1; \
850 topKInfo.dims = 1; \
851 topKInfo.sizes[0] = 1; \
852 topKInfo.strides[0] = 1; \
853 indicesInfo.dims = 1; \
854 indicesInfo.sizes[0] = 1; \
855 indicesInfo.strides[0] = 1; \
856 } \
857 /* We use these structures solely to find the offset to */ \
858 /* each slice we are operating on */ \
859 inputInfo.sizes[dim] = 1; \
860 topKInfo.sizes[dim] = 1; \
861 indicesInfo.sizes[dim] = 1; \
862 /* stash the stride of dim because it can be accidentally collapsed */ \
863 auto strideTopK = topKInfo.strides[dim]; \
864 auto strideIndices = indicesInfo.strides[dim]; \
865 /* Collapse all other dims */ \
866 int collapseInputDim = inputInfo.collapseDims(dim); \
867 int collapseTopKDim = topKInfo.collapseDims(dim); \
868 int collapseIndicesDim = indicesInfo.collapseDims(dim); \
869 /* restore stride in case it was collapsed */ \
870 topKInfo.strides[collapseTopKDim] = strideTopK; \
871 indicesInfo.strides[collapseIndicesDim] = strideIndices; \
872 int64_t numInputSlices = 1; \
873 for (int i = 0; i < inputInfo.dims; ++i) { \
874 numInputSlices *= inputInfo.sizes[i]; \
875 } \
876 \
877 /* This is used as a template parameter to calculate indices. */ \
878 /* We only specialize it if all collapsed dim sizes are the */ \
879 /* same; otherwise, we use -1 which is the specialization */ \
880 /* parameter for arbitrary dimensions */ \
881 int allDims = inputInfo.dims; \
882 if (topKInfo.dims != allDims || indicesInfo.dims != allDims) { \
883 allDims = -1; \
884 } \
885 \
886 RUN_DIM(INDEX_T); \
887 });
888
889 // the below is safe with 0-dimensional tensors because it is based on
890 // TensorInfo which implicitly expands to 1-dimensional.
891 if (input.numel() > 0) {
892 // Based on required index size, run the algorithm with the
893 // appropriate index type
894 if (at::cuda::detail::canUse32BitIndexMath(input) &&
895 at::cuda::detail::canUse32BitIndexMath(values) &&
896 at::cuda::detail::canUse32BitIndexMath(indices)) {
897 RUN_T(uint32_t);
898 } else {
899 RUN_T(uint64_t);
900 }
901 }
902 #undef RUN_T
903 #undef RUN_DIM
904 #undef RUN_K
905 }
906
907 } // at::native
908