1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Sort.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/MemoryOverlap.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/WrapDimUtils.h>
8 #include <ATen/native/Sorting.h>
9 #include <ATen/native/Resize.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/arange.h>
16 #include <ATen/ops/empty_like.h>
17 #include <ATen/ops/empty_strided.h>
18 #include <ATen/ops/sort_native.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21
22 #include <limits>
23
24 namespace at::native {
25
26 std::vector<int64_t> infer_dense_strides_dim_last(const Tensor & self, int64_t dim);
27
fillSliceWithIndex(const Tensor & t,int dim)28 void fillSliceWithIndex(const Tensor& t, int dim) {
29 if (t.numel()) {
30 auto sizes = DimVector(t.dim(), 1);
31 sizes[dim] = t.sizes()[dim];
32 auto range = at::arange(t.sizes()[dim], t.options());
33 auto rangeview = range.view(sizes);
34 t.copy_(rangeview);
35 }
36 }
37
38 // We perform a segmented sort in cub with inputs that have
39 // more than 1024/2048 elements along the selected dimension.
40 // Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace).
sort_cuda_kernel(const TensorBase & self_base,const TensorBase & values_base,const TensorBase & indices_base,int64_t dim,bool descending,bool stable)41 void sort_cuda_kernel(
42 const TensorBase& self_base,
43 const TensorBase& values_base,
44 const TensorBase& indices_base,
45 int64_t dim,
46 bool descending,
47 bool stable) {
48 // this algorithm is always stable
49
50 // Macro for converting `TensorBase` -> `Tensor` without
51 // reference count bumps.
52 #define TOTENSOR(BASE, VAR) \
53 OptionalTensorRef opt_##BASE(BASE); \
54 const Tensor& VAR = *opt_##BASE;
55
56 // Converting TensorBase into Tensor.
57 // We will need Tensor's methods from this point onwards.
58 TOTENSOR(self_base, self);
59 TOTENSOR(values_base, values);
60 TOTENSOR(indices_base, indices);
61
62 TORCH_CHECK(self.sizes()[dim] <= std::numeric_limits<int>::max(),
63 "The dimension being sorted can not have more than INT_MAX elements.");
64
65 const auto self_dtype = self.dtype();
66 // FIXME: remove this check once cub sort supports bool
67 TORCH_CHECK(self_dtype != ScalarType::Bool,
68 "Sort currently does not support bool dtype on CUDA.");
69 TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble,
70 "Sort currently does not support complex dtypes on CUDA.");
71
72 // use inplace algorithm for smaller input sizes without stable=True
73 if (should_use_small_sort(self, dim)) {
74 // from thc: sorted->values, indices->indices, input->self
75 fillSliceWithIndex(indices, dim);
76
77 // We sort k/v pairs in-place; copy unsorted input to output
78 values.copy_(self);
79
80 // Sort using our in-place k/v kernel that supports arbitrary
81 // layout
82 sortKeyValueInplace(values, indices, dim, descending, stable);
83 return;
84 }
85
86 Tensor self_;
87 bool newself = false;
88 if (self.is_non_overlapping_and_dense() && self.stride(dim) == 1) {
89 self_ = self;
90 } else {
91 auto new_strides_unsort = infer_dense_strides_dim_last(self, dim);
92 self_ = at::empty_strided(self.sizes(), new_strides_unsort, self.options());
93 self_.copy_(self);
94 newself = true;
95 }
96
97 c10::MaybeOwned<Tensor> values_tmp, indices_tmp;
98 if (values.strides() == self_.strides() && (newself || get_overlap_status(self, values) == MemOverlapStatus::No)) {
99 values_tmp = c10::MaybeOwned<Tensor>::borrowed(values);
100 } else {
101 values_tmp = c10::MaybeOwned<Tensor>::owned(
102 at::empty_strided(self_.sizes(), self_.strides(), self_.options()));
103 }
104
105 if (indices.strides() != self_.strides()) {
106 indices_tmp = c10::MaybeOwned<Tensor>::owned(
107 at::empty_strided(self_.sizes(), self_.strides(), self_.options().dtype(kLong)));
108 } else {
109 indices_tmp = c10::MaybeOwned<Tensor>::borrowed(indices);
110 }
111
112 launch_stable_sort_kernel(self_, dim, descending, *values_tmp, *indices_tmp);
113
114 if (!values_tmp->is_same(values)) {
115 values.copy_(*values_tmp);
116 }
117 if (!indices_tmp->is_same(indices)) {
118 indices.copy_(*indices_tmp);
119 }
120 }
121
122 // TODO: we should handle this accordingly when we start using REGISTER_HIP_DISPATCH,
123 // since REGISTER_DISPATCH won't work in this cpp file.
124 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
125 REGISTER_CUDA_DISPATCH(sort_stub, &sort_cuda_kernel);
126
127 } // namespace at::native
128