xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Sort.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cuda/Sort.h>
3 #include <ATen/core/TensorBase.h>
4 #include <ATen/core/Array.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/cuda/cub.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/KernelUtils.h>
9 #include <ATen/cuda/detail/OffsetCalculator.cuh>
10 #include <ATen/cuda/NumericLimits.cuh>
11 #include <ATen/native/cuda/SortUtils.cuh>
12 #include <ATen/native/cuda/SortingCommon.cuh>
13 
14 #include <limits>
15 #include <c10/core/DeviceArray.h>
16 
17 namespace at::native {
18 
19 template <typename T>
minimum_grid_for_occupancy(T kernel,int max_block_size)20 static int minimum_grid_for_occupancy(T kernel, int max_block_size) {
21   int minGridSize = 0;
22   int blockSize;
23   C10_CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(
24       &minGridSize,
25       &blockSize,
26       kernel,
27       /*dynamicSMemSize=*/0,
28       max_block_size));
29   return minGridSize;
30 }
31 
32 template <typename T>
has_nan()33 constexpr bool has_nan() {
34   if constexpr (std::numeric_limits<T>::is_specialized) {
35     return std::numeric_limits<T>::has_quiet_NaN;
36   } else if constexpr (
37       c10::is_complex<T>::value ||
38       std::is_same_v<T, c10::BFloat16> ||
39       std::is_same_v<T, c10::Half>) {
40     return true;
41   }
42 }
43 
44 // For very small unstable sorts (n <= 32), use bitonicSortKVInPlace
45 // which can sort multiple arrays within the same block of threads,
46 // improving occupancy.
47 struct SmallBitonicSort {
48   template <int A, typename K, typename V, typename IndexType>
sortat::native::SmallBitonicSort49   void sort(
50       at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
51       IndexType keySlices,
52       IndexType keySliceSize,
53       IndexType keySliceStride,
54       at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
55       IndexType valueSliceStride,
56       bool descending) {
57     constexpr int sort_size = 32;
58     constexpr int max_block_y = 16;
59     constexpr int items_per_thread = 2;
60     static_assert(sort_size % items_per_thread == 0, "");
61     constexpr int block_x = sort_size / items_per_thread;
62 
63     TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size);
64 
65     // Scale batch size down if the grid would be too small
66     const auto min_grid = minimum_grid_for_occupancy(
67         bitonicSortKVInPlace<
68             A, -1, block_x, max_block_y,
69             K, V, LTOp<K, true>, IndexType>,
70         block_x * max_block_y);
71     const auto max_batch = std::max(IndexType{1}, keySlices / min_grid);
72     const int block_y = std::min(IndexType(max_block_y), max_batch);
73     dim3 block(block_x, block_y);
74 
75     dim3 grid;
76     const int grid_count = (keySlices + block_y - 1) / block_y;
77     TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid),
78                           "Too many slices to sort");
79     const auto stream = at::cuda::getCurrentCUDAStream();
80 
81     if (descending) {
82       bitonicSortKVInPlace<A, -1, block_x, max_block_y>
83         <<<grid, block, 0, stream>>>(
84           keyInfo,
85           keySlices,
86           keySliceSize,
87           keySliceStride,
88           valueInfo,
89           valueSliceStride,
90           GTOp<K, true>());
91       C10_CUDA_KERNEL_LAUNCH_CHECK();
92     } else {
93       bitonicSortKVInPlace<A, -1, block_x, max_block_y>
94         <<<grid, block, 0, stream>>>(
95           keyInfo,
96           keySlices,
97           keySliceSize,
98           keySliceStride,
99           valueInfo,
100           valueSliceStride,
101           LTOp<K, true>());
102       C10_CUDA_KERNEL_LAUNCH_CHECK();
103     }
104   }
105 };
106 
107 #if HAS_WARP_MERGE_SORT()
108 
109 // For small sorts (n <= 128) we use warpMergeSortKVInPlace which
110 // sorts one slice per warp and potentially multiple slices in the
111 // same block for improved occupancy with large batch sizes.
112 template <int sort_size>
113 struct WarpMergeSort {
114 
115   template <int A, typename K, typename V, typename IndexType>
sortat::native::WarpMergeSort116   void sort(
117       at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
118       IndexType keySlices,
119       IndexType keySliceSize,
120       IndexType keySliceStride,
121       at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
122       IndexType valueSliceStride,
123       bool descending) {
124     constexpr int max_block_y = 16;
125     const int block_x = at::cuda::warp_size();
126 
127     TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size);
128 
129     // Scale batch size down if the grid would be too small
130     const auto min_grid = minimum_grid_for_occupancy(
131         warpMergeSortKVInPlace<
132             A, -1, sort_size, max_block_y,
133             K, V, LTOp<K, true>, IndexType>,
134         block_x * max_block_y);
135     const auto max_batch = std::max(IndexType{1}, keySlices / min_grid);
136     const int block_y = std::min(IndexType(max_block_y), max_batch);
137     dim3 block(block_x, block_y);
138 
139     dim3 grid;
140     const int grid_count = (keySlices + block_y - 1) / block_y;
141     TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid),
142                           "Too many slices to sort");
143     const auto stream = at::cuda::getCurrentCUDAStream();
144 
145     if (descending) {
146       const K invalid_key = at::numeric_limits<K>::lower_bound();
147       warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
148         <<<grid, block, 0, stream>>>(
149           keyInfo,
150           keySlices,
151           keySliceSize,
152           keySliceStride,
153           valueInfo,
154           valueSliceStride,
155           GTOp<K, true>(),
156           invalid_key);
157       C10_CUDA_KERNEL_LAUNCH_CHECK();
158     } else {
159       const K invalid_key = []{
160         // NAN is sorted after inf
161         if constexpr(has_nan<K>()) {
162           return K(NAN);
163         }
164         return at::numeric_limits<K>::upper_bound();
165       }();
166       warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
167         <<<grid, block, 0, stream>>>(
168           keyInfo,
169           keySlices,
170           keySliceSize,
171           keySliceStride,
172           valueInfo,
173           valueSliceStride,
174           LTOp<K, true>(),
175           invalid_key);
176       C10_CUDA_KERNEL_LAUNCH_CHECK();
177     }
178   }
179 };
180 
181 #endif // !HAS_WARP_MERGE_SORT()
182 
183 // For medium sizes (128 < n <= 4096) use radixSortKVInplace.
184 struct MediumRadixSort {
185 
186   template <int A, typename K, typename V, typename IndexType>
sortat::native::MediumRadixSort187   void sort(
188       at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
189       IndexType keySlices,
190       IndexType keySliceSize,
191       IndexType keySliceStride,
192       at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
193       IndexType valueSliceStride,
194       bool descending) {
195 
196 #define HANDLE_CASE(SIZE, ITEMS_PER_THREAD)         \
197     fixed_size_sort<A, SIZE, ITEMS_PER_THREAD>(     \
198         keyInfo,                                    \
199         keySlices,                                  \
200         keySliceSize,                               \
201         keySliceStride,                             \
202         valueInfo,                                  \
203         valueSliceStride,                           \
204         descending)
205 
206     int64_t ceilPowerOf2 = nextHighestPowerOf2(keySliceSize);
207     TORCH_INTERNAL_ASSERT(ceilPowerOf2 <= 4096);
208     switch (ceilPowerOf2) {
209       case 4096:
210         HANDLE_CASE(4096, 32);
211         break;
212       case 2048:
213         HANDLE_CASE(2048, 32);
214         break;
215       case 1024:
216       case 512:
217       case 256:
218         HANDLE_CASE(1024, 32);
219         break;
220       case 128:
221       case 64:
222 #if !HAS_WARP_MERGE_SORT()
223         HANDLE_CASE(128, 4);
224         break;
225 #endif
226       case 32:
227       case 16:
228       case 8:
229       case 4:
230       case 2:
231 #if HAS_WARP_MERGE_SORT()
232         TORCH_INTERNAL_ASSERT(
233             false, "Expected size <= 128 to be handled by a different algorithm");
234 #else
235         HANDLE_CASE(32, 2);
236 #endif
237         break;
238       case 1:
239         /* Nothing to do, data already sorted */
240         break;
241       default:
242         TORCH_INTERNAL_ASSERT(false);
243     }
244 #undef HANDLE_CASE
245 
246   }
247 
248   template <int A, int sort_size, int items_per_thread,
249             typename K, typename V, typename IndexType>
fixed_size_sortat::native::MediumRadixSort250   void fixed_size_sort(
251       at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
252       IndexType keySlices,
253       IndexType keySliceSize,
254       IndexType keySliceStride,
255       at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
256       IndexType valueSliceStride,
257       bool descending) {
258     static_assert(sort_size % items_per_thread == 0, "");
259     constexpr int block = sort_size / items_per_thread;
260     dim3 grid;
261     TORCH_INTERNAL_ASSERT(getGridFromTiles(keySlices, grid),
262                           "Too many slices to sort");
263 
264     const auto stream = at::cuda::getCurrentCUDAStream();
265     radixSortKVInPlace<A, -1, block, items_per_thread>
266         <<<grid, block, 0, stream>>>(
267           keyInfo,
268           keySlices,
269           keySliceSize,
270           keySliceStride,
271           valueInfo,
272           valueSliceStride,
273           descending);
274     C10_CUDA_KERNEL_LAUNCH_CHECK();
275   }
276 };
277 
278 template <typename Sorter>
sortCommon(Sorter sorter,const TensorBase & key,const TensorBase & value,int dim,bool descending)279 void sortCommon(Sorter sorter, const TensorBase &key, const TensorBase &value,
280                 int dim, bool descending) {
281   TORCH_CHECK(key.sizes() == value.sizes(),
282               "Key tensor must have same size as value tensor");
283   int dims = value.dim();
284   TORCH_CHECK(dims <= MAX_DIMS, "value tensor has too many dimensions");
285   // if key and value tensors have the same size, we do not need to check both
286 
287   ptrdiff_t inElements = key.numel();
288 
289   if (inElements == 0) {
290     return;
291   }
292 
293   int64_t keySliceSize = key.size(dim);
294   ptrdiff_t keySlices = inElements / keySliceSize;
295 
296 #define HANDLE_SORT_CASE(TYPE, A)                   \
297   sorter.template sort<A>(                          \
298       keyInfo,                                      \
299       (TYPE) keySlices,                             \
300       (TYPE) keySliceSize,                          \
301       (TYPE) keyInfo.strides[collapseKeyDim],       \
302       valueInfo,                                    \
303       (TYPE) valueInfo.strides[collapseValueDim],   \
304       descending)
305 
306   // The constructed key/value tensor info is used to select the slice
307   // we are sorting on a per-block basis
308   // The constructed key/value tensor info is used to select the slice
309   // we are sorting on a per-block basis
310   AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, key.scalar_type(), "sortKeyValueInplace", [&]  {
311     if (at::cuda::detail::canUse32BitIndexMath(key)) {
312       at::cuda::detail::TensorInfo<scalar_t, unsigned int> keyInfo =
313         at::cuda::detail::getTensorInfo<scalar_t, unsigned int>(key);
314       at::cuda::detail::TensorInfo<int64_t, unsigned int> valueInfo =
315         at::cuda::detail::getTensorInfo<int64_t, unsigned int>(value);
316 
317       auto strideKey = keyInfo.strides[dim];
318       keyInfo.sizes[dim] = 1;
319       int collapseKeyDim = keyInfo.collapseDims(dim);
320       keyInfo.strides[collapseKeyDim] = strideKey;
321       auto strideValue = valueInfo.strides[dim];
322       valueInfo.sizes[dim]=1;
323       int collapseValueDim = valueInfo.collapseDims(dim);
324       valueInfo.strides[collapseValueDim] = strideValue;
325 
326       if (keyInfo.isContiguous()) {
327         HANDLE_SORT_CASE(unsigned int, -2);
328       } else {
329         switch (keyInfo.dims) {
330           case 2:
331             HANDLE_SORT_CASE(unsigned int, 2);
332             break;
333           default:
334             HANDLE_SORT_CASE(unsigned int, -1);
335             break;
336         }
337       }
338 
339     } else {
340       at::cuda::detail::TensorInfo<scalar_t, uint64_t> keyInfo =
341         at::cuda::detail::getTensorInfo<scalar_t, uint64_t>(key);
342       at::cuda::detail::TensorInfo<int64_t, uint64_t> valueInfo =
343         at::cuda::detail::getTensorInfo<int64_t, uint64_t>(value);
344 
345       auto strideKey = keyInfo.strides[dim];
346       keyInfo.sizes[dim] = 1;
347       int collapseKeyDim = keyInfo.collapseDims(dim);
348       keyInfo.strides[collapseKeyDim] = strideKey;
349       auto strideValue = valueInfo.strides[dim];
350       valueInfo.sizes[dim]=1;
351       int collapseValueDim = valueInfo.collapseDims(dim);
352       valueInfo.strides[collapseValueDim] = strideValue;
353 
354       // int64_t case is rare, just instantiate the generic version
355       HANDLE_SORT_CASE(uint64_t, -1);
356     }
357   });
358 #undef HANDLE_SORT_CASE
359 }
360 
sortKeyValueInplace(const TensorBase & key,const TensorBase & value,int dim,bool descending,bool stable)361 void sortKeyValueInplace(
362     const TensorBase& key,
363     const TensorBase& value,
364     int dim,
365     bool descending,
366     bool stable) {
367   const auto sort_size = key.size(dim);
368   if (sort_size <= 1) {
369     return; // Already sorted
370   } else if (!stable && sort_size <= 32) {
371     // NOTE: Bitonic sort is unstable
372     sortCommon(SmallBitonicSort{}, key, value, dim, descending);
373 #if HAS_WARP_MERGE_SORT()
374   } else if (sort_size <= 128) {
375     sortCommon(WarpMergeSort<128>{}, key, value, dim, descending);
376 #endif
377   } else {
378     sortCommon(MediumRadixSort{}, key, value, dim, descending);
379   }
380 }
381 
382 }  // namespace at::native
383