1 #pragma once
2 #include <ATen/cuda/cub.h>
3
4 #include <cstddef>
5 #include <type_traits>
6 #include <iterator>
7 #include <limits>
8
9 #include <ATen/cuda/cub_definitions.cuh>
10
11 #if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
12
13 #include <cub/cub.cuh>
14
15 #else
16
17 // include cub in a safe manner, see:
18 // https://github.com/pytorch/pytorch/pull/55292
19 #undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
20 #undef CUB_NS_PREFIX
21 #undef CUB_NS_QUALIFIER
22 #define CUB_NS_PREFIX namespace at_cuda_detail {
23 #define CUB_NS_POSTFIX }
24 #define CUB_NS_QUALIFIER ::at_cuda_detail::cub
25 #include <cub/cub.cuh>
26 #undef CUB_NS_POSTFIX
27 #undef CUB_NS_PREFIX
28 #undef CUB_NS_QUALIFIER
29
30 #endif
31
32 #include <ATen/cuda/Exceptions.h>
33 #include <c10/cuda/CUDACachingAllocator.h>
34 #include <c10/cuda/CUDAStream.h>
35
36 // handle the temporary storage and 'twice' calls for cub API
37 #define CUB_WRAPPER(func, ...) do { \
38 size_t temp_storage_bytes = 0; \
39 func(nullptr, temp_storage_bytes, __VA_ARGS__); \
40 auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
41 auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
42 func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
43 AT_CUDA_CHECK(cudaGetLastError()); \
44 } while (false)
45
46 #ifdef USE_ROCM
47 #define NO_ROCM(x)
48 #define ROCM_HIPCUB(x) ::hipcub
49 #else
50 #define NO_ROCM(x) x
51 #define ROCM_HIPCUB(x) x
52 #endif
53
54 #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
55
56 #if !defined(USE_ROCM)
57 namespace at_cuda_detail {
58 #endif
59
60 // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
61
62 template <>
ROCM_HIPCUB(cub)63 struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
64 {
65 static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
66 unsigned short max_word = 0x7F7F;
67 return reinterpret_cast<c10::BFloat16&>(max_word);
68 }
69
70 static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
71 unsigned short lowest_word = 0xFF7F;
72 return reinterpret_cast<c10::BFloat16&>(lowest_word);
73 }
74 };
75
76 template <>
ROCM_HIPCUB(cub)77 struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
78 ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
79
80 #if !defined(USE_ROCM)
81 } // namespace at_cuda_detail
82 #endif
83
84 #endif
85
86 #if !defined(USE_ROCM)
87 namespace at::native {
88 namespace cub = ::at_cuda_detail::cub;
89 } // namespace at::native
90 #endif
91
92 namespace at::cuda::cub {
93
94 namespace detail {
95
96 template<typename T>
97 struct cuda_type {
98 using type = T;
99 };
100 template<>
101 struct cuda_type<c10::Half> {
102 using type = __half;
103 };
104
105 #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
106
107 template<>
108 struct cuda_type<c10::BFloat16> {
109 using type = __nv_bfloat16;
110 };
111
112 #elif defined(USE_ROCM)
113
114 template<>
115 struct cuda_type<c10::BFloat16> {
116 using type = hip_bfloat16;
117 };
118
119 #endif
120
121 } // namespace detail
122
123 template<typename key_t, typename value_t, typename OffsetIteratorT>
segmented_sort_pairs(const key_t * keys_in,key_t * keys_out,const value_t * values_in,value_t * values_out,int64_t num_elements,int64_t num_segments,OffsetIteratorT begin_offsets,OffsetIteratorT end_offsets,bool descending=false,int64_t begin_bit=0,int64_t end_bit=sizeof (key_t)* 8)124 inline void segmented_sort_pairs(
125 const key_t *keys_in, key_t *keys_out,
126 const value_t *values_in, value_t *values_out,
127 int64_t num_elements, int64_t num_segments,
128 OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
129 bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
130 ) {
131 TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
132 "cub sort does not support sorting more than INT_MAX elements");
133 TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
134 "cub sort does not support sorting more than INT_MAX elements");
135 using key_t_ = typename detail::cuda_type<key_t>::type;
136
137 auto allocator = c10::cuda::CUDACachingAllocator::get();
138 c10::DataPtr keys_out_owner;
139
140 if (keys_out == nullptr) {
141 keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
142 keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
143 }
144
145 const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
146 key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
147
148 if (descending) {
149 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
150 keys_in_, keys_out_, values_in, values_out,
151 num_elements, num_segments, begin_offsets, end_offsets,
152 begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
153 } else {
154 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
155 keys_in_, keys_out_, values_in, values_out,
156 num_elements, num_segments, begin_offsets, end_offsets,
157 begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
158 }
159 }
160
161 #if CUB_SUPPORTS_UNIQUE_BY_KEY()
162 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
unique_by_key(KeysInputIteratorT keys_in,ValuesInputIteratorT values_in,ValuesOutputIteratorT values_out,NumSelectedIteratorT num_selected,int64_t num_input_items)163 inline void unique_by_key(
164 KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
165 ValuesOutputIteratorT values_out,
166 NumSelectedIteratorT num_selected, int64_t num_input_items)
167 {
168 // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
169 using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
170 auto allocator = c10::cuda::CUDACachingAllocator::get();
171 c10::DataPtr keys_out_owner;
172 keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
173 auto keys_out_ = static_cast<KeyT *>(keys_out_owner.get());
174 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
175 keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
176 }
177 #endif
178
179 namespace impl {
180
181 template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT>
182 C10_LAUNCH_BOUNDS_1(1)
transform_vals(InputIteratorT1 a,InputIteratorT2 b,OutputIteratorT out,ScanOpT scan_op)183 __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){
184 // NOTE: out here not the final scan output, but an intermediate of the accumulation type.
185 using acc_t = typename std::iterator_traits<OutputIteratorT>::value_type;
186 *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
187 }
188
189 #if !CUB_SUPPORTS_FUTURE_VALUE()
190 template<typename ValueT, typename InputIteratorT>
191 struct chained_iterator {
192 using iterator_category = std::random_access_iterator_tag;
193 using difference_type = std::ptrdiff_t;
194 using value_type = ValueT;
195 using pointer = ValueT*;
196 using reference = ValueT&;
197
198 InputIteratorT iter;
199 ValueT *first;
200 difference_type offset = 0;
201
operator []at::cuda::cub::impl::chained_iterator202 __device__ ValueT operator[](difference_type i) {
203 i += offset;
204 if (i == 0) {
205 return *first;
206 } else {
207 return ValueT(iter[i - 1]);
208 }
209 }
operator +at::cuda::cub::impl::chained_iterator210 __device__ chained_iterator operator+(difference_type i) {
211 return chained_iterator{iter, first, i};
212 }
operator *at::cuda::cub::impl::chained_iterator213 __device__ ValueT operator*() {
214 return (*this)[0];
215 }
216 };
217 #endif
218
219 // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
220 // so split at int_max/2
221 constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
222 }
223
224 // non synchronizing cub call
225 // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
226 // so split at int_max/2
227 template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size>
inclusive_scan(InputIteratorT input,OutputIteratorT output,ScanOpT scan_op,int64_t num_items)228 inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
229 #if defined(USE_ROCM)
230 //For ROCm, use hipCUB chained iterators
231 CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
232 input,
233 output,
234 scan_op,
235 num_items,
236 at::cuda::getCurrentCUDAStream());
237 C10_HIP_KERNEL_LAUNCH_CHECK();
238 #else
239 // non synchronizing cub call
240 // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
241 // so split at int_max/2
242 int size_cub = std::min<int64_t>(num_items, max_cub_size);
243 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
244 input,
245 output,
246 scan_op,
247 size_cub,
248 at::cuda::getCurrentCUDAStream());
249 C10_CUDA_KERNEL_LAUNCH_CHECK();
250 using input_t = typename std::iterator_traits<InputIteratorT>::value_type;
251 for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
252 auto allocator = c10::cuda::CUDACachingAllocator::get();
253 c10::DataPtr first_elem = allocator->allocate(sizeof(input_t));
254 auto first_elem_ptr = reinterpret_cast<input_t *>(first_elem.get());
255
256 size_cub = std::min<int64_t>(num_items - i, max_cub_size);
257 impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
258 output + i - 1,
259 input + i,
260 first_elem_ptr,
261 scan_op);
262 C10_CUDA_KERNEL_LAUNCH_CHECK();
263 #if !CUB_SUPPORTS_FUTURE_VALUE()
264 using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
265 using tuple = typename ArgIndexInputIterator::value_type;
266 auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
267 if (x.key == 0) {
268 return *first_elem_ptr;
269 } else {
270 return x.value;
271 }
272 };
273 auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>(
274 ArgIndexInputIterator(input + i), input_iter_transform);
275 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
276 input_,
277 output + i,
278 scan_op,
279 size_cub,
280 at::cuda::getCurrentCUDAStream());
281 #else
282 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
283 input + i + 1,
284 output + i,
285 scan_op,
286 ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
287 size_cub,
288 at::cuda::getCurrentCUDAStream());
289 #endif
290 }
291 #endif
292 }
293
294 template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
exclusive_scan(InputIteratorT input,OutputIteratorT output,ScanOpT scan_op,InitValueT init_value,int64_t num_items)295 inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
296 #if defined(USE_ROCM)
297 //For ROCm, use hipCUB chained iterators
298 CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
299 input,
300 output,
301 scan_op,
302 init_value,
303 num_items,
304 at::cuda::getCurrentCUDAStream());
305 C10_HIP_KERNEL_LAUNCH_CHECK();
306 #else
307 // non synchronizing cub call
308 // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
309 // so split at int_max/2
310 int size_cub = std::min<int64_t>(num_items, max_cub_size);
311 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
312 input,
313 output,
314 scan_op,
315 init_value,
316 size_cub,
317 at::cuda::getCurrentCUDAStream());
318 C10_CUDA_KERNEL_LAUNCH_CHECK();
319 for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
320 auto allocator = c10::cuda::CUDACachingAllocator::get();
321 c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT));
322 auto first_elem_ptr = reinterpret_cast<InitValueT *>(first_elem.get());
323
324 size_cub = std::min<int64_t>(num_items - i, max_cub_size);
325 impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
326 output + i - 1,
327 input + i - 1,
328 first_elem_ptr,
329 scan_op);
330 C10_CUDA_KERNEL_LAUNCH_CHECK();
331 #if !CUB_SUPPORTS_FUTURE_VALUE()
332 auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
333 input + i, first_elem_ptr};
334 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
335 input_,
336 output + i,
337 scan_op,
338 size_cub,
339 at::cuda::getCurrentCUDAStream());
340 #else
341 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
342 input + i,
343 output + i,
344 scan_op,
345 ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
346 size_cub,
347 at::cuda::getCurrentCUDAStream());
348 #endif
349 }
350 #endif
351 }
352
353 #if CUB_SUPPORTS_SCAN_BY_KEY()
354
355 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inclusive_sum_by_key(KeysInputIteratorT keys,ValuesInputIteratorT input,ValuesOutputIteratorT output,int64_t num_items)356 inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
357 TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
358 "cub InclusiveSumByKey does not support more than INT_MAX elements");
359 CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
360 keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
361 }
362
363 template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT>
inclusive_scan_by_key(KeysInputIteratorT keys,ValuesInputIteratorT input,ValuesOutputIteratorT output,ScanOpT scan_op,int64_t num_items)364 inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
365 TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
366 "cub InclusiveSumByKey does not support more than INT_MAX elements");
367 CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
368 keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
369 }
370
371 #endif
372
373 template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
unique(InputIteratorT input,OutputIteratorT output,NumSelectedIteratorT num_selected_out,int64_t num_items)374 void unique(InputIteratorT input, OutputIteratorT output,
375 NumSelectedIteratorT num_selected_out, int64_t num_items) {
376 TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
377 "cub unique does not support more than INT_MAX elements");
378 CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
379 input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
380 }
381
382 template <typename InputIteratorT, typename OutputIteratorT, typename CountsOutputIteratorT,
383 typename LengthOutputIteratorT>
run_length_encode(InputIteratorT input,OutputIteratorT output,CountsOutputIteratorT counts_out,LengthOutputIteratorT length_out,int64_t num_items)384 void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out,
385 LengthOutputIteratorT length_out, int64_t num_items) {
386 TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
387 "cub run_length_encode does not support more than INT_MAX elements");
388 CUB_WRAPPER(
389 NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
390 input, output, counts_out, length_out, num_items,
391 at::cuda::getCurrentCUDAStream());
392 }
393
394 template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
reduce(InputIteratorT input,OutputIteratorT output,int64_t num_items,ReductionOpT op,T init)395 void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) {
396 TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
397 "cub reduce does not support more than INT_MAX elements");
398 CUB_WRAPPER(
399 NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce,
400 input, output, num_items, op, init,
401 at::cuda::getCurrentCUDAStream());
402
403 }
404
405 } // namespace at::cuda::cub
406