xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Unique.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch_v2.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/ThrustAllocator.h>
6 
7 #include <c10/util/Load.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/_unique2_native.h>
13 #include <ATen/ops/_unique_native.h>
14 #include <ATen/ops/arange.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/unique_consecutive_native.h>
17 #include <ATen/ops/unique_dim_consecutive_native.h>
18 #include <ATen/ops/unique_dim_native.h>
19 #endif
20 
21 #include <tuple>
22 #include <iterator>
23 #include <thrust/adjacent_difference.h>
24 #include <thrust/execution_policy.h>
25 #include <thrust/unique.h>
26 #include <thrust/sort.h>
27 #include <thrust/scan.h>
28 #include <thrust/scatter.h>
29 
30 #include <ATen/native/cuda/UniqueCub.cuh>
31 
32 namespace at::native {
33 
34 namespace {
35 
36 template <
37   typename policy_t, typename scalar_t,
38   typename equal_t, typename not_equal_t
39 >
compute_unique(const policy_t & policy,scalar_t * data,int64_t num_inp,const Tensor & sorted_indices,const bool return_inverse,const bool return_counts,TensorOptions options,equal_t equal,not_equal_t not_equal)40 std::tuple<Tensor, Tensor, int64_t> compute_unique(
41   const policy_t &policy,
42   scalar_t *data,
43   int64_t num_inp,
44   const Tensor &sorted_indices,
45   const bool return_inverse,
46   const bool return_counts,
47   TensorOptions options,
48   equal_t equal,
49   not_equal_t not_equal
50 ) {
51   // inverse indices
52   Tensor inverse_indices;
53   if (!return_inverse || num_inp == 0) {
54     inverse_indices = at::empty({0}, options);
55   } else {
56     TORCH_CHECK(sorted_indices.defined(),
57       "return_inverse is set to true, but sorted_indices is undefined. Send a bug report!");
58     const int64_t *sorted_indices_ptr = sorted_indices.const_data_ptr<int64_t>();
59     Tensor inv_loc = at::empty({num_inp}, options);
60     inverse_indices = at::empty({num_inp}, options);
61     int64_t* inv_loc_ptr = inv_loc.mutable_data_ptr<int64_t>();
62     int64_t* inverse_indices_ptr = inverse_indices.mutable_data_ptr<int64_t>();
63     thrust::adjacent_difference(policy, data, data + num_inp, inv_loc_ptr, not_equal);
64     inv_loc[0] = 0;
65     thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
66     thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
67   }
68 
69   // unique and count
70   Tensor counts = at::empty({0}, options);
71   int64_t num_out;
72   if (!return_counts) {
73     num_out = thrust::unique(policy, data, data + num_inp, equal) - data;
74   } else {
75     Tensor range = at::arange(0, num_inp + 1, options);
76     int64_t *range_ptr = range.mutable_data_ptr<int64_t>();
77     num_out = thrust::unique_by_key(policy, data, data + num_inp, range_ptr, equal).first - data;
78     range[num_out] = num_inp;
79     counts.resize_(num_out);
80     int64_t* counts_ptr = counts.mutable_data_ptr<int64_t>();
81     thrust::adjacent_difference(policy, range_ptr + 1, range_ptr + num_out + 1, counts_ptr);
82   }
83 
84   AT_CUDA_CHECK(cudaGetLastError());
85   return std::tuple<Tensor, Tensor, int64_t>(inverse_indices, counts, num_out);
86 }
87 
88 template <typename scalar_t>
unique_dim_cuda_template(const Tensor & self,const int64_t dim,const bool consecutive,const bool return_inverse,const bool return_counts)89 std::tuple<Tensor, Tensor, Tensor> unique_dim_cuda_template(
90   const Tensor& self,
91   const int64_t dim,
92   const bool consecutive,
93   const bool return_inverse,
94   const bool return_counts
95 ) {
96 
97   /**
98     * The idea for implementing this is basically the same as unique.
99     * For unique_dim, we are taking the unique with respect to a index
100     * tensor, but during the processes, we override the compare and equal
101     * operator by checking the data underlying it instead. After the
102     * algorithm, we would use index_select to map the resulting indices
103     * to the result on the actual data.
104     */
105 
106   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
107   at::cuda::ThrustAllocator allocator;
108   auto policy = thrust::cuda::par(allocator).on(stream);
109 
110   auto sizes = self.sizes().vec();
111   // check how many zero dimensions exist
112   auto num_zero_dims = std::count(sizes.begin(), sizes.end(), 0);
113 
114   // tensor is not well formed as it has 0 sized dimensions
115   if (self.size(dim) == 0){
116     TORCH_CHECK(
117         num_zero_dims == 1,
118         "Number of zero sized dimensions is more than one, so unique cannot be applied ")
119     Tensor output = at::empty(sizes, self.options());
120     Tensor inverse_indices =
121         at::empty({0}, self.options().dtype(kLong));
122     Tensor counts = at::empty({0}, self.options().dtype(kLong));
123 
124     return std::make_tuple(output, inverse_indices, counts);
125   }
126 
127   TORCH_CHECK(num_zero_dims == 0,
128     "There are 0 sized dimensions, and they aren't selected, so unique cannot be applied");
129 
130   int64_t num_inp = self.size(dim);
131   auto options = self.options().dtype(kLong);
132   Tensor input_flat = self.moveaxis(dim, 0).contiguous().view({num_inp, -1});
133   int64_t n = input_flat.size(1);
134   const scalar_t *input_flat_ptr = input_flat.const_data_ptr<scalar_t>();
135 
136   Tensor indices = at::arange(0, num_inp, options);
137   int64_t *indices_data = indices.mutable_data_ptr<int64_t>();
138   if (!consecutive) {
139     thrust::sort(policy, indices_data, indices_data + num_inp,
140       [=] __device__ (int64_t a, int64_t b) -> bool {
141         for (int64_t i = 0; i < n; ++i) {
142           scalar_t lhs = c10::load(&input_flat_ptr[i + a * n]);
143           scalar_t rhs = c10::load(&input_flat_ptr[i + b * n]);
144           if (lhs < rhs) {
145             return true;
146           } else if (lhs > rhs) {
147             return false;
148           }
149         }
150         return false;
151       }
152     );
153   }
154 
155   auto [inverse_indices, counts, num_out] = compute_unique(
156     policy, indices_data, num_inp, indices,
157     return_inverse, return_counts, options,
158     [=] __device__ (int64_t a, int64_t b) -> bool {
159       for (int64_t i = 0; i < n; ++i) {
160         scalar_t lhs = c10::load(&input_flat_ptr[i + a * n]);
161         scalar_t rhs = c10::load(&input_flat_ptr[i + b * n]);
162         if (lhs != rhs) {
163           return false;
164         }
165       }
166       return true;
167     },
168     [=] __device__ (int64_t a, int64_t b) -> int64_t {
169       for (int64_t i = 0; i < n; ++i) {
170         scalar_t lhs = c10::load(&input_flat_ptr[i + a * n]);
171         scalar_t rhs = c10::load(&input_flat_ptr[i + b * n]);
172         if (lhs != rhs) {
173           return 1;
174         }
175       }
176       return 0;
177     }
178   );
179   indices.resize_(num_out);
180 
181   return std::tuple<Tensor, Tensor, Tensor>(self.index_select(dim, indices), inverse_indices, counts);
182 }
183 
184 } // namespace
185 
186 
187 std::tuple<Tensor, Tensor>
_unique_cuda(const Tensor & self,const bool sorted,const bool return_inverse)188 _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
189   return AT_DISPATCH_V2(self.scalar_type(), "unique", AT_WRAP([&] {
190     // The current CUDA implementation of unique always sort due to the
191     // lack of hashtable implementation in thrust
192     auto [output, inverse, _] = internal::unique_cuda_template<scalar_t>(self, false, return_inverse, false);
193     return std::make_tuple(output, inverse);
194   }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
195 }
196 
197 std::tuple<Tensor, Tensor, Tensor>
_unique2_cuda(const Tensor & self,const bool sorted,const bool return_inverse,const bool return_counts)198 _unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
199   return AT_DISPATCH_V2(self.scalar_type(), "unique", AT_WRAP([&] {
200     // The current CUDA implementation of unique always sort due to the
201     // lack of hashtable implementation in thrust
202     return internal::unique_cuda_template<scalar_t>(self, false, return_inverse, return_counts);
203   }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
204 }
205 
206 std::tuple<Tensor, Tensor, Tensor>
unique_dim_cuda(const Tensor & self,const int64_t dim,const bool sorted,const bool return_inverse,const bool return_counts)207 unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
208   return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] {
209     return unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, return_counts);
210   }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
211 }
212 
213 std::tuple<Tensor, Tensor, Tensor>
unique_dim_consecutive_cuda(const Tensor & self,const int64_t dim,const bool return_inverse,const bool return_counts)214 unique_dim_consecutive_cuda(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) {
215   return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] {
216     return unique_dim_cuda_template<scalar_t>(self, dim, true, return_inverse, return_counts);
217   }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
218 }
219 
220 std::tuple<Tensor, Tensor, Tensor>
unique_consecutive_cuda(const Tensor & self,const bool return_inverse,const bool return_counts,std::optional<int64_t> dim)221 unique_consecutive_cuda(const Tensor& self, const bool return_inverse, const bool return_counts, std::optional<int64_t> dim) {
222   if (!dim.has_value()) {
223     return AT_DISPATCH_V2(self.scalar_type(), "unique", AT_WRAP([&] {
224       // The current CUDA implementation of unique always sort due to the
225       // lack of hashtable implementation in thrust
226       return internal::unique_cuda_template<scalar_t>(self, true, return_inverse, return_counts);
227     }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
228   }
229   return unique_dim_consecutive_cuda(self, dim.value(), return_inverse, return_counts);
230 }
231 
232 }  // namespace at::native
233