1 #define TORCH_ASSERT_NO_OPERATORS
2
3 #include <limits>
4
5 #include <ATen/native/Sorting.h>
6 #include <ATen/core/TensorBase.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/Dispatch_v2.h>
9 #include <ATen/Parallel.h>
10 #include <ATen/NumericUtils.h>
11 #include <ATen/TensorIterator.h>
12 #include <ATen/cpu/vec/functional.h>
13 #include <ATen/cpu/vec/vec.h>
14 #include <ATen/native/StridedRandomAccessor.h>
15 #include <ATen/native/CompositeRandomAccessor.h>
16 #include <ATen/native/TopKImpl.h>
17 #include <c10/core/WrapDimMinimal.h>
18 #include <c10/util/irange.h>
19 #ifdef USE_FBGEMM
20 #include <fbgemm/Utils.h>
21 #endif
22
23 namespace at::native {
24
25 namespace {
26
27 template <typename func_t>
_dim_apply(const TensorBase & values,const TensorBase & indices,int64_t dim,const std::string & method_name,const func_t & f)28 void _dim_apply(
29 const TensorBase &values,
30 const TensorBase &indices,
31 int64_t dim,
32 const std::string& method_name,
33 const func_t& f) {
34 auto iter = TensorIteratorConfig()
35 .check_all_same_dtype(false)
36 .resize_outputs(false)
37 .declare_static_shape(values.sizes(), /*squash_dims=*/dim)
38 .add_output(values)
39 .add_output(indices)
40 .build();
41
42 auto values_dim_stride = values.stride(dim);
43 auto indices_dim_stride = indices.stride(dim);
44 auto dim_size = values.size(dim);
45
46 AT_DISPATCH_V2(
47 iter.dtype(), "sorting_kernel_method_name", AT_WRAP([&] {
48 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
49 auto* values_data_bytes = data[0];
50 auto* indices_data_bytes = data[1];
51
52 if(values_data_bytes==nullptr || indices_data_bytes==nullptr){
53 return;
54 }
55
56 for (const auto i C10_UNUSED : c10::irange(n)) {
57 f(
58 reinterpret_cast<scalar_t*>(values_data_bytes),
59 values_dim_stride,
60 reinterpret_cast<int64_t*>(indices_data_bytes),
61 indices_dim_stride,
62 dim_size
63 );
64
65 values_data_bytes += strides[0];
66 indices_data_bytes += strides[1];
67 }
68 };
69
70 int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, dim_size);
71 iter.for_each(loop, /*grain_size=*/grain_size);
72 }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
73 );
74 }
75
76 template <typename scalar_t>
77 struct KeyValueCompAsc {
78 template <typename LHS, typename RHS>
operator ()at::native::__anon34baf7e90111::KeyValueCompAsc79 constexpr bool operator()(LHS lhs, RHS rhs) const {
80 return (!_isnan<scalar_t>(get<0>(lhs)) && _isnan<scalar_t>(get<0>(rhs)))
81 || (get<0>(lhs) < get<0>(rhs));
82 }
83 };
84
85 template <typename scalar_t>
86 struct KeyValueCompDesc {
87 template <typename LHS, typename RHS>
operator ()at::native::__anon34baf7e90111::KeyValueCompDesc88 constexpr bool operator()(LHS lhs, RHS rhs) const {
89 return (_isnan<scalar_t>(get<0>(lhs)) && !_isnan<scalar_t>(get<0>(rhs)))
90 || (get<0>(lhs) > get<0>(rhs));
91 }
92 };
93
94 #ifdef USE_FBGEMM
can_use_radix_sort(const TensorBase & values,const bool descending)95 static bool can_use_radix_sort(const TensorBase& values, const bool descending) {
96 // radix_sort can be used only for 1D data
97 if (values.dim() != 1) return false;
98 // radix_sort sorts in ascending order
99 if (descending) return false;
100 // radix_sort works for integer values
101 if (!at::isIntegralType(values.scalar_type(), /*includeBool=*/false)) return false;
102 // performance improvements are visible for bigger tensor sizes, when radix_sort
103 // is accelerated with OpenMP
104 if (values.numel() < at::internal::GRAIN_SIZE || !fbgemm::is_radix_sort_accelerated_with_openmp()) return false;
105 // TODO(DamianSzwichtenberg): radix_sort is a stable sorting algorithm,
106 // should we check here, whether stable is set to true?
107
108 return true;
109 }
110
parallel_sort1d_kernel(const TensorBase & values,const TensorBase & indices)111 static void parallel_sort1d_kernel(
112 const TensorBase& values,
113 const TensorBase& indices) {
114 AT_DISPATCH_INTEGRAL_TYPES(values.scalar_type(), "parallel_sort1d_kernel", [&] {
115 const auto elements = values.numel();
116 auto* const keys = values.data_ptr<scalar_t>();
117 auto* const vals = indices.data_ptr<int64_t>();
118 std::vector<scalar_t> tmp_keys(elements);
119 std::vector<int64_t> tmp_vals(elements);
120 const scalar_t* sorted_keys = nullptr;
121 const int64_t* sorted_vals = nullptr;
122 std::tie(sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel(
123 keys,
124 vals,
125 tmp_keys.data(),
126 tmp_vals.data(),
127 elements,
128 std::numeric_limits<scalar_t>::max(),
129 values.scalar_type() != ScalarType::Byte);
130
131 const bool sorted_in_place = keys == sorted_keys;
132 if (!sorted_in_place) {
133 const auto num_threads = at::get_num_threads();
134 at::parallel_for(0, elements, elements / num_threads, [&](int64_t begin, int64_t end) {
135 const auto job_size = end - begin;
136 vec::map([](vec::Vectorized<scalar_t> x) -> vec::Vectorized<scalar_t> { return x; }, keys + begin, sorted_keys + begin, job_size);
137 vec::map([](vec::Vectorized<int64_t> x) -> vec::Vectorized<int64_t> { return x; }, vals + begin, sorted_vals + begin, job_size);
138 });
139 }
140 });
141 }
142 #endif
143
144 template <typename scalar_t, typename value_accessor_t, typename indices_accessor_t>
sort_kernel_impl(const value_accessor_t & value_accessor,const indices_accessor_t & indices_accessor,int64_t dim_size,bool descending,bool stable)145 static inline void sort_kernel_impl(const value_accessor_t& value_accessor,
146 const indices_accessor_t& indices_accessor,
147 int64_t dim_size, bool descending, bool stable) {
148 auto composite_accessor = CompositeRandomAccessorCPU<
149 value_accessor_t, indices_accessor_t
150 >(value_accessor, indices_accessor);
151 if (descending) {
152 if (stable) {
153 std::stable_sort(composite_accessor, composite_accessor + dim_size,
154 KeyValueCompDesc<scalar_t>());
155 } else {
156 std::sort(composite_accessor, composite_accessor + dim_size,
157 KeyValueCompDesc<scalar_t>());
158 }
159 } else {
160 if (stable) {
161 std::stable_sort(composite_accessor, composite_accessor + dim_size,
162 KeyValueCompAsc<scalar_t>());
163 } else {
164 std::sort(composite_accessor, composite_accessor + dim_size,
165 KeyValueCompAsc<scalar_t>());
166 }
167 }
168 }
169
sort_kernel(const TensorBase & self,const TensorBase & values,const TensorBase & indices,int64_t dim,bool descending,bool stable)170 static void sort_kernel(
171 const TensorBase& self,
172 const TensorBase& values,
173 const TensorBase& indices,
174 int64_t dim,
175 bool descending,
176 bool stable) {
177 dim = maybe_wrap_dim(dim, values.dim());
178 _fill_indices(indices, dim);
179 if (self.stride(dim) == 0) {
180 // check if stride is zero
181 // https://github.com/pytorch/pytorch/issues/91420
182 return;
183 }
184 #ifdef USE_FBGEMM
185 if (can_use_radix_sort(values, descending)) {
186 parallel_sort1d_kernel(values, indices);
187 return;
188 }
189 #endif
190 _dim_apply(
191 values, indices, dim,
192 "sort_cpu", [&](
193 auto* values, int64_t values_dim_stride,
194 auto* indices, int64_t indices_dim_stride,
195 int64_t dim_size
196 ) {
197 using scalar_t = std::remove_pointer_t<decltype(values)>;
198 if (values_dim_stride == 1 && indices_dim_stride == 1) {
199 sort_kernel_impl<
200 scalar_t, decltype(values), decltype(indices)
201 >(values, indices, dim_size, descending, stable);
202 } else if (values_dim_stride == 1 && indices_dim_stride != 1) {
203 auto indices_accessor = StridedRandomAccessor<int64_t>(
204 indices, indices_dim_stride);
205 sort_kernel_impl<
206 scalar_t, decltype(values), decltype(indices_accessor)
207 >(values, indices_accessor, dim_size, descending, stable);
208 } else if (values_dim_stride != 1 && indices_dim_stride == 1) {
209 auto values_accessor = StridedRandomAccessor<scalar_t>(
210 values, values_dim_stride);
211 sort_kernel_impl<
212 scalar_t, decltype(values_accessor), decltype(indices)
213 >(values_accessor, indices, dim_size, descending, stable);
214 } else {
215 auto values_accessor = StridedRandomAccessor<scalar_t>(
216 values, values_dim_stride);
217 auto indices_accessor = StridedRandomAccessor<int64_t>(
218 indices, indices_dim_stride);
219 sort_kernel_impl<
220 scalar_t, decltype(values_accessor), decltype(indices_accessor)
221 >(values_accessor, indices_accessor, dim_size, descending, stable);
222 }
223 }
224 );
225 }
226
topk_kernel(const TensorBase & values,const TensorBase & indices,const TensorBase & self,int64_t k,int64_t dim,bool largest,bool sorted)227 static void topk_kernel(
228 const TensorBase &values,
229 const TensorBase &indices,
230 const TensorBase &self,
231 int64_t k,
232 int64_t dim,
233 bool largest,
234 bool sorted) {
235 auto sizes = self.sizes();
236 auto iter = TensorIteratorConfig()
237 .check_all_same_dtype(false)
238 .resize_outputs(false)
239 .declare_static_shape(sizes, /*squash_dims=*/dim)
240 .add_output(values)
241 .add_output(indices)
242 .add_const_input(self)
243 .build();
244
245 auto mode_values_stride = values.strides()[dim];
246 auto mode_indices_stride = indices.strides()[dim];
247 auto tmp_values_stride = self.strides()[dim];
248
249 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "topk_cpu", [&] {
250 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
251 if (self.scalar_type() == ScalarType::BFloat16) {
252 return topk_impl_loop<scalar_t, float>(
253 mode_values_stride, mode_indices_stride, tmp_values_stride,
254 k, sizes[dim], largest, sorted, data, strides, n);
255 } else {
256 return topk_impl_loop<scalar_t, scalar_t>(
257 mode_values_stride, mode_indices_stride, tmp_values_stride,
258 k, sizes[dim], largest, sorted, data, strides, n);
259 }
260 };
261
262 int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
263 iter.for_each(loop, /*grain_size=*/grain_size);
264 });
265 }
266
267 } // anonymous namespace
268
269 REGISTER_DISPATCH(sort_stub, &sort_kernel);
270 REGISTER_DISPATCH(topk_stub, &topk_kernel);
271
272 } //at::native
273