xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Sort.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Sort.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/MemoryOverlap.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/WrapDimUtils.h>
8 #include <ATen/native/Sorting.h>
9 #include <ATen/native/Resize.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/arange.h>
16 #include <ATen/ops/empty_like.h>
17 #include <ATen/ops/empty_strided.h>
18 #include <ATen/ops/sort_native.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21 
22 #include <limits>
23 
24 namespace at::native {
25 
26 std::vector<int64_t> infer_dense_strides_dim_last(const Tensor & self, int64_t dim);
27 
fillSliceWithIndex(const Tensor & t,int dim)28 void fillSliceWithIndex(const Tensor& t, int dim) {
29   if (t.numel()) {
30     auto sizes = DimVector(t.dim(), 1);
31     sizes[dim] = t.sizes()[dim];
32     auto range = at::arange(t.sizes()[dim], t.options());
33     auto rangeview = range.view(sizes);
34     t.copy_(rangeview);
35   }
36 }
37 
38 // We perform a segmented sort in cub with inputs that have
39 // more than 1024/2048 elements along the selected dimension.
40 // Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace).
sort_cuda_kernel(const TensorBase & self_base,const TensorBase & values_base,const TensorBase & indices_base,int64_t dim,bool descending,bool stable)41 void sort_cuda_kernel(
42     const TensorBase& self_base,
43     const TensorBase& values_base,
44     const TensorBase& indices_base,
45     int64_t dim,
46     bool descending,
47     bool stable) {
48   // this algorithm is always stable
49 
50   // Macro for converting `TensorBase` -> `Tensor` without
51   // reference count bumps.
52 #define TOTENSOR(BASE, VAR)           \
53   OptionalTensorRef opt_##BASE(BASE); \
54   const Tensor& VAR = *opt_##BASE;
55 
56   // Converting TensorBase into Tensor.
57   // We will need Tensor's methods from this point onwards.
58   TOTENSOR(self_base, self);
59   TOTENSOR(values_base, values);
60   TOTENSOR(indices_base, indices);
61 
62   TORCH_CHECK(self.sizes()[dim] <= std::numeric_limits<int>::max(),
63     "The dimension being sorted can not have more than INT_MAX elements.");
64 
65   const auto self_dtype = self.dtype();
66   // FIXME: remove this check once cub sort supports bool
67   TORCH_CHECK(self_dtype != ScalarType::Bool,
68     "Sort currently does not support bool dtype on CUDA.");
69   TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble,
70     "Sort currently does not support complex dtypes on CUDA.");
71 
72   // use inplace algorithm for smaller input sizes without stable=True
73   if (should_use_small_sort(self, dim)) {
74     // from thc: sorted->values, indices->indices, input->self
75     fillSliceWithIndex(indices, dim);
76 
77     // We sort k/v pairs in-place; copy unsorted input to output
78     values.copy_(self);
79 
80     // Sort using our in-place k/v kernel that supports arbitrary
81     // layout
82     sortKeyValueInplace(values, indices, dim, descending, stable);
83     return;
84   }
85 
86   Tensor self_;
87   bool newself = false;
88   if (self.is_non_overlapping_and_dense() && self.stride(dim) == 1) {
89     self_ = self;
90   } else {
91     auto new_strides_unsort = infer_dense_strides_dim_last(self, dim);
92     self_ = at::empty_strided(self.sizes(), new_strides_unsort, self.options());
93     self_.copy_(self);
94     newself = true;
95   }
96 
97   c10::MaybeOwned<Tensor> values_tmp, indices_tmp;
98   if (values.strides() == self_.strides() && (newself || get_overlap_status(self, values) == MemOverlapStatus::No)) {
99     values_tmp = c10::MaybeOwned<Tensor>::borrowed(values);
100   } else {
101     values_tmp = c10::MaybeOwned<Tensor>::owned(
102         at::empty_strided(self_.sizes(), self_.strides(), self_.options()));
103   }
104 
105   if (indices.strides() != self_.strides()) {
106     indices_tmp = c10::MaybeOwned<Tensor>::owned(
107         at::empty_strided(self_.sizes(), self_.strides(), self_.options().dtype(kLong)));
108   } else {
109     indices_tmp = c10::MaybeOwned<Tensor>::borrowed(indices);
110   }
111 
112   launch_stable_sort_kernel(self_, dim, descending, *values_tmp, *indices_tmp);
113 
114   if (!values_tmp->is_same(values)) {
115     values.copy_(*values_tmp);
116   }
117   if (!indices_tmp->is_same(indices)) {
118     indices.copy_(*indices_tmp);
119   }
120 }
121 
122 // TODO: we should handle this accordingly when we start using REGISTER_HIP_DISPATCH,
123 // since REGISTER_DISPATCH won't work in this cpp file.
124 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
125 REGISTER_CUDA_DISPATCH(sort_stub, &sort_cuda_kernel);
126 
127 }  // namespace at::native
128