1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/UniqueCub.cuh>
3
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/detail/KernelUtils.h>
6 #include <ATen/cuda/CUDAApplyUtils.cuh>
7 #include <ATen/cuda/cub.cuh>
8
9 #include <c10/core/DeviceArray.h>
10 #include <c10/util/Load.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #else
15 #include <ATen/ops/arange.h>
16 #include <ATen/ops/empty.h>
17 #endif
18
19 namespace at::native::internal {
20
21 namespace {
22
23 template <typename InputIteratorT>
adjacent_difference_kernel(int64_t n,InputIteratorT input,int * output)24 __global__ void adjacent_difference_kernel(
25 int64_t n,
26 InputIteratorT input,
27 int* output) {
28 CUDA_KERNEL_LOOP(i, n) {
29 output[i] = i > 0 ? input[i] != input[i - 1] : 0;
30 }
31 }
32
scatter_kernel(int64_t n,const int64_t * input,const int64_t * indices,int64_t * output)33 __global__ void scatter_kernel(
34 int64_t n,
35 const int64_t* input,
36 const int64_t* indices,
37 int64_t* output) {
38 CUDA_KERNEL_LOOP(i, n) {
39 output[indices[i]] = input[i];
40 }
41 }
42
43 template <typename scalar_t>
wrap_input_iterator(const scalar_t * data)44 const scalar_t * wrap_input_iterator(const scalar_t *data) {
45 return data;
46 }
47
48 struct LoadBoolOp {
operator ()at::native::internal::__anone612ad3d0111::LoadBoolOp49 __device__ bool operator()(uint8_t x) const {
50 return static_cast<bool>(x);
51 }
52 };
53
wrap_input_iterator(const bool * data)54 auto wrap_input_iterator(const bool *data) {
55 // See NOTE [Loading boolean values]
56 LoadBoolOp op;
57 return NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<bool, LoadBoolOp, const uint8_t*, int>(
58 reinterpret_cast<const uint8_t*>(data), op);
59 }
60
61 // A variation of compute_unique (defined in Unique.cu) that doesn't allow
62 // customizing equal and not_equal (CUB doesn't allow them).
63 template <typename scalar_t>
compute_unique(const Tensor & sorted,const Tensor & sorted_indices,const bool return_inverse,const bool return_counts,const bool consecutive)64 std::tuple<Tensor, Tensor, Tensor> compute_unique(
65 const Tensor& sorted,
66 const Tensor& sorted_indices,
67 const bool return_inverse,
68 const bool return_counts,
69 const bool consecutive) {
70 int64_t num_inp = sorted.numel();
71 auto options = sorted.options().dtype(kLong);
72 auto data = wrap_input_iterator(sorted.const_data_ptr<scalar_t>());
73 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
75 // inverse indices
76 Tensor inverse_indices;
77 if (!return_inverse) {
78 inverse_indices = at::empty({0}, options);
79 } else {
80 inverse_indices = at::empty(sorted.sizes(), options);
81 Tensor inv_loc = consecutive ? at::empty({num_inp}, options.dtype(kInt))
82 : inverse_indices;
83 int* inv_loc_ptr = static_cast<int*>(inv_loc.mutable_data_ptr());
84 const dim3 block =
85 dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), num_inp));
86 dim3 grid;
87 c10::DeviceIndex curDevice = -1;
88 c10::cuda::GetDevice(&curDevice);
89 cuda::getApplyGrid(num_inp, grid, curDevice);
90 adjacent_difference_kernel<<<grid, block, 0, stream>>>(
91 num_inp, data, inv_loc_ptr);
92 C10_CUDA_KERNEL_LAUNCH_CHECK();
93
94 Tensor inv_loc_out =
95 consecutive ? inverse_indices : at::empty({num_inp}, options);
96 at::cuda::cub::inclusive_sum_truncating(
97 inv_loc_ptr,
98 inv_loc_out.mutable_data_ptr<int64_t>(),
99 num_inp);
100
101 if (!consecutive) {
102 TORCH_INTERNAL_ASSERT(
103 sorted_indices.defined(),
104 "return_inverse is set to true, but sorted_indices is undefined. Send a bug report!");
105 scatter_kernel<<<grid, block, 0, stream>>>(
106 num_inp,
107 inv_loc_out.const_data_ptr<int64_t>(),
108 sorted_indices.const_data_ptr<int64_t>(),
109 inverse_indices.mutable_data_ptr<int64_t>());
110 C10_CUDA_KERNEL_LAUNCH_CHECK();
111 }
112 }
113
114 // unique and count
115 Tensor data_out = at::empty({num_inp}, sorted.options());
116 Tensor counts = at::empty({0}, options);
117 Tensor length = at::empty({1}, options);
118 int64_t num_out;
119 if (!return_counts) {
120 cuda::cub::unique(data, data_out.mutable_data_ptr<scalar_t>(), length.mutable_data_ptr<int64_t>(), num_inp);
121 num_out = length.item<int64_t>();
122 } else {
123 counts.resize_(num_inp);
124 at::cuda::cub::run_length_encode(
125 data,
126 data_out.mutable_data_ptr<scalar_t>(),
127 counts.mutable_data_ptr<int64_t>(),
128 length.mutable_data_ptr<int64_t>(),
129 num_inp);
130 num_out = length.item<int64_t>();
131 counts.resize_(num_out);
132 }
133
134 data_out.resize_(num_out);
135 return std::tuple<Tensor, Tensor, Tensor>(
136 data_out, inverse_indices, counts);
137 }
138
139 } // namespace
140
141 // This function (and compute_unique above) are defined in a separate file from
142 // Unique.cu because for now ATen/cuda/cub.cuh can't be used together with
143 // thrust in the same compilation unit.
144
145 template <typename scalar_t>
146 struct UniqueCub {
operator ()at::native::internal::UniqueCub147 std::tuple<Tensor, Tensor, Tensor> operator() (
148 const Tensor& self,
149 const bool consecutive,
150 const bool return_inverse,
151 const bool return_counts) {
152 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
153
154 int64_t num_inp = self.numel();
155 Tensor sorted;
156 if (consecutive) {
157 sorted = self;
158 } else {
159 sorted = at::empty(self.sizes(), self.options());
160 }
161
162 Tensor sorted_indices;
163 if (!return_inverse) {
164 if (!consecutive) {
165 cuda::cub::radix_sort_keys(
166 self.const_data_ptr<scalar_t>(),
167 sorted.mutable_data_ptr<scalar_t>(),
168 num_inp);
169 }
170 } else {
171 if (!consecutive) {
172 auto options = self.options().dtype(kLong);
173 Tensor range = at::arange(0, num_inp, options);
174 sorted_indices = at::empty({num_inp}, options);
175 cuda::cub::radix_sort_pairs(
176 self.const_data_ptr<scalar_t>(),
177 sorted.mutable_data_ptr<scalar_t>(),
178 range.const_data_ptr<int64_t>(),
179 sorted_indices.mutable_data_ptr<int64_t>(),
180 num_inp);
181 }
182 }
183
184 return compute_unique<scalar_t>(
185 sorted, sorted_indices, return_inverse, return_counts, consecutive);
186 }
187 };
188
189 struct MapNumberOfTrueValues {
operator ()at::native::internal::MapNumberOfTrueValues190 __device__ int operator()(uint8_t x) const {
191 return static_cast<bool>(x);
192 }
193 };
194
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)195 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
196 __global__ void unique_bool_write_inverse_indices(
197 const int numel,
198 const int *num_true_p,
199 const bool *self,
200 int64_t *inverse_indices_out) {
201 constexpr int false_idx = 0;
202 const int num_true = *num_true_p;
203 const int num_false = numel - num_true;
204 const int true_idx = num_false > 0;
205
206 CUDA_KERNEL_LOOP(i, numel) {
207 const auto value = c10::load(&self[i]);
208 inverse_indices_out[i] = value ? true_idx : false_idx;
209 }
210 }
211
212 C10_LAUNCH_BOUNDS_1(1)
unique_bool_write_output(const int numel,const int * num_true_p,bool * values_out,int64_t * counts_out)213 __global__ void unique_bool_write_output(
214 const int numel,
215 const int *num_true_p,
216 bool *values_out,
217 int64_t *counts_out) {
218 constexpr int false_idx = 0;
219 const int num_true = *num_true_p;
220 const int num_false = numel - num_true;
221 const int true_idx = num_false > 0;
222
223 if (blockIdx.x == 0 && threadIdx.x == 0) {
224 if (num_false > 0) {
225 values_out[false_idx] = false;
226 counts_out[false_idx] = num_false;
227 }
228 if (num_true > 0) {
229 values_out[true_idx] = true;
230 counts_out[true_idx] = num_true;
231 }
232 }
233 }
234
235 template <>
236 struct UniqueCub<bool> {
237
operator ()at::native::internal::UniqueCub238 std::tuple<Tensor, Tensor, Tensor> operator() (
239 const Tensor& self,
240 const bool consecutive,
241 const bool return_inverse,
242 const bool return_counts) {
243 auto stream = at::cuda::getCurrentCUDAStream();
244
245 int64_t num_inp = self.numel();
246
247 Tensor output, inverse_indices, counts;
248 if (consecutive) {
249 Tensor sorted_indices;
250 return compute_unique<bool>(
251 self, sorted_indices, return_inverse, return_counts, consecutive);
252 }
253
254 // Instead of sorting, we use a reduction to find the number of
255 // true values and from that we can infer the number of false.
256 // If either has a count of zero, we omit it from the output.
257 auto allocator = at::cuda::getCUDADeviceAllocator();
258 c10::DeviceArray<int> tmp_num_true(*allocator, 1);
259
260 const bool* self_data = self.const_data_ptr<bool>();
261 MapNumberOfTrueValues op;
262 NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<int, MapNumberOfTrueValues, const uint8_t*, int>
263 data_iter(reinterpret_cast<const uint8_t*>(self_data), op);
264 at::cuda::cub::reduce(data_iter, tmp_num_true.get(), num_inp,
265 NO_ROCM(at_cuda_detail)::cub::Sum{}, 0);
266
267 auto options = self.options();
268 output = at::empty({2}, self.options());
269 counts = at::empty({2}, options.dtype(kLong));
270
271 unique_bool_write_output<<<1, 1, 0, stream>>>(
272 num_inp,
273 tmp_num_true.get(),
274 output.mutable_data_ptr<bool>(),
275 counts.mutable_data_ptr<int64_t>());
276 C10_CUDA_KERNEL_LAUNCH_CHECK();
277
278 if (return_inverse) {
279 using namespace at::cuda::detail;
280 inverse_indices = at::empty(self.sizes(), options.dtype(kLong));
281 dim3 block = CUDA_NUM_THREADS;
282 dim3 grid = GET_BLOCKS(num_inp);
283 unique_bool_write_inverse_indices<<<grid, block, 0, stream>>>(
284 num_inp,
285 tmp_num_true.get(),
286 self_data,
287 inverse_indices.mutable_data_ptr<int64_t>());
288 C10_CUDA_KERNEL_LAUNCH_CHECK();
289 }
290
291 // Final sync to fix the output tensors shape
292 int num_true = 0;
293 at::cuda::memcpy_and_sync(&num_true, tmp_num_true.get(), sizeof(int),
294 cudaMemcpyDeviceToHost, stream);
295 const int num_false = num_inp - num_true;
296 const int num_out = ((num_true > 0) + (num_false > 0));
297 output.resize_({num_out});
298 counts.resize_({num_out});
299
300 return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
301 }
302 };
303
304 template <typename scalar_t>
unique_cuda_template(const Tensor & self,const bool consecutive,const bool return_inverse,const bool return_counts)305 std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
306 const Tensor& self,
307 const bool consecutive,
308 const bool return_inverse,
309 const bool return_counts) {
310 auto num_inp = self.numel();
311 TORCH_CHECK(
312 num_inp <= INT_MAX, "num_inp ", num_inp, " is too big to for CUB");
313 if (num_inp == 0) {
314 Tensor output = at::empty({0}, self.options());
315 Tensor inverse_indices = at::empty(self.sizes(), self.options().dtype(kLong));
316 Tensor counts = at::empty({0}, self.options().dtype(kLong));
317 return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
318 }
319
320 auto self_c = self.expect_contiguous();
321 return UniqueCub<scalar_t>{}(*self_c, consecutive, return_inverse, return_counts);
322 }
323
324 #define INSTANTIATE_UNIQUE_CUDA_TEMPLATE(TYPE) \
325 template std::tuple<Tensor, Tensor, Tensor> unique_cuda_template<TYPE>( \
326 const Tensor& self, \
327 const bool consecutive, \
328 const bool return_inverse, \
329 const bool return_counts)
330
331 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint8_t);
332 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int8_t);
333 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(double);
334 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(float);
335 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int32_t);
336 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int64_t);
337 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int16_t);
338 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint32_t);
339 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint64_t);
340 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint16_t);
341 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(bool);
342 INSTANTIATE_UNIQUE_CUDA_TEMPLATE(at::Half);
343
344 #undef INSTANTIATE
345
346 } // namespace at::native::internal
347