xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/SortingKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 
3 #include <limits>
4 
5 #include <ATen/native/Sorting.h>
6 #include <ATen/core/TensorBase.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/Dispatch_v2.h>
9 #include <ATen/Parallel.h>
10 #include <ATen/NumericUtils.h>
11 #include <ATen/TensorIterator.h>
12 #include <ATen/cpu/vec/functional.h>
13 #include <ATen/cpu/vec/vec.h>
14 #include <ATen/native/StridedRandomAccessor.h>
15 #include <ATen/native/CompositeRandomAccessor.h>
16 #include <ATen/native/TopKImpl.h>
17 #include <c10/core/WrapDimMinimal.h>
18 #include <c10/util/irange.h>
19 #ifdef USE_FBGEMM
20 #include <fbgemm/Utils.h>
21 #endif
22 
23 namespace at::native {
24 
25 namespace {
26 
27 template <typename func_t>
_dim_apply(const TensorBase & values,const TensorBase & indices,int64_t dim,const std::string & method_name,const func_t & f)28 void _dim_apply(
29     const TensorBase &values,
30     const TensorBase &indices,
31     int64_t dim,
32     const std::string& method_name,
33     const func_t& f) {
34   auto iter = TensorIteratorConfig()
35     .check_all_same_dtype(false)
36     .resize_outputs(false)
37     .declare_static_shape(values.sizes(), /*squash_dims=*/dim)
38     .add_output(values)
39     .add_output(indices)
40     .build();
41 
42   auto values_dim_stride = values.stride(dim);
43   auto indices_dim_stride = indices.stride(dim);
44   auto dim_size = values.size(dim);
45 
46   AT_DISPATCH_V2(
47     iter.dtype(), "sorting_kernel_method_name", AT_WRAP([&] {
48       auto loop = [&](char** data, const int64_t* strides, int64_t n) {
49         auto* values_data_bytes = data[0];
50         auto* indices_data_bytes = data[1];
51 
52         if(values_data_bytes==nullptr || indices_data_bytes==nullptr){
53           return;
54         }
55 
56         for (const auto i C10_UNUSED : c10::irange(n)) {
57           f(
58             reinterpret_cast<scalar_t*>(values_data_bytes),
59             values_dim_stride,
60             reinterpret_cast<int64_t*>(indices_data_bytes),
61             indices_dim_stride,
62             dim_size
63           );
64 
65           values_data_bytes += strides[0];
66           indices_data_bytes += strides[1];
67         }
68       };
69 
70       int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, dim_size);
71       iter.for_each(loop, /*grain_size=*/grain_size);
72     }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
73   );
74 }
75 
76 template <typename scalar_t>
77 struct KeyValueCompAsc {
78   template <typename LHS, typename RHS>
operator ()at::native::__anon34baf7e90111::KeyValueCompAsc79   constexpr bool operator()(LHS lhs, RHS rhs) const {
80     return (!_isnan<scalar_t>(get<0>(lhs)) && _isnan<scalar_t>(get<0>(rhs)))
81       || (get<0>(lhs) < get<0>(rhs));
82   }
83 };
84 
85 template <typename scalar_t>
86 struct KeyValueCompDesc {
87   template <typename LHS, typename RHS>
operator ()at::native::__anon34baf7e90111::KeyValueCompDesc88   constexpr bool operator()(LHS lhs, RHS rhs) const {
89     return (_isnan<scalar_t>(get<0>(lhs)) && !_isnan<scalar_t>(get<0>(rhs)))
90       || (get<0>(lhs) > get<0>(rhs));
91   }
92 };
93 
94 #ifdef USE_FBGEMM
can_use_radix_sort(const TensorBase & values,const bool descending)95 static bool can_use_radix_sort(const TensorBase& values, const bool descending) {
96   // radix_sort can be used only for 1D data
97   if (values.dim() != 1) return false;
98   // radix_sort sorts in ascending order
99   if (descending) return false;
100   // radix_sort works for integer values
101   if (!at::isIntegralType(values.scalar_type(), /*includeBool=*/false)) return false;
102   // performance improvements are visible for bigger tensor sizes, when radix_sort
103   // is accelerated with OpenMP
104   if (values.numel() < at::internal::GRAIN_SIZE || !fbgemm::is_radix_sort_accelerated_with_openmp()) return false;
105   // TODO(DamianSzwichtenberg): radix_sort is a stable sorting algorithm,
106   // should we check here, whether stable is set to true?
107 
108   return true;
109 }
110 
parallel_sort1d_kernel(const TensorBase & values,const TensorBase & indices)111 static void parallel_sort1d_kernel(
112     const TensorBase& values,
113     const TensorBase& indices) {
114   AT_DISPATCH_INTEGRAL_TYPES(values.scalar_type(), "parallel_sort1d_kernel", [&] {
115     const auto elements = values.numel();
116     auto* const keys = values.data_ptr<scalar_t>();
117     auto* const vals = indices.data_ptr<int64_t>();
118     std::vector<scalar_t> tmp_keys(elements);
119     std::vector<int64_t> tmp_vals(elements);
120     const scalar_t* sorted_keys = nullptr;
121     const int64_t* sorted_vals = nullptr;
122     std::tie(sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel(
123         keys,
124         vals,
125         tmp_keys.data(),
126         tmp_vals.data(),
127         elements,
128         std::numeric_limits<scalar_t>::max(),
129         values.scalar_type() != ScalarType::Byte);
130 
131     const bool sorted_in_place = keys == sorted_keys;
132     if (!sorted_in_place) {
133       const auto num_threads = at::get_num_threads();
134       at::parallel_for(0, elements, elements / num_threads, [&](int64_t begin, int64_t end) {
135         const auto job_size = end - begin;
136         vec::map([](vec::Vectorized<scalar_t> x) -> vec::Vectorized<scalar_t> { return x; }, keys + begin, sorted_keys + begin, job_size);
137         vec::map([](vec::Vectorized<int64_t> x) -> vec::Vectorized<int64_t> { return x; }, vals + begin, sorted_vals + begin, job_size);
138       });
139     }
140   });
141 }
142 #endif
143 
144 template <typename scalar_t, typename value_accessor_t, typename indices_accessor_t>
sort_kernel_impl(const value_accessor_t & value_accessor,const indices_accessor_t & indices_accessor,int64_t dim_size,bool descending,bool stable)145 static inline void sort_kernel_impl(const value_accessor_t& value_accessor,
146             const indices_accessor_t& indices_accessor,
147             int64_t dim_size, bool descending, bool stable) {
148   auto composite_accessor = CompositeRandomAccessorCPU<
149     value_accessor_t, indices_accessor_t
150   >(value_accessor, indices_accessor);
151   if (descending) {
152     if (stable) {
153       std::stable_sort(composite_accessor, composite_accessor + dim_size,
154         KeyValueCompDesc<scalar_t>());
155     } else {
156       std::sort(composite_accessor, composite_accessor + dim_size,
157         KeyValueCompDesc<scalar_t>());
158     }
159   } else {
160     if (stable) {
161       std::stable_sort(composite_accessor, composite_accessor + dim_size,
162         KeyValueCompAsc<scalar_t>());
163     } else {
164       std::sort(composite_accessor, composite_accessor + dim_size,
165         KeyValueCompAsc<scalar_t>());
166     }
167   }
168 }
169 
sort_kernel(const TensorBase & self,const TensorBase & values,const TensorBase & indices,int64_t dim,bool descending,bool stable)170 static void sort_kernel(
171     const TensorBase& self,
172     const TensorBase& values,
173     const TensorBase& indices,
174     int64_t dim,
175     bool descending,
176     bool stable) {
177   dim = maybe_wrap_dim(dim, values.dim());
178   _fill_indices(indices, dim);
179   if (self.stride(dim) == 0) {
180     // check if stride is zero
181     // https://github.com/pytorch/pytorch/issues/91420
182     return;
183   }
184 #ifdef USE_FBGEMM
185   if (can_use_radix_sort(values, descending)) {
186     parallel_sort1d_kernel(values, indices);
187     return;
188   }
189 #endif
190   _dim_apply(
191     values, indices, dim,
192     "sort_cpu", [&](
193       auto* values, int64_t values_dim_stride,
194       auto* indices, int64_t indices_dim_stride,
195       int64_t dim_size
196     ) {
197       using scalar_t = std::remove_pointer_t<decltype(values)>;
198       if (values_dim_stride == 1 && indices_dim_stride == 1) {
199         sort_kernel_impl<
200           scalar_t, decltype(values), decltype(indices)
201         >(values, indices, dim_size, descending, stable);
202       } else if (values_dim_stride == 1 && indices_dim_stride != 1) {
203         auto indices_accessor = StridedRandomAccessor<int64_t>(
204           indices, indices_dim_stride);
205         sort_kernel_impl<
206           scalar_t, decltype(values), decltype(indices_accessor)
207         >(values, indices_accessor, dim_size, descending, stable);
208       } else if (values_dim_stride != 1 && indices_dim_stride == 1) {
209         auto values_accessor = StridedRandomAccessor<scalar_t>(
210           values, values_dim_stride);
211         sort_kernel_impl<
212           scalar_t, decltype(values_accessor), decltype(indices)
213         >(values_accessor, indices, dim_size, descending, stable);
214       } else {
215         auto values_accessor = StridedRandomAccessor<scalar_t>(
216           values, values_dim_stride);
217         auto indices_accessor = StridedRandomAccessor<int64_t>(
218           indices, indices_dim_stride);
219         sort_kernel_impl<
220           scalar_t, decltype(values_accessor), decltype(indices_accessor)
221         >(values_accessor, indices_accessor, dim_size, descending, stable);
222       }
223     }
224   );
225 }
226 
topk_kernel(const TensorBase & values,const TensorBase & indices,const TensorBase & self,int64_t k,int64_t dim,bool largest,bool sorted)227 static void topk_kernel(
228     const TensorBase &values,
229     const TensorBase &indices,
230     const TensorBase &self,
231     int64_t k,
232     int64_t dim,
233     bool largest,
234     bool sorted) {
235   auto sizes = self.sizes();
236   auto iter = TensorIteratorConfig()
237     .check_all_same_dtype(false)
238     .resize_outputs(false)
239     .declare_static_shape(sizes, /*squash_dims=*/dim)
240     .add_output(values)
241     .add_output(indices)
242     .add_const_input(self)
243     .build();
244 
245   auto mode_values_stride = values.strides()[dim];
246   auto mode_indices_stride = indices.strides()[dim];
247   auto tmp_values_stride = self.strides()[dim];
248 
249   AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "topk_cpu", [&] {
250     auto loop = [&](char** data, const int64_t* strides, int64_t n) {
251       if (self.scalar_type() == ScalarType::BFloat16) {
252         return topk_impl_loop<scalar_t, float>(
253             mode_values_stride, mode_indices_stride, tmp_values_stride,
254             k, sizes[dim], largest, sorted, data, strides, n);
255       } else {
256         return topk_impl_loop<scalar_t, scalar_t>(
257             mode_values_stride, mode_indices_stride, tmp_values_stride,
258             k, sizes[dim], largest, sorted, data, strides, n);
259       }
260     };
261 
262     int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
263     iter.for_each(loop, /*grain_size=*/grain_size);
264   });
265 }
266 
267 } // anonymous namespace
268 
269 REGISTER_DISPATCH(sort_stub, &sort_kernel);
270 REGISTER_DISPATCH(topk_stub, &topk_kernel);
271 
272 } //at::native
273