xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UniqueCub.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/UniqueCub.cuh>
3 
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/detail/KernelUtils.h>
6 #include <ATen/cuda/CUDAApplyUtils.cuh>
7 #include <ATen/cuda/cub.cuh>
8 
9 #include <c10/core/DeviceArray.h>
10 #include <c10/util/Load.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #else
15 #include <ATen/ops/arange.h>
16 #include <ATen/ops/empty.h>
17 #endif
18 
19 namespace at::native::internal {
20 
21 namespace {
22 
23 template <typename InputIteratorT>
adjacent_difference_kernel(int64_t n,InputIteratorT input,int * output)24 __global__ void adjacent_difference_kernel(
25     int64_t n,
26     InputIteratorT input,
27     int* output) {
28   CUDA_KERNEL_LOOP(i, n) {
29     output[i] = i > 0 ? input[i] != input[i - 1] : 0;
30   }
31 }
32 
scatter_kernel(int64_t n,const int64_t * input,const int64_t * indices,int64_t * output)33 __global__ void scatter_kernel(
34     int64_t n,
35     const int64_t* input,
36     const int64_t* indices,
37     int64_t* output) {
38   CUDA_KERNEL_LOOP(i, n) {
39     output[indices[i]] = input[i];
40   }
41 }
42 
43 template <typename scalar_t>
wrap_input_iterator(const scalar_t * data)44 const scalar_t * wrap_input_iterator(const scalar_t *data) {
45   return data;
46 }
47 
48 struct LoadBoolOp {
operator ()at::native::internal::__anone612ad3d0111::LoadBoolOp49   __device__ bool operator()(uint8_t x) const {
50     return static_cast<bool>(x);
51   }
52 };
53 
wrap_input_iterator(const bool * data)54 auto wrap_input_iterator(const bool *data) {
55   // See NOTE [Loading boolean values]
56   LoadBoolOp op;
57   return NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<bool, LoadBoolOp, const uint8_t*, int>(
58       reinterpret_cast<const uint8_t*>(data), op);
59 }
60 
61 // A variation of compute_unique (defined in Unique.cu) that doesn't allow
62 // customizing equal and not_equal (CUB doesn't allow them).
63 template <typename scalar_t>
compute_unique(const Tensor & sorted,const Tensor & sorted_indices,const bool return_inverse,const bool return_counts,const bool consecutive)64 std::tuple<Tensor, Tensor, Tensor> compute_unique(
65     const Tensor& sorted,
66     const Tensor& sorted_indices,
67     const bool return_inverse,
68     const bool return_counts,
69     const bool consecutive) {
70   int64_t num_inp = sorted.numel();
71   auto options = sorted.options().dtype(kLong);
72   auto data = wrap_input_iterator(sorted.const_data_ptr<scalar_t>());
73   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74 
75   // inverse indices
76   Tensor inverse_indices;
77   if (!return_inverse) {
78     inverse_indices = at::empty({0}, options);
79   } else {
80     inverse_indices = at::empty(sorted.sizes(), options);
81     Tensor inv_loc = consecutive ? at::empty({num_inp}, options.dtype(kInt))
82                                  : inverse_indices;
83     int* inv_loc_ptr = static_cast<int*>(inv_loc.mutable_data_ptr());
84     const dim3 block =
85         dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), num_inp));
86     dim3 grid;
87     c10::DeviceIndex curDevice = -1;
88     c10::cuda::GetDevice(&curDevice);
89     cuda::getApplyGrid(num_inp, grid, curDevice);
90     adjacent_difference_kernel<<<grid, block, 0, stream>>>(
91         num_inp, data, inv_loc_ptr);
92     C10_CUDA_KERNEL_LAUNCH_CHECK();
93 
94     Tensor inv_loc_out =
95         consecutive ? inverse_indices : at::empty({num_inp}, options);
96     at::cuda::cub::inclusive_sum_truncating(
97         inv_loc_ptr,
98         inv_loc_out.mutable_data_ptr<int64_t>(),
99         num_inp);
100 
101     if (!consecutive) {
102       TORCH_INTERNAL_ASSERT(
103           sorted_indices.defined(),
104           "return_inverse is set to true, but sorted_indices is undefined. Send a bug report!");
105       scatter_kernel<<<grid, block, 0, stream>>>(
106           num_inp,
107           inv_loc_out.const_data_ptr<int64_t>(),
108           sorted_indices.const_data_ptr<int64_t>(),
109           inverse_indices.mutable_data_ptr<int64_t>());
110       C10_CUDA_KERNEL_LAUNCH_CHECK();
111     }
112   }
113 
114   // unique and count
115   Tensor data_out = at::empty({num_inp}, sorted.options());
116   Tensor counts = at::empty({0}, options);
117   Tensor length = at::empty({1}, options);
118   int64_t num_out;
119   if (!return_counts) {
120     cuda::cub::unique(data, data_out.mutable_data_ptr<scalar_t>(), length.mutable_data_ptr<int64_t>(), num_inp);
121     num_out = length.item<int64_t>();
122   } else {
123     counts.resize_(num_inp);
124     at::cuda::cub::run_length_encode(
125         data,
126         data_out.mutable_data_ptr<scalar_t>(),
127         counts.mutable_data_ptr<int64_t>(),
128         length.mutable_data_ptr<int64_t>(),
129         num_inp);
130     num_out = length.item<int64_t>();
131     counts.resize_(num_out);
132   }
133 
134   data_out.resize_(num_out);
135   return std::tuple<Tensor, Tensor, Tensor>(
136       data_out, inverse_indices, counts);
137 }
138 
139 } // namespace
140 
141 // This function (and compute_unique above) are defined in a separate file from
142 // Unique.cu because for now ATen/cuda/cub.cuh can't be used together with
143 // thrust in the same compilation unit.
144 
145 template <typename scalar_t>
146 struct UniqueCub {
operator ()at::native::internal::UniqueCub147   std::tuple<Tensor, Tensor, Tensor> operator() (
148       const Tensor& self,
149       const bool consecutive,
150       const bool return_inverse,
151       const bool return_counts) {
152     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
153 
154     int64_t num_inp = self.numel();
155     Tensor sorted;
156     if (consecutive) {
157       sorted = self;
158     } else {
159       sorted = at::empty(self.sizes(), self.options());
160     }
161 
162     Tensor sorted_indices;
163     if (!return_inverse) {
164       if (!consecutive) {
165         cuda::cub::radix_sort_keys(
166           self.const_data_ptr<scalar_t>(),
167           sorted.mutable_data_ptr<scalar_t>(),
168           num_inp);
169       }
170     } else {
171       if (!consecutive) {
172         auto options = self.options().dtype(kLong);
173         Tensor range = at::arange(0, num_inp, options);
174         sorted_indices = at::empty({num_inp}, options);
175         cuda::cub::radix_sort_pairs(
176             self.const_data_ptr<scalar_t>(),
177             sorted.mutable_data_ptr<scalar_t>(),
178             range.const_data_ptr<int64_t>(),
179             sorted_indices.mutable_data_ptr<int64_t>(),
180             num_inp);
181       }
182     }
183 
184     return compute_unique<scalar_t>(
185         sorted, sorted_indices, return_inverse, return_counts, consecutive);
186   }
187 };
188 
189 struct MapNumberOfTrueValues {
operator ()at::native::internal::MapNumberOfTrueValues190   __device__ int operator()(uint8_t x) const {
191     return static_cast<bool>(x);
192   }
193 };
194 
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)195 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
196 __global__ void unique_bool_write_inverse_indices(
197     const int numel,
198     const int *num_true_p,
199     const bool *self,
200     int64_t *inverse_indices_out) {
201   constexpr int false_idx = 0;
202   const int num_true = *num_true_p;
203   const int num_false = numel - num_true;
204   const int true_idx = num_false > 0;
205 
206   CUDA_KERNEL_LOOP(i, numel) {
207     const auto value = c10::load(&self[i]);
208     inverse_indices_out[i] = value ? true_idx : false_idx;
209   }
210 }
211 
212 C10_LAUNCH_BOUNDS_1(1)
unique_bool_write_output(const int numel,const int * num_true_p,bool * values_out,int64_t * counts_out)213 __global__ void unique_bool_write_output(
214     const int numel,
215     const int *num_true_p,
216     bool *values_out,
217     int64_t *counts_out) {
218   constexpr int false_idx = 0;
219   const int num_true = *num_true_p;
220   const int num_false = numel - num_true;
221   const int true_idx = num_false > 0;
222 
223   if (blockIdx.x == 0 && threadIdx.x == 0) {
224     if (num_false > 0) {
225       values_out[false_idx] = false;
226       counts_out[false_idx] = num_false;
227     }
228     if (num_true > 0) {
229       values_out[true_idx] = true;
230       counts_out[true_idx] = num_true;
231     }
232   }
233 }
234 
235 template <>
236 struct UniqueCub<bool> {
237 
operator ()at::native::internal::UniqueCub238   std::tuple<Tensor, Tensor, Tensor> operator() (
239       const Tensor& self,
240       const bool consecutive,
241       const bool return_inverse,
242       const bool return_counts) {
243     auto stream = at::cuda::getCurrentCUDAStream();
244 
245     int64_t num_inp = self.numel();
246 
247     Tensor output, inverse_indices, counts;
248     if (consecutive) {
249       Tensor sorted_indices;
250       return compute_unique<bool>(
251           self, sorted_indices, return_inverse, return_counts, consecutive);
252     }
253 
254     // Instead of sorting, we use a reduction to find the number of
255     // true values and from that we can infer the number of false.
256     // If either has a count of zero, we omit it from the output.
257     auto allocator = at::cuda::getCUDADeviceAllocator();
258     c10::DeviceArray<int> tmp_num_true(*allocator, 1);
259 
260     const bool* self_data = self.const_data_ptr<bool>();
261     MapNumberOfTrueValues op;
262     NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<int, MapNumberOfTrueValues, const uint8_t*, int>
263         data_iter(reinterpret_cast<const uint8_t*>(self_data), op);
264     at::cuda::cub::reduce(data_iter, tmp_num_true.get(), num_inp,
265                           NO_ROCM(at_cuda_detail)::cub::Sum{}, 0);
266 
267     auto options = self.options();
268     output = at::empty({2}, self.options());
269     counts = at::empty({2}, options.dtype(kLong));
270 
271     unique_bool_write_output<<<1, 1, 0, stream>>>(
272         num_inp,
273         tmp_num_true.get(),
274         output.mutable_data_ptr<bool>(),
275         counts.mutable_data_ptr<int64_t>());
276     C10_CUDA_KERNEL_LAUNCH_CHECK();
277 
278     if (return_inverse) {
279       using namespace at::cuda::detail;
280       inverse_indices = at::empty(self.sizes(), options.dtype(kLong));
281       dim3 block = CUDA_NUM_THREADS;
282       dim3 grid = GET_BLOCKS(num_inp);
283       unique_bool_write_inverse_indices<<<grid, block, 0, stream>>>(
284           num_inp,
285           tmp_num_true.get(),
286           self_data,
287           inverse_indices.mutable_data_ptr<int64_t>());
288       C10_CUDA_KERNEL_LAUNCH_CHECK();
289     }
290 
291     // Final sync to fix the output tensors shape
292     int num_true = 0;
293     at::cuda::memcpy_and_sync(&num_true, tmp_num_true.get(), sizeof(int),
294                               cudaMemcpyDeviceToHost, stream);
295     const int num_false = num_inp - num_true;
296     const int num_out = ((num_true > 0) + (num_false > 0));
297     output.resize_({num_out});
298     counts.resize_({num_out});
299 
300     return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
301   }
302 };
303 
304 template <typename scalar_t>
unique_cuda_template(const Tensor & self,const bool consecutive,const bool return_inverse,const bool return_counts)305 std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
306     const Tensor& self,
307     const bool consecutive,
308     const bool return_inverse,
309     const bool return_counts) {
310   auto num_inp = self.numel();
311   TORCH_CHECK(
312       num_inp <= INT_MAX, "num_inp ", num_inp, " is too big to for CUB");
313   if (num_inp == 0) {
314     Tensor output = at::empty({0}, self.options());
315     Tensor inverse_indices = at::empty(self.sizes(), self.options().dtype(kLong));
316     Tensor counts = at::empty({0}, self.options().dtype(kLong));
317     return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
318   }
319 
320   auto self_c = self.expect_contiguous();
321   return UniqueCub<scalar_t>{}(*self_c, consecutive, return_inverse, return_counts);
322 }
323 
324 #define INSTANTIATE_UNIQUE_CUDA_TEMPLATE(TYPE)                            \
325   template std::tuple<Tensor, Tensor, Tensor> unique_cuda_template<TYPE>( \
326       const Tensor& self,                                                 \
327       const bool consecutive,                                             \
328       const bool return_inverse,                                          \
329       const bool return_counts)
330 
331 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint8_t);
332 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int8_t);
333 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(double);
334 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(float);
335 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int32_t);
336 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int64_t);
337 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int16_t);
338 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint32_t);
339 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint64_t);
340 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint16_t);
341 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(bool);
342 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(at::Half);
343 
344 #undef INSTANTIATE
345 
346 } // namespace at::native::internal
347