1
2 #define TORCH_ASSERT_NO_OPERATORS
3 #include <ATen/native/cuda/SortStable.h>
4
5 #include <ATen/Dispatch.h>
6 #include <ATen/core/Array.h>
7 #include <ATen/core/TensorBase.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <ATen/cuda/cub.cuh>
11 #include <ATen/cuda/detail/OffsetCalculator.cuh>
12 #include <ATen/native/cuda/SortUtils.cuh>
13 #include <ATen/native/cuda/SortingCommon.cuh>
14
15 #include <c10/core/DeviceArray.h>
16 #include <limits>
17
18 namespace at::native {
19
20 namespace {
21
22 struct offset_t {
23 int stride;
24 int begin;
operator []at::native::__anon30277e2f0111::offset_t25 __device__ int operator[](int i) {
26 return stride * (begin + i);
27 }
28 };
29 // Segmented sort by full sort algorithm:.
30 // Say we are sorting a (2, 3) tensor. We have in flattened form:
31 // values 0.4 1.2 5.3 6.2 1.3 2.3
32 // indices 0 1 2 0 1 2
33 // segment_id 0 0 0 1 1 1
34
35 // First we sort by values, globally:
36 // values 6.2 5.3 2.3 1.2 1.3 0.4
37 // indices 0 2 2 1 1 0
38 // segment_id 1 0 1 0 1 0
39
40 // Then we stable sort by segment id:
41 // values 5.3 1.2 0.4 6.2 2.3 1.3
42 // indices 2 1 0 0 2 1
43 // segment_id 0 0 0 1 1 1
44
45 // This method can only work if the slice we are sorting (`dim`) is
46 // innermost, and both values and indices are contiguous. We do this
47 // by re-arranging the input into this form as needed, which will
48 // unfortunately allocate memory if the request is not in this form.
49 // Vectorized sort is slower than iterated sort if the number of
50 // slices is small (since we're sorting twice, instead of invoking a
51 // smaller sort `numSlices` times), but the cub sort
52 // implementation here is a catch-all, so we're not looking for
53 // efficiency, but instead correctness.
54
55 template <typename scalar_t>
sort_postprocess_kernel(const scalar_t * in,scalar_t * out,int64_t * index,const int2 * i_s_ptr,int nsegments,int nsort)56 __global__ void sort_postprocess_kernel(
57 const scalar_t* in,
58 scalar_t* out,
59 int64_t* index,
60 const int2* i_s_ptr,
61 int nsegments,
62 int nsort) {
63 CUDA_KERNEL_LOOP(i, nsegments * nsort) {
64 int segment = i / nsort;
65 int j = i % nsort;
66
67 int offset = segment * nsort;
68 const scalar_t* in_ = in + offset;
69 scalar_t* out_ = out + offset;
70 int64_t* index_ = index + offset;
71 const int2* i_s_ptr_ = i_s_ptr + offset;
72
73 int idx = i_s_ptr_[j].y;
74 index_[j] = idx;
75 out_[j] = in_[idx];
76 }
77 }
78
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)79 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
80 __global__ void fill_index_and_segment_kernel(
81 int2* data,
82 int numel,
83 at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
84 CUDA_KERNEL_LOOP(idx, numel) {
85 auto div_mod = nsort_divider.divmod(idx);
86 auto segment = static_cast<int>(div_mod.div);
87 auto sort = static_cast<int>(div_mod.mod);
88 data[idx] = int2{segment, sort};
89 }
90 }
91
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)92 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
93 __global__ void fill_reverse_indices_kernel(
94 int64_t* data,
95 int numel,
96 at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
97 CUDA_KERNEL_LOOP(idx, numel) {
98 data[idx] = nsort_divider.mod(idx);
99 }
100 }
101
102 template <typename scalar_t>
segmented_sort_large_segments(const int64_t nsegments,const int64_t nsort,const int64_t n,const bool descending,const scalar_t * self_ptr,scalar_t * values_ptr,int64_t * indices_ptr)103 inline void segmented_sort_large_segments(
104 const int64_t nsegments,
105 const int64_t nsort,
106 const int64_t n,
107 const bool descending,
108 const scalar_t* self_ptr,
109 scalar_t* values_ptr,
110 int64_t* indices_ptr) {
111 using namespace at::cuda::detail;
112 auto allocator = at::cuda::getCUDADeviceAllocator();
113 auto stream = at::cuda::getCurrentCUDAStream();
114 dim3 block = CUDA_NUM_THREADS;
115 dim3 grid = GET_BLOCKS(nsort);
116 c10::DeviceArray<int64_t> indices(*allocator, nsort);
117 at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
118 fill_reverse_indices_kernel<<<grid, block, 0, stream>>>(
119 indices.get(), nsort, nsort_divider);
120 const int64_t* initial_indices = indices.get();
121
122 for (auto i : c10::irange(nsegments)) {
123 at::cuda::cub::radix_sort_pairs<scalar_t, int64_t>(
124 self_ptr, values_ptr, initial_indices, indices_ptr, nsort, descending);
125 indices_ptr += nsort;
126 self_ptr += nsort;
127 values_ptr += nsort;
128 }
129 }
130
131 template <typename scalar_t>
segmented_sort_pairs_by_full_sort(const int64_t nsegments,const int64_t nsort,const int64_t n,const bool descending,const scalar_t * const self_ptr,scalar_t * const values_ptr,int64_t * const indices_ptr)132 inline void segmented_sort_pairs_by_full_sort(
133 const int64_t nsegments,
134 const int64_t nsort,
135 const int64_t n,
136 const bool descending,
137 const scalar_t* const self_ptr,
138 scalar_t* const values_ptr,
139 int64_t* const indices_ptr) {
140 int64_t segment_bits = std::max<int64_t>(
141 1L, static_cast<int64_t>(std::ceil(std::log2(nsegments))));
142
143 const auto numel = nsort * nsegments;
144 auto cuda_allocator = at::cuda::getCUDADeviceAllocator();
145 auto indices_and_segment = cuda_allocator->allocate(numel * sizeof(int2));
146 auto i_s_ptr = static_cast<int2*>(indices_and_segment.get());
147
148 using namespace at::cuda::detail;
149 dim3 block = CUDA_NUM_THREADS;
150 dim3 grid = GET_BLOCKS(numel);
151 auto stream = c10::cuda::getCurrentCUDAStream();
152 at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
153 fill_index_and_segment_kernel<<<grid, block, 0, stream>>>(
154 i_s_ptr, numel, nsort_divider);
155
156 auto indices_and_segment2 =
157 cuda_allocator->allocate(nsegments * nsort * sizeof(int2));
158 auto i_s_ptr2 = static_cast<int2*>(indices_and_segment2.get());
159
160 at::cuda::cub::radix_sort_pairs<scalar_t, int2>(
161 self_ptr, nullptr, i_s_ptr, i_s_ptr2, n, descending);
162
163 TORCH_INTERNAL_ASSERT(segment_bits <= 32);
164
165 // sort on lower 32bits, i.e. segment index
166 at::cuda::cub::radix_sort_keys<int64_t>(
167 reinterpret_cast<int64_t*>(i_s_ptr2),
168 reinterpret_cast<int64_t*>(i_s_ptr),
169 n,
170 false,
171 0,
172 segment_bits);
173
174 sort_postprocess_kernel<<<
175 (n + 511) / 512,
176 512,
177 0,
178 at::cuda::getCurrentCUDAStream()>>>(
179 self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort);
180 }
181
182 template <typename scalar_t>
segmented_sort_pairs(int64_t nsegments,int64_t nsort,int64_t n,bool descending,const scalar_t * self_ptr,scalar_t * values_ptr,int64_t * indices_ptr)183 void segmented_sort_pairs(
184 int64_t nsegments,
185 int64_t nsort,
186 int64_t n,
187 bool descending,
188 const scalar_t* self_ptr,
189 scalar_t* values_ptr,
190 int64_t* indices_ptr) {
191 const auto numel = nsort * nsegments;
192 auto cuda_allocator = at::cuda::getCUDADeviceAllocator();
193 auto reverse_indices = cuda_allocator->allocate(numel * sizeof(int64_t));
194 int64_t* reverse_indices_ptr = static_cast<int64_t*>(reverse_indices.get());
195
196 using namespace at::cuda::detail;
197 dim3 block = CUDA_NUM_THREADS;
198 dim3 grid = GET_BLOCKS(numel);
199 auto stream = c10::cuda::getCurrentCUDAStream();
200 at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
201 fill_reverse_indices_kernel<<<grid, block, 0, stream>>>(
202 reverse_indices_ptr, numel, nsort_divider);
203
204 at::cuda::cub::segmented_sort_pairs(
205 self_ptr,
206 values_ptr,
207 reverse_indices_ptr,
208 indices_ptr,
209 n,
210 nsegments,
211 offset_t{(int)nsort, 0},
212 offset_t{(int)nsort, 1},
213 descending);
214 }
215
216 } // namespace
217
launch_stable_sort_kernel(const TensorBase & self,int64_t dim,bool descending,const TensorBase & values,const TensorBase & indices)218 void launch_stable_sort_kernel(
219 const TensorBase& self,
220 int64_t dim,
221 bool descending,
222 const TensorBase& values,
223 const TensorBase& indices) {
224 const auto numel = self.numel();
225 if (numel == 0) {
226 return;
227 }
228
229 int64_t numel_or_intmax =
230 std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
231 int64_t nsort = self.size(dim);
232 int64_t nbatch = (numel_or_intmax / nsort) * nsort;
233 TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
234 int64_t* indices_ptr = indices.mutable_data_ptr<int64_t>();
235
236 AT_DISPATCH_ALL_TYPES_AND3(
237 kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&] {
238 const scalar_t* self_ptr = self.const_data_ptr<scalar_t>();
239 scalar_t* values_ptr = values.mutable_data_ptr<scalar_t>();
240 int64_t remaining = numel;
241 while (remaining > 0) {
242 int64_t n = std::min(remaining, nbatch);
243 int64_t nsegments = n / nsort;
244
245 if (nsegments == 1 ||
246 nsort >= 1000000) { // rough heuristics where even a single
247 // sort occupies GPU
248 segmented_sort_large_segments(
249 nsegments,
250 nsort,
251 n,
252 descending,
253 self_ptr,
254 values_ptr,
255 indices_ptr);
256 } else if (nsegments < 128) {
257 segmented_sort_pairs_by_full_sort(
258 nsegments,
259 nsort,
260 n,
261 descending,
262 self_ptr,
263 values_ptr,
264 indices_ptr);
265 } else {
266 segmented_sort_pairs(
267 nsegments,
268 nsort,
269 n,
270 descending,
271 self_ptr,
272 values_ptr,
273 indices_ptr);
274 }
275
276 remaining -= n;
277 self_ptr += n;
278 values_ptr += n;
279 indices_ptr += n;
280 }
281 });
282 }
283
284 } // namespace at::native
285