xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SortStable.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #define TORCH_ASSERT_NO_OPERATORS
3 #include <ATen/native/cuda/SortStable.h>
4 
5 #include <ATen/Dispatch.h>
6 #include <ATen/core/Array.h>
7 #include <ATen/core/TensorBase.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <ATen/cuda/cub.cuh>
11 #include <ATen/cuda/detail/OffsetCalculator.cuh>
12 #include <ATen/native/cuda/SortUtils.cuh>
13 #include <ATen/native/cuda/SortingCommon.cuh>
14 
15 #include <c10/core/DeviceArray.h>
16 #include <limits>
17 
18 namespace at::native {
19 
20 namespace {
21 
22 struct offset_t {
23   int stride;
24   int begin;
operator []at::native::__anon30277e2f0111::offset_t25   __device__ int operator[](int i) {
26     return stride * (begin + i);
27   }
28 };
29 // Segmented sort by full sort algorithm:.
30 // Say we are sorting a (2, 3) tensor. We have in flattened form:
31 // values       0.4 1.2 5.3 6.2 1.3 2.3
32 // indices        0   1   2   0   1   2
33 // segment_id     0   0   0   1   1   1
34 
35 // First we sort by values, globally:
36 // values       6.2 5.3 2.3 1.2 1.3 0.4
37 // indices        0   2   2   1   1   0
38 // segment_id     1   0   1   0   1   0
39 
40 // Then we stable sort by segment id:
41 // values       5.3 1.2 0.4 6.2 2.3 1.3
42 // indices        2   1   0   0   2   1
43 // segment_id     0   0   0   1   1   1
44 
45 // This method can only work if the slice we are sorting (`dim`) is
46 // innermost, and both values and indices are contiguous. We do this
47 // by re-arranging the input into this form as needed, which will
48 // unfortunately allocate memory if the request is not in this form.
49 // Vectorized sort is slower than iterated sort if the number of
50 // slices is small (since we're sorting twice, instead of invoking a
51 // smaller sort `numSlices` times), but the cub sort
52 // implementation here is a catch-all, so we're not looking for
53 // efficiency, but instead correctness.
54 
55 template <typename scalar_t>
sort_postprocess_kernel(const scalar_t * in,scalar_t * out,int64_t * index,const int2 * i_s_ptr,int nsegments,int nsort)56 __global__ void sort_postprocess_kernel(
57     const scalar_t* in,
58     scalar_t* out,
59     int64_t* index,
60     const int2* i_s_ptr,
61     int nsegments,
62     int nsort) {
63   CUDA_KERNEL_LOOP(i, nsegments * nsort) {
64     int segment = i / nsort;
65     int j = i % nsort;
66 
67     int offset = segment * nsort;
68     const scalar_t* in_ = in + offset;
69     scalar_t* out_ = out + offset;
70     int64_t* index_ = index + offset;
71     const int2* i_s_ptr_ = i_s_ptr + offset;
72 
73     int idx = i_s_ptr_[j].y;
74     index_[j] = idx;
75     out_[j] = in_[idx];
76   }
77 }
78 
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)79 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
80 __global__ void fill_index_and_segment_kernel(
81     int2* data,
82     int numel,
83     at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
84   CUDA_KERNEL_LOOP(idx, numel) {
85     auto div_mod = nsort_divider.divmod(idx);
86     auto segment = static_cast<int>(div_mod.div);
87     auto sort = static_cast<int>(div_mod.mod);
88     data[idx] = int2{segment, sort};
89   }
90 }
91 
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)92 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
93 __global__ void fill_reverse_indices_kernel(
94     int64_t* data,
95     int numel,
96     at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
97   CUDA_KERNEL_LOOP(idx, numel) {
98     data[idx] = nsort_divider.mod(idx);
99   }
100 }
101 
102 template <typename scalar_t>
segmented_sort_large_segments(const int64_t nsegments,const int64_t nsort,const int64_t n,const bool descending,const scalar_t * self_ptr,scalar_t * values_ptr,int64_t * indices_ptr)103 inline void segmented_sort_large_segments(
104     const int64_t nsegments,
105     const int64_t nsort,
106     const int64_t n,
107     const bool descending,
108     const scalar_t* self_ptr,
109     scalar_t* values_ptr,
110     int64_t* indices_ptr) {
111   using namespace at::cuda::detail;
112   auto allocator = at::cuda::getCUDADeviceAllocator();
113   auto stream = at::cuda::getCurrentCUDAStream();
114   dim3 block = CUDA_NUM_THREADS;
115   dim3 grid = GET_BLOCKS(nsort);
116   c10::DeviceArray<int64_t> indices(*allocator, nsort);
117   at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
118   fill_reverse_indices_kernel<<<grid, block, 0, stream>>>(
119       indices.get(), nsort, nsort_divider);
120   const int64_t* initial_indices = indices.get();
121 
122   for (auto i : c10::irange(nsegments)) {
123     at::cuda::cub::radix_sort_pairs<scalar_t, int64_t>(
124         self_ptr, values_ptr, initial_indices, indices_ptr, nsort, descending);
125     indices_ptr += nsort;
126     self_ptr += nsort;
127     values_ptr += nsort;
128   }
129 }
130 
131 template <typename scalar_t>
segmented_sort_pairs_by_full_sort(const int64_t nsegments,const int64_t nsort,const int64_t n,const bool descending,const scalar_t * const self_ptr,scalar_t * const values_ptr,int64_t * const indices_ptr)132 inline void segmented_sort_pairs_by_full_sort(
133     const int64_t nsegments,
134     const int64_t nsort,
135     const int64_t n,
136     const bool descending,
137     const scalar_t* const self_ptr,
138     scalar_t* const values_ptr,
139     int64_t* const indices_ptr) {
140   int64_t segment_bits = std::max<int64_t>(
141       1L, static_cast<int64_t>(std::ceil(std::log2(nsegments))));
142 
143   const auto numel = nsort * nsegments;
144   auto cuda_allocator = at::cuda::getCUDADeviceAllocator();
145   auto indices_and_segment = cuda_allocator->allocate(numel * sizeof(int2));
146   auto i_s_ptr = static_cast<int2*>(indices_and_segment.get());
147 
148   using namespace at::cuda::detail;
149   dim3 block = CUDA_NUM_THREADS;
150   dim3 grid = GET_BLOCKS(numel);
151   auto stream = c10::cuda::getCurrentCUDAStream();
152   at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
153   fill_index_and_segment_kernel<<<grid, block, 0, stream>>>(
154       i_s_ptr, numel, nsort_divider);
155 
156   auto indices_and_segment2 =
157       cuda_allocator->allocate(nsegments * nsort * sizeof(int2));
158   auto i_s_ptr2 = static_cast<int2*>(indices_and_segment2.get());
159 
160   at::cuda::cub::radix_sort_pairs<scalar_t, int2>(
161       self_ptr, nullptr, i_s_ptr, i_s_ptr2, n, descending);
162 
163   TORCH_INTERNAL_ASSERT(segment_bits <= 32);
164 
165   // sort on lower 32bits, i.e. segment index
166   at::cuda::cub::radix_sort_keys<int64_t>(
167       reinterpret_cast<int64_t*>(i_s_ptr2),
168       reinterpret_cast<int64_t*>(i_s_ptr),
169       n,
170       false,
171       0,
172       segment_bits);
173 
174   sort_postprocess_kernel<<<
175       (n + 511) / 512,
176       512,
177       0,
178       at::cuda::getCurrentCUDAStream()>>>(
179       self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort);
180 }
181 
182 template <typename scalar_t>
segmented_sort_pairs(int64_t nsegments,int64_t nsort,int64_t n,bool descending,const scalar_t * self_ptr,scalar_t * values_ptr,int64_t * indices_ptr)183 void segmented_sort_pairs(
184     int64_t nsegments,
185     int64_t nsort,
186     int64_t n,
187     bool descending,
188     const scalar_t* self_ptr,
189     scalar_t* values_ptr,
190     int64_t* indices_ptr) {
191   const auto numel = nsort * nsegments;
192   auto cuda_allocator = at::cuda::getCUDADeviceAllocator();
193   auto reverse_indices = cuda_allocator->allocate(numel * sizeof(int64_t));
194   int64_t* reverse_indices_ptr = static_cast<int64_t*>(reverse_indices.get());
195 
196   using namespace at::cuda::detail;
197   dim3 block = CUDA_NUM_THREADS;
198   dim3 grid = GET_BLOCKS(numel);
199   auto stream = c10::cuda::getCurrentCUDAStream();
200   at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
201   fill_reverse_indices_kernel<<<grid, block, 0, stream>>>(
202       reverse_indices_ptr, numel, nsort_divider);
203 
204   at::cuda::cub::segmented_sort_pairs(
205       self_ptr,
206       values_ptr,
207       reverse_indices_ptr,
208       indices_ptr,
209       n,
210       nsegments,
211       offset_t{(int)nsort, 0},
212       offset_t{(int)nsort, 1},
213       descending);
214 }
215 
216 } // namespace
217 
launch_stable_sort_kernel(const TensorBase & self,int64_t dim,bool descending,const TensorBase & values,const TensorBase & indices)218 void launch_stable_sort_kernel(
219     const TensorBase& self,
220     int64_t dim,
221     bool descending,
222     const TensorBase& values,
223     const TensorBase& indices) {
224   const auto numel = self.numel();
225   if (numel == 0) {
226     return;
227   }
228 
229   int64_t numel_or_intmax =
230       std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
231   int64_t nsort = self.size(dim);
232   int64_t nbatch = (numel_or_intmax / nsort) * nsort;
233   TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
234   int64_t* indices_ptr = indices.mutable_data_ptr<int64_t>();
235 
236   AT_DISPATCH_ALL_TYPES_AND3(
237       kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&] {
238         const scalar_t* self_ptr = self.const_data_ptr<scalar_t>();
239         scalar_t* values_ptr = values.mutable_data_ptr<scalar_t>();
240         int64_t remaining = numel;
241         while (remaining > 0) {
242           int64_t n = std::min(remaining, nbatch);
243           int64_t nsegments = n / nsort;
244 
245           if (nsegments == 1 ||
246               nsort >= 1000000) { // rough heuristics where even a single
247                                   // sort occupies GPU
248             segmented_sort_large_segments(
249                 nsegments,
250                 nsort,
251                 n,
252                 descending,
253                 self_ptr,
254                 values_ptr,
255                 indices_ptr);
256           } else if (nsegments < 128) {
257             segmented_sort_pairs_by_full_sort(
258                 nsegments,
259                 nsort,
260                 n,
261                 descending,
262                 self_ptr,
263                 values_ptr,
264                 indices_ptr);
265           } else {
266             segmented_sort_pairs(
267                 nsegments,
268                 nsort,
269                 n,
270                 descending,
271                 self_ptr,
272                 values_ptr,
273                 indices_ptr);
274           }
275 
276           remaining -= n;
277           self_ptr += n;
278           values_ptr += n;
279           indices_ptr += n;
280         }
281       });
282 }
283 
284 } // namespace at::native
285