1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cuda/Sort.h>
3 #include <ATen/core/TensorBase.h>
4 #include <ATen/core/Array.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/cuda/cub.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/KernelUtils.h>
9 #include <ATen/cuda/detail/OffsetCalculator.cuh>
10 #include <ATen/cuda/NumericLimits.cuh>
11 #include <ATen/native/cuda/SortUtils.cuh>
12 #include <ATen/native/cuda/SortingCommon.cuh>
13
14 #include <limits>
15 #include <c10/core/DeviceArray.h>
16
17 namespace at::native {
18
19 template <typename T>
minimum_grid_for_occupancy(T kernel,int max_block_size)20 static int minimum_grid_for_occupancy(T kernel, int max_block_size) {
21 int minGridSize = 0;
22 int blockSize;
23 C10_CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(
24 &minGridSize,
25 &blockSize,
26 kernel,
27 /*dynamicSMemSize=*/0,
28 max_block_size));
29 return minGridSize;
30 }
31
32 template <typename T>
has_nan()33 constexpr bool has_nan() {
34 if constexpr (std::numeric_limits<T>::is_specialized) {
35 return std::numeric_limits<T>::has_quiet_NaN;
36 } else if constexpr (
37 c10::is_complex<T>::value ||
38 std::is_same_v<T, c10::BFloat16> ||
39 std::is_same_v<T, c10::Half>) {
40 return true;
41 }
42 }
43
44 // For very small unstable sorts (n <= 32), use bitonicSortKVInPlace
45 // which can sort multiple arrays within the same block of threads,
46 // improving occupancy.
47 struct SmallBitonicSort {
48 template <int A, typename K, typename V, typename IndexType>
sortat::native::SmallBitonicSort49 void sort(
50 at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
51 IndexType keySlices,
52 IndexType keySliceSize,
53 IndexType keySliceStride,
54 at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
55 IndexType valueSliceStride,
56 bool descending) {
57 constexpr int sort_size = 32;
58 constexpr int max_block_y = 16;
59 constexpr int items_per_thread = 2;
60 static_assert(sort_size % items_per_thread == 0, "");
61 constexpr int block_x = sort_size / items_per_thread;
62
63 TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size);
64
65 // Scale batch size down if the grid would be too small
66 const auto min_grid = minimum_grid_for_occupancy(
67 bitonicSortKVInPlace<
68 A, -1, block_x, max_block_y,
69 K, V, LTOp<K, true>, IndexType>,
70 block_x * max_block_y);
71 const auto max_batch = std::max(IndexType{1}, keySlices / min_grid);
72 const int block_y = std::min(IndexType(max_block_y), max_batch);
73 dim3 block(block_x, block_y);
74
75 dim3 grid;
76 const int grid_count = (keySlices + block_y - 1) / block_y;
77 TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid),
78 "Too many slices to sort");
79 const auto stream = at::cuda::getCurrentCUDAStream();
80
81 if (descending) {
82 bitonicSortKVInPlace<A, -1, block_x, max_block_y>
83 <<<grid, block, 0, stream>>>(
84 keyInfo,
85 keySlices,
86 keySliceSize,
87 keySliceStride,
88 valueInfo,
89 valueSliceStride,
90 GTOp<K, true>());
91 C10_CUDA_KERNEL_LAUNCH_CHECK();
92 } else {
93 bitonicSortKVInPlace<A, -1, block_x, max_block_y>
94 <<<grid, block, 0, stream>>>(
95 keyInfo,
96 keySlices,
97 keySliceSize,
98 keySliceStride,
99 valueInfo,
100 valueSliceStride,
101 LTOp<K, true>());
102 C10_CUDA_KERNEL_LAUNCH_CHECK();
103 }
104 }
105 };
106
107 #if HAS_WARP_MERGE_SORT()
108
109 // For small sorts (n <= 128) we use warpMergeSortKVInPlace which
110 // sorts one slice per warp and potentially multiple slices in the
111 // same block for improved occupancy with large batch sizes.
112 template <int sort_size>
113 struct WarpMergeSort {
114
115 template <int A, typename K, typename V, typename IndexType>
sortat::native::WarpMergeSort116 void sort(
117 at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
118 IndexType keySlices,
119 IndexType keySliceSize,
120 IndexType keySliceStride,
121 at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
122 IndexType valueSliceStride,
123 bool descending) {
124 constexpr int max_block_y = 16;
125 const int block_x = at::cuda::warp_size();
126
127 TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size);
128
129 // Scale batch size down if the grid would be too small
130 const auto min_grid = minimum_grid_for_occupancy(
131 warpMergeSortKVInPlace<
132 A, -1, sort_size, max_block_y,
133 K, V, LTOp<K, true>, IndexType>,
134 block_x * max_block_y);
135 const auto max_batch = std::max(IndexType{1}, keySlices / min_grid);
136 const int block_y = std::min(IndexType(max_block_y), max_batch);
137 dim3 block(block_x, block_y);
138
139 dim3 grid;
140 const int grid_count = (keySlices + block_y - 1) / block_y;
141 TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid),
142 "Too many slices to sort");
143 const auto stream = at::cuda::getCurrentCUDAStream();
144
145 if (descending) {
146 const K invalid_key = at::numeric_limits<K>::lower_bound();
147 warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
148 <<<grid, block, 0, stream>>>(
149 keyInfo,
150 keySlices,
151 keySliceSize,
152 keySliceStride,
153 valueInfo,
154 valueSliceStride,
155 GTOp<K, true>(),
156 invalid_key);
157 C10_CUDA_KERNEL_LAUNCH_CHECK();
158 } else {
159 const K invalid_key = []{
160 // NAN is sorted after inf
161 if constexpr(has_nan<K>()) {
162 return K(NAN);
163 }
164 return at::numeric_limits<K>::upper_bound();
165 }();
166 warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
167 <<<grid, block, 0, stream>>>(
168 keyInfo,
169 keySlices,
170 keySliceSize,
171 keySliceStride,
172 valueInfo,
173 valueSliceStride,
174 LTOp<K, true>(),
175 invalid_key);
176 C10_CUDA_KERNEL_LAUNCH_CHECK();
177 }
178 }
179 };
180
181 #endif // !HAS_WARP_MERGE_SORT()
182
183 // For medium sizes (128 < n <= 4096) use radixSortKVInplace.
184 struct MediumRadixSort {
185
186 template <int A, typename K, typename V, typename IndexType>
sortat::native::MediumRadixSort187 void sort(
188 at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
189 IndexType keySlices,
190 IndexType keySliceSize,
191 IndexType keySliceStride,
192 at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
193 IndexType valueSliceStride,
194 bool descending) {
195
196 #define HANDLE_CASE(SIZE, ITEMS_PER_THREAD) \
197 fixed_size_sort<A, SIZE, ITEMS_PER_THREAD>( \
198 keyInfo, \
199 keySlices, \
200 keySliceSize, \
201 keySliceStride, \
202 valueInfo, \
203 valueSliceStride, \
204 descending)
205
206 int64_t ceilPowerOf2 = nextHighestPowerOf2(keySliceSize);
207 TORCH_INTERNAL_ASSERT(ceilPowerOf2 <= 4096);
208 switch (ceilPowerOf2) {
209 case 4096:
210 HANDLE_CASE(4096, 32);
211 break;
212 case 2048:
213 HANDLE_CASE(2048, 32);
214 break;
215 case 1024:
216 case 512:
217 case 256:
218 HANDLE_CASE(1024, 32);
219 break;
220 case 128:
221 case 64:
222 #if !HAS_WARP_MERGE_SORT()
223 HANDLE_CASE(128, 4);
224 break;
225 #endif
226 case 32:
227 case 16:
228 case 8:
229 case 4:
230 case 2:
231 #if HAS_WARP_MERGE_SORT()
232 TORCH_INTERNAL_ASSERT(
233 false, "Expected size <= 128 to be handled by a different algorithm");
234 #else
235 HANDLE_CASE(32, 2);
236 #endif
237 break;
238 case 1:
239 /* Nothing to do, data already sorted */
240 break;
241 default:
242 TORCH_INTERNAL_ASSERT(false);
243 }
244 #undef HANDLE_CASE
245
246 }
247
248 template <int A, int sort_size, int items_per_thread,
249 typename K, typename V, typename IndexType>
fixed_size_sortat::native::MediumRadixSort250 void fixed_size_sort(
251 at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
252 IndexType keySlices,
253 IndexType keySliceSize,
254 IndexType keySliceStride,
255 at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
256 IndexType valueSliceStride,
257 bool descending) {
258 static_assert(sort_size % items_per_thread == 0, "");
259 constexpr int block = sort_size / items_per_thread;
260 dim3 grid;
261 TORCH_INTERNAL_ASSERT(getGridFromTiles(keySlices, grid),
262 "Too many slices to sort");
263
264 const auto stream = at::cuda::getCurrentCUDAStream();
265 radixSortKVInPlace<A, -1, block, items_per_thread>
266 <<<grid, block, 0, stream>>>(
267 keyInfo,
268 keySlices,
269 keySliceSize,
270 keySliceStride,
271 valueInfo,
272 valueSliceStride,
273 descending);
274 C10_CUDA_KERNEL_LAUNCH_CHECK();
275 }
276 };
277
278 template <typename Sorter>
sortCommon(Sorter sorter,const TensorBase & key,const TensorBase & value,int dim,bool descending)279 void sortCommon(Sorter sorter, const TensorBase &key, const TensorBase &value,
280 int dim, bool descending) {
281 TORCH_CHECK(key.sizes() == value.sizes(),
282 "Key tensor must have same size as value tensor");
283 int dims = value.dim();
284 TORCH_CHECK(dims <= MAX_DIMS, "value tensor has too many dimensions");
285 // if key and value tensors have the same size, we do not need to check both
286
287 ptrdiff_t inElements = key.numel();
288
289 if (inElements == 0) {
290 return;
291 }
292
293 int64_t keySliceSize = key.size(dim);
294 ptrdiff_t keySlices = inElements / keySliceSize;
295
296 #define HANDLE_SORT_CASE(TYPE, A) \
297 sorter.template sort<A>( \
298 keyInfo, \
299 (TYPE) keySlices, \
300 (TYPE) keySliceSize, \
301 (TYPE) keyInfo.strides[collapseKeyDim], \
302 valueInfo, \
303 (TYPE) valueInfo.strides[collapseValueDim], \
304 descending)
305
306 // The constructed key/value tensor info is used to select the slice
307 // we are sorting on a per-block basis
308 // The constructed key/value tensor info is used to select the slice
309 // we are sorting on a per-block basis
310 AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, key.scalar_type(), "sortKeyValueInplace", [&] {
311 if (at::cuda::detail::canUse32BitIndexMath(key)) {
312 at::cuda::detail::TensorInfo<scalar_t, unsigned int> keyInfo =
313 at::cuda::detail::getTensorInfo<scalar_t, unsigned int>(key);
314 at::cuda::detail::TensorInfo<int64_t, unsigned int> valueInfo =
315 at::cuda::detail::getTensorInfo<int64_t, unsigned int>(value);
316
317 auto strideKey = keyInfo.strides[dim];
318 keyInfo.sizes[dim] = 1;
319 int collapseKeyDim = keyInfo.collapseDims(dim);
320 keyInfo.strides[collapseKeyDim] = strideKey;
321 auto strideValue = valueInfo.strides[dim];
322 valueInfo.sizes[dim]=1;
323 int collapseValueDim = valueInfo.collapseDims(dim);
324 valueInfo.strides[collapseValueDim] = strideValue;
325
326 if (keyInfo.isContiguous()) {
327 HANDLE_SORT_CASE(unsigned int, -2);
328 } else {
329 switch (keyInfo.dims) {
330 case 2:
331 HANDLE_SORT_CASE(unsigned int, 2);
332 break;
333 default:
334 HANDLE_SORT_CASE(unsigned int, -1);
335 break;
336 }
337 }
338
339 } else {
340 at::cuda::detail::TensorInfo<scalar_t, uint64_t> keyInfo =
341 at::cuda::detail::getTensorInfo<scalar_t, uint64_t>(key);
342 at::cuda::detail::TensorInfo<int64_t, uint64_t> valueInfo =
343 at::cuda::detail::getTensorInfo<int64_t, uint64_t>(value);
344
345 auto strideKey = keyInfo.strides[dim];
346 keyInfo.sizes[dim] = 1;
347 int collapseKeyDim = keyInfo.collapseDims(dim);
348 keyInfo.strides[collapseKeyDim] = strideKey;
349 auto strideValue = valueInfo.strides[dim];
350 valueInfo.sizes[dim]=1;
351 int collapseValueDim = valueInfo.collapseDims(dim);
352 valueInfo.strides[collapseValueDim] = strideValue;
353
354 // int64_t case is rare, just instantiate the generic version
355 HANDLE_SORT_CASE(uint64_t, -1);
356 }
357 });
358 #undef HANDLE_SORT_CASE
359 }
360
sortKeyValueInplace(const TensorBase & key,const TensorBase & value,int dim,bool descending,bool stable)361 void sortKeyValueInplace(
362 const TensorBase& key,
363 const TensorBase& value,
364 int dim,
365 bool descending,
366 bool stable) {
367 const auto sort_size = key.size(dim);
368 if (sort_size <= 1) {
369 return; // Already sorted
370 } else if (!stable && sort_size <= 32) {
371 // NOTE: Bitonic sort is unstable
372 sortCommon(SmallBitonicSort{}, key, value, dim, descending);
373 #if HAS_WARP_MERGE_SORT()
374 } else if (sort_size <= 128) {
375 sortCommon(WarpMergeSort<128>{}, key, value, dim, descending);
376 #endif
377 } else {
378 sortCommon(MediumRadixSort{}, key, value, dim, descending);
379 }
380 }
381
382 } // namespace at::native
383