xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/cub.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/cuda/cub.h>
3 
4 #include <cstddef>
5 #include <type_traits>
6 #include <iterator>
7 #include <limits>
8 
9 #include <ATen/cuda/cub_definitions.cuh>
10 
11 #if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
12 
13 #include <cub/cub.cuh>
14 
15 #else
16 
17 // include cub in a safe manner, see:
18 // https://github.com/pytorch/pytorch/pull/55292
19 #undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
20 #undef CUB_NS_PREFIX
21 #undef CUB_NS_QUALIFIER
22 #define CUB_NS_PREFIX namespace at_cuda_detail {
23 #define CUB_NS_POSTFIX }
24 #define CUB_NS_QUALIFIER ::at_cuda_detail::cub
25 #include <cub/cub.cuh>
26 #undef CUB_NS_POSTFIX
27 #undef CUB_NS_PREFIX
28 #undef CUB_NS_QUALIFIER
29 
30 #endif
31 
32 #include <ATen/cuda/Exceptions.h>
33 #include <c10/cuda/CUDACachingAllocator.h>
34 #include <c10/cuda/CUDAStream.h>
35 
36 // handle the temporary storage and 'twice' calls for cub API
37 #define CUB_WRAPPER(func, ...) do {                                       \
38   size_t temp_storage_bytes = 0;                                          \
39   func(nullptr, temp_storage_bytes, __VA_ARGS__);                         \
40   auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get();    \
41   auto temp_storage = caching_allocator.allocate(temp_storage_bytes);     \
42   func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__);              \
43   AT_CUDA_CHECK(cudaGetLastError());                                      \
44 } while (false)
45 
46 #ifdef USE_ROCM
47 #define NO_ROCM(x)
48 #define ROCM_HIPCUB(x) ::hipcub
49 #else
50 #define NO_ROCM(x) x
51 #define ROCM_HIPCUB(x) x
52 #endif
53 
54 #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
55 
56 #if !defined(USE_ROCM)
57 namespace at_cuda_detail {
58 #endif
59 
60 // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
61 
62 template <>
ROCM_HIPCUB(cub)63 struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
64 {
65     static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
66         unsigned short max_word = 0x7F7F;
67         return reinterpret_cast<c10::BFloat16&>(max_word);
68     }
69 
70     static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
71         unsigned short lowest_word = 0xFF7F;
72         return reinterpret_cast<c10::BFloat16&>(lowest_word);
73     }
74 };
75 
76 template <>
ROCM_HIPCUB(cub)77 struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
78        ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
79 
80 #if !defined(USE_ROCM)
81 } // namespace at_cuda_detail
82 #endif
83 
84 #endif
85 
86 #if !defined(USE_ROCM)
87 namespace at::native {
88 namespace cub = ::at_cuda_detail::cub;
89 } // namespace at::native
90 #endif
91 
92 namespace at::cuda::cub {
93 
94 namespace detail {
95 
96 template<typename T>
97 struct cuda_type {
98   using type = T;
99 };
100 template<>
101 struct cuda_type<c10::Half> {
102   using type = __half;
103 };
104 
105 #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
106 
107 template<>
108 struct cuda_type<c10::BFloat16> {
109   using type = __nv_bfloat16;
110 };
111 
112 #elif defined(USE_ROCM)
113 
114 template<>
115 struct cuda_type<c10::BFloat16> {
116   using type = hip_bfloat16;
117 };
118 
119 #endif
120 
121 }  // namespace detail
122 
123 template<typename key_t, typename value_t, typename OffsetIteratorT>
segmented_sort_pairs(const key_t * keys_in,key_t * keys_out,const value_t * values_in,value_t * values_out,int64_t num_elements,int64_t num_segments,OffsetIteratorT begin_offsets,OffsetIteratorT end_offsets,bool descending=false,int64_t begin_bit=0,int64_t end_bit=sizeof (key_t)* 8)124 inline void segmented_sort_pairs(
125     const key_t *keys_in, key_t *keys_out,
126     const value_t *values_in, value_t *values_out,
127     int64_t num_elements, int64_t num_segments,
128     OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
129     bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
130 ) {
131   TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
132     "cub sort does not support sorting more than INT_MAX elements");
133   TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
134     "cub sort does not support sorting more than INT_MAX elements");
135   using key_t_ = typename detail::cuda_type<key_t>::type;
136 
137   auto allocator = c10::cuda::CUDACachingAllocator::get();
138   c10::DataPtr keys_out_owner;
139 
140   if (keys_out == nullptr) {
141     keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
142     keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
143   }
144 
145   const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
146   key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
147 
148   if (descending) {
149     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
150       keys_in_, keys_out_, values_in, values_out,
151       num_elements, num_segments, begin_offsets, end_offsets,
152       begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
153   } else {
154     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
155       keys_in_, keys_out_, values_in, values_out,
156       num_elements, num_segments, begin_offsets, end_offsets,
157       begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
158   }
159 }
160 
161 #if CUB_SUPPORTS_UNIQUE_BY_KEY()
162 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
unique_by_key(KeysInputIteratorT keys_in,ValuesInputIteratorT values_in,ValuesOutputIteratorT values_out,NumSelectedIteratorT num_selected,int64_t num_input_items)163 inline void unique_by_key(
164   KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
165   ValuesOutputIteratorT values_out,
166   NumSelectedIteratorT num_selected, int64_t num_input_items)
167 {
168   // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
169   using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
170   auto allocator = c10::cuda::CUDACachingAllocator::get();
171   c10::DataPtr keys_out_owner;
172   keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
173   auto keys_out_ = static_cast<KeyT *>(keys_out_owner.get());
174   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
175     keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
176 }
177 #endif
178 
179 namespace impl {
180 
181 template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT>
182 C10_LAUNCH_BOUNDS_1(1)
transform_vals(InputIteratorT1 a,InputIteratorT2 b,OutputIteratorT out,ScanOpT scan_op)183 __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){
184   // NOTE: out here not the final scan output, but an intermediate of the accumulation type.
185   using acc_t = typename std::iterator_traits<OutputIteratorT>::value_type;
186   *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
187 }
188 
189 #if !CUB_SUPPORTS_FUTURE_VALUE()
190 template<typename ValueT, typename InputIteratorT>
191 struct chained_iterator {
192   using iterator_category = std::random_access_iterator_tag;
193   using difference_type   = std::ptrdiff_t;
194   using value_type        = ValueT;
195   using pointer           = ValueT*;
196   using reference         = ValueT&;
197 
198   InputIteratorT iter;
199   ValueT *first;
200   difference_type offset = 0;
201 
operator []at::cuda::cub::impl::chained_iterator202   __device__ ValueT operator[](difference_type i) {
203     i +=  offset;
204     if (i == 0) {
205       return *first;
206     } else {
207       return ValueT(iter[i - 1]);
208     }
209   }
operator +at::cuda::cub::impl::chained_iterator210   __device__ chained_iterator operator+(difference_type i) {
211     return chained_iterator{iter, first, i};
212   }
operator *at::cuda::cub::impl::chained_iterator213   __device__ ValueT operator*() {
214     return (*this)[0];
215   }
216 };
217 #endif
218 
219 // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
220 // so split at int_max/2
221 constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
222 }
223 
224 // non synchronizing cub call
225 // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
226 // so split at int_max/2
227 template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size>
inclusive_scan(InputIteratorT input,OutputIteratorT output,ScanOpT scan_op,int64_t num_items)228 inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
229 #if defined(USE_ROCM)
230   //For ROCm, use hipCUB chained iterators
231   CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
232       input,
233       output,
234       scan_op,
235       num_items,
236       at::cuda::getCurrentCUDAStream());
237   C10_HIP_KERNEL_LAUNCH_CHECK();
238 #else
239   // non synchronizing cub call
240   // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
241   // so split at int_max/2
242   int size_cub = std::min<int64_t>(num_items, max_cub_size);
243   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
244       input,
245       output,
246       scan_op,
247       size_cub,
248       at::cuda::getCurrentCUDAStream());
249   C10_CUDA_KERNEL_LAUNCH_CHECK();
250   using input_t = typename std::iterator_traits<InputIteratorT>::value_type;
251   for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
252     auto allocator = c10::cuda::CUDACachingAllocator::get();
253     c10::DataPtr first_elem = allocator->allocate(sizeof(input_t));
254     auto first_elem_ptr = reinterpret_cast<input_t *>(first_elem.get());
255 
256     size_cub = std::min<int64_t>(num_items - i, max_cub_size);
257     impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
258         output + i - 1,
259         input + i,
260         first_elem_ptr,
261         scan_op);
262     C10_CUDA_KERNEL_LAUNCH_CHECK();
263 #if !CUB_SUPPORTS_FUTURE_VALUE()
264     using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
265     using tuple = typename ArgIndexInputIterator::value_type;
266     auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  {
267       if (x.key == 0) {
268         return *first_elem_ptr;
269       } else {
270         return x.value;
271       }
272     };
273     auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>(
274       ArgIndexInputIterator(input + i), input_iter_transform);
275     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
276         input_,
277         output + i,
278         scan_op,
279         size_cub,
280         at::cuda::getCurrentCUDAStream());
281 #else
282     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
283         input + i + 1,
284         output + i,
285         scan_op,
286         ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
287         size_cub,
288         at::cuda::getCurrentCUDAStream());
289 #endif
290   }
291 #endif
292 }
293 
294 template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
exclusive_scan(InputIteratorT input,OutputIteratorT output,ScanOpT scan_op,InitValueT init_value,int64_t num_items)295 inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
296 #if defined(USE_ROCM)
297   //For ROCm, use hipCUB chained iterators
298   CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
299       input,
300       output,
301       scan_op,
302       init_value,
303       num_items,
304       at::cuda::getCurrentCUDAStream());
305   C10_HIP_KERNEL_LAUNCH_CHECK();
306 #else
307   // non synchronizing cub call
308   // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
309   // so split at int_max/2
310   int size_cub = std::min<int64_t>(num_items, max_cub_size);
311   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
312       input,
313       output,
314       scan_op,
315       init_value,
316       size_cub,
317       at::cuda::getCurrentCUDAStream());
318   C10_CUDA_KERNEL_LAUNCH_CHECK();
319   for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
320     auto allocator = c10::cuda::CUDACachingAllocator::get();
321     c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT));
322     auto first_elem_ptr = reinterpret_cast<InitValueT *>(first_elem.get());
323 
324     size_cub = std::min<int64_t>(num_items - i, max_cub_size);
325     impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
326         output + i - 1,
327         input + i - 1,
328         first_elem_ptr,
329         scan_op);
330     C10_CUDA_KERNEL_LAUNCH_CHECK();
331 #if !CUB_SUPPORTS_FUTURE_VALUE()
332     auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
333       input + i, first_elem_ptr};
334     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
335         input_,
336         output + i,
337         scan_op,
338         size_cub,
339         at::cuda::getCurrentCUDAStream());
340 #else
341     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
342         input + i,
343         output + i,
344         scan_op,
345         ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
346         size_cub,
347         at::cuda::getCurrentCUDAStream());
348 #endif
349   }
350 #endif
351 }
352 
353 #if CUB_SUPPORTS_SCAN_BY_KEY()
354 
355 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inclusive_sum_by_key(KeysInputIteratorT keys,ValuesInputIteratorT input,ValuesOutputIteratorT output,int64_t num_items)356 inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
357   TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
358     "cub InclusiveSumByKey does not support more than INT_MAX elements");
359   CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
360       keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
361 }
362 
363 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT>
inclusive_scan_by_key(KeysInputIteratorT keys,ValuesInputIteratorT input,ValuesOutputIteratorT output,ScanOpT scan_op,int64_t num_items)364 inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
365   TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
366     "cub InclusiveSumByKey does not support more than INT_MAX elements");
367   CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
368       keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
369 }
370 
371 #endif
372 
373 template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
unique(InputIteratorT input,OutputIteratorT output,NumSelectedIteratorT num_selected_out,int64_t num_items)374 void unique(InputIteratorT input, OutputIteratorT output,
375             NumSelectedIteratorT num_selected_out, int64_t num_items) {
376   TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
377               "cub unique does not support more than INT_MAX elements");
378   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
379               input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
380 }
381 
382 template <typename InputIteratorT, typename OutputIteratorT, typename CountsOutputIteratorT,
383           typename LengthOutputIteratorT>
run_length_encode(InputIteratorT input,OutputIteratorT output,CountsOutputIteratorT counts_out,LengthOutputIteratorT length_out,int64_t num_items)384 void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out,
385                        LengthOutputIteratorT length_out, int64_t num_items) {
386   TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
387               "cub run_length_encode does not support more than INT_MAX elements");
388   CUB_WRAPPER(
389       NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
390       input, output, counts_out, length_out, num_items,
391       at::cuda::getCurrentCUDAStream());
392 }
393 
394 template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
reduce(InputIteratorT input,OutputIteratorT output,int64_t num_items,ReductionOpT op,T init)395 void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) {
396   TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
397               "cub reduce does not support more than INT_MAX elements");
398   CUB_WRAPPER(
399       NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce,
400       input, output, num_items, op, init,
401       at::cuda::getCurrentCUDAStream());
402 
403 }
404 
405 }  // namespace at::cuda::cub
406