1 #pragma once
2
3 #include <ATen/core/Array.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/DeviceUtils.cuh>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/detail/FunctionTraits.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/cuda/thread_constants.h>
10 #include <ATen/native/cuda/MemoryAccess.cuh>
11 #include <ATen/OpMathType.h>
12 #include <c10/macros/Macros.h>
13 #include <c10/cuda/CUDACachingAllocator.h>
14 #include <functional>
15 #include <iosfwd>
16 #include <type_traits>
17 #include <utility>
18 #include <thrust/pair.h>
19
20 #include <ATen/native/cuda/jit_utils.h>
21
22 namespace at { namespace native {
23
24 using at::detail::Array;
25
div_up(int64_t a,int64_t b)26 static inline int64_t div_up(int64_t a, int64_t b) {
27 return (a + b - 1) / b;
28 }
29
30 // returns floor(log2(n))
last_pow2(int n)31 static inline int last_pow2(int n) {
32 n |= (n >> 1);
33 n |= (n >> 2);
34 n |= (n >> 4);
35 n |= (n >> 8);
36 n |= (n >> 16);
37 return std::max(1, n - (n >> 1));
38 }
39
40 // returns reduced fraction numerator & denominator
reduce_fraction(size_t & numerator,size_t & denominator)41 C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
42 // get GCD of num and denom using Euclid's algorithm.
43 // Can replace this with std::gcd if we ever support c++17.
44 size_t a = denominator;
45 size_t b = numerator;
46 while (b != 0) {
47 a %= b;
48 // swap(a,b)
49 size_t tmp = a;
50 a = b;
51 b = tmp;
52 }
53
54 // a is now the GCD
55 numerator /= a;
56 denominator /= a;
57 }
58
59 //template for changing MAX_NUM_THREADS based on op dtype
60 template <typename T>
61 struct mnt_wrapper {
62 static constexpr int MAX_NUM_THREADS = 512;
63 };
64
65 template <>
66 struct mnt_wrapper <c10::complex<double>>{
67 static constexpr int MAX_NUM_THREADS = 256;
68 };
69
max_reduce_threads(c10::ScalarType type)70 constexpr int max_reduce_threads(c10::ScalarType type) {
71 return type == kComplexDouble ? 256 : 512;
72 }
73
74 struct ReduceConfig {
75 static constexpr int BLOCK_X = 0;
76 static constexpr int BLOCK_Y = 1;
77 static constexpr int CTA = 2;
78
79 static constexpr int input_vec_size = 4;
80
ReduceConfigat::native::ReduceConfig81 ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
82 : element_size_bytes(element_size_bytes)
83 , num_inputs(num_inputs)
84 , num_outputs(num_outputs) {}
85 int element_size_bytes;
86 int num_inputs;
87 int num_outputs;
88 int step_input = 1;
89 int step_output = 1;
90 int ctas_per_output = 1;
91 int input_mult[3] = {0, 0, 0};
92 int output_mult[2] = {0, 0};
93
94 int block_width;
95 int block_height;
96 int num_threads;
97
98 bool vectorize_input = false;
99 int output_vec_size = 1;
100
101 template <typename T>
set_block_dimensionat::native::ReduceConfig102 void set_block_dimension(int64_t dim0, int64_t dim1) {
103 const int max_num_threads = mnt_wrapper<T>::MAX_NUM_THREADS / output_vec_size;
104 int dim0_pow2 = dim0 < max_num_threads ? static_cast<int>(last_pow2(dim0)) : max_num_threads;
105 int dim1_pow2 = dim1 < max_num_threads ? static_cast<int>(last_pow2(dim1)) : max_num_threads;
106 block_width = std::min(dim0_pow2, int(at::cuda::warp_size()));
107 block_height = std::min(dim1_pow2, int(max_num_threads / block_width));
108 block_width = std::min(dim0_pow2, int(max_num_threads / block_height));
109 num_threads = block_width * block_height;
110 }
111
split_inputat::native::ReduceConfig112 int split_input(int parallelism) {
113 int step = step_input;
114 step_input *= parallelism;
115 return step;
116 }
117
split_outputat::native::ReduceConfig118 int split_output(int parallelism) {
119 int step = step_output;
120 step_output *= parallelism;
121 return step;
122 }
123
blockat::native::ReduceConfig124 dim3 block() const {
125 return dim3(block_width, block_height);
126 }
127
gridat::native::ReduceConfig128 dim3 grid() const {
129 return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output);
130 }
131
should_block_x_reduceat::native::ReduceConfig132 C10_HOST_DEVICE bool should_block_x_reduce() const {
133 return input_mult[BLOCK_X] != 0;
134 }
135
should_block_y_reduceat::native::ReduceConfig136 C10_HOST_DEVICE bool should_block_y_reduce() const {
137 return input_mult[BLOCK_Y] != 0;
138 }
139
should_global_reduceat::native::ReduceConfig140 C10_HOST_DEVICE bool should_global_reduce() const {
141 return input_mult[CTA] != 0;
142 }
143
should_storeat::native::ReduceConfig144 C10_DEVICE bool should_store(int output_idx) const {
145 return output_idx < num_outputs &&
146 (!should_block_x_reduce() || threadIdx.x == 0) &&
147 (!should_block_y_reduce() || threadIdx.y == 0);
148 }
149
should_reduce_tailat::native::ReduceConfig150 C10_DEVICE bool should_reduce_tail() const {
151 return (!should_block_y_reduce() || threadIdx.y == 0) &&
152 (!should_global_reduce() || blockIdx.y == 0);
153 }
154
input_idxat::native::ReduceConfig155 C10_HOST_DEVICE int input_idx() const {
156 int lane = threadIdx.x;
157 int warp = threadIdx.y;
158 int cta2 = blockIdx.y;
159 return (lane * input_mult[BLOCK_X] +
160 warp * input_mult[BLOCK_Y] +
161 cta2 * input_mult[CTA]);
162 }
163
164 template <int output_vec_size>
output_idxat::native::ReduceConfig165 C10_HOST_DEVICE int output_idx() const {
166 int lane = threadIdx.x;
167 int warp = threadIdx.y;
168 int cta1 = blockIdx.x;
169 return (lane * output_mult[BLOCK_X] +
170 warp * output_mult[BLOCK_Y] +
171 cta1 * step_output) * output_vec_size;
172 }
173
shared_memory_offsetat::native::ReduceConfig174 C10_DEVICE int shared_memory_offset(int offset) const {
175 return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
176 }
177
staging_memory_offsetat::native::ReduceConfig178 C10_DEVICE int staging_memory_offset(int cta2) const {
179 int offset = cta2 + blockIdx.x * gridDim.y;
180 if (!should_block_x_reduce()) {
181 offset = threadIdx.x + offset * blockDim.x;
182 }
183 return offset;
184 }
185
shared_memory_sizeat::native::ReduceConfig186 int shared_memory_size() const {
187 if (!should_block_y_reduce() &&
188 (!should_block_x_reduce() ||
189 block_width <= at::cuda::warp_size())) {
190 return 0;
191 }
192 return element_size_bytes * num_threads * output_vec_size;
193 }
194
global_memory_sizeat::native::ReduceConfig195 int64_t global_memory_size() const {
196 if (!should_global_reduce()) {
197 return 0;
198 }
199 auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
200 if (!should_block_x_reduce()) {
201 size *= block().x * output_vec_size;
202 }
203 return size;
204 }
205
semaphore_sizeat::native::ReduceConfig206 int semaphore_size() const {
207 if (!should_global_reduce()) {
208 return 0;
209 }
210 return sizeof(int) * grid().x;
211 }
212
values_per_threadat::native::ReduceConfig213 int values_per_thread() const {
214 return div_up(num_inputs, step_input);
215 }
216 };
217
218 std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
219
220 template<int nt, int output_vec_size, typename R>
221 C10_LAUNCH_BOUNDS_2(nt, 4)
reduce_kernel(R reduction)222 __global__ void reduce_kernel(R reduction) {
223 reduction.template run<output_vec_size>();
224 }
225
226 template <typename index_t>
make_output_calculator(const TensorIterator & iter)227 static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) {
228 int num_reduce_dims = iter.num_reduce_dims();
229 int num_output_dims = iter.ndim() - num_reduce_dims;
230 int input_index = iter.ntensors() - 1;
231 int output_index = 0;
232 std::array<const int64_t*, 2> strides = {
233 iter.strides(output_index).data() + num_reduce_dims,
234 iter.strides(input_index).data() + num_reduce_dims,
235 };
236 auto shape = iter.shape().data() + num_reduce_dims;
237 return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data());
238 }
239
240 template <typename index_t>
make_input_calculator(const TensorIterator & iter)241 static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) {
242 int num_reduce_dims = iter.num_reduce_dims();
243 int input_index = iter.ntensors() - 1;
244 std::array<const int64_t*, 1> strides = {
245 iter.strides(input_index).data(),
246 };
247 return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data());
248 }
249
250 template <typename out_scalar_t, typename func_t>
251 struct func_wrapper_t {
252 using arg_t = typename binary_function_traits<func_t>::arg1_t;
253 using scalar_t = typename binary_function_traits<func_t>::arg2_t;
254
255 func_t combine;
projectat::native::func_wrapper_t256 static inline __device__ out_scalar_t project(arg_t arg) {
257 return (out_scalar_t) arg;
258 }
warp_shfl_downat::native::func_wrapper_t259 static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
260 return WARP_SHFL_DOWN(arg, offset);
261 }
262
translate_idxat::native::func_wrapper_t263 static __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
264 return acc;
265 }
266
func_wrapper_tat::native::func_wrapper_t267 func_wrapper_t(const func_t& op) : combine(op) {
268 }
269
270 // wrap a normal reduction that ignores the index
reduceat::native::func_wrapper_t271 __device__ arg_t reduce(arg_t acc, scalar_t val, int64_t idx) const {
272 return combine(acc, val);
273 }
274 };
275
276 template <typename scalar_t, typename func_t>
func_wrapper(const func_t & op)277 func_wrapper_t<scalar_t, func_t> func_wrapper(const func_t& op) {
278 return func_wrapper_t<scalar_t, func_t> { op };
279 }
280
281 template <typename scalar_t, typename out_scalar_t=scalar_t>
282 struct ReduceJitOp {
283 //ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations
284 //Maybe we can find a way to unify ReduceOp and ReduceJitOp
285 using InputCalculator = OffsetCalculator<1, uint32_t>;
286 using OutputCalculator = OffsetCalculator<2, uint32_t>;
287 //TODO for now arg_t is always opmath_t of the input, later we'll need to change it
288 using arg_t = at::opmath_type<scalar_t>;
289
290 static constexpr int input_vec_size = ReduceConfig::input_vec_size;
291 //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor,
292 //not just wrapper
293 arg_t ident;
294 ReduceConfig config;
295 InputCalculator input_calc;
296 OutputCalculator output_calc;
297 const void* src;
298 const char* dst[2]; //it accepts at most two destinations
299 // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
300 // output is not permissible
301 void* acc_buf;
302 // cta_buf used for accumulation between blocks during global reduction
303 void* cta_buf;
304 int* semaphores;
305 int64_t base_idx;
306 bool accumulate;
307 bool final_output;
308 int noutputs;
309
ReduceJitOpat::native::ReduceJitOp310 ReduceJitOp(
311 ReduceConfig config,
312 InputCalculator input_calc,
313 OutputCalculator output_calc,
314 const void* src,
315 char* dst0,
316 std::optional<char*> dst1,
317 void* acc_buf,
318 void* cta_buf,
319 int* semaphores,
320 arg_t ident,
321 int noutputs,
322 int64_t base_idx)
323 : ident(ident),
324 config(config),
325 input_calc(input_calc),
326 output_calc(output_calc),
327 src(src),
328 acc_buf(acc_buf),
329 cta_buf(cta_buf),
330 semaphores(semaphores),
331 base_idx(base_idx),
332 noutputs(noutputs) {
333 dst[0] = dst0;
334 if (dst1.has_value()) {
335 dst[1] = dst1.value();
336 }
337 }
338 };
339
340 template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t, int vt0=4>
341 struct ReduceOp {
342 using traits = function_traits<decltype(&ops_t::reduce)>;
343 using arg_t = typename std::decay<typename traits::template arg<0>::type>::type;
344
345 using InputCalculator = OffsetCalculator<1, index_t>;
346 using OutputCalculator = OffsetCalculator<2, index_t>;
347
348 static constexpr bool can_accumulate_in_output =
349 std::is_convertible<arg_t, out_scalar_t>::value
350 && std::is_convertible<out_scalar_t, arg_t>::value;
351
352 static constexpr int input_vec_size = ReduceConfig::input_vec_size;
353
354 ops_t ops;
355 arg_t ident;
356 ReduceConfig config;
357 InputCalculator input_calc;
358 OutputCalculator output_calc;
359 const void* src;
360 const char* dst[2]; //it accepts at most two destinations
361 // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
362 // output is not permissible
363 void* acc_buf;
364 // cta_buf used for accumulation between blocks during global reduction
365 void* cta_buf;
366 int* semaphores;
367 int64_t base_idx;
368 bool accumulate;
369 bool final_output;
370 int noutputs;
371
ReduceOpat::native::ReduceOp372 ReduceOp(
373 ops_t ops,
374 ReduceConfig config,
375 InputCalculator input_calc,
376 OutputCalculator output_calc,
377 const void* src,
378 char* dst0,
379 std::optional<char*> dst1,
380 void* acc_buf,
381 void* cta_buf,
382 int* semaphores,
383 arg_t ident,
384 int noutputs,
385 int64_t base_idx)
386 : ops(ops),
387 ident(ident),
388 config(config),
389 input_calc(input_calc),
390 output_calc(output_calc),
391 src(src),
392 acc_buf(acc_buf),
393 cta_buf(cta_buf),
394 semaphores(semaphores),
395 base_idx(base_idx),
396 noutputs(noutputs) {
397 dst[0] = dst0;
398 if (dst1.has_value()) {
399 dst[1] = dst1.value();
400 }
401 }
402
403 template <int output_vec_size>
runat::native::ReduceOp404 C10_DEVICE void run() const {
405 extern __shared__ char shared_memory[];
406 index_t output_idx = config.output_idx<output_vec_size>();
407 index_t input_idx = config.input_idx();
408 auto base_offsets1 = output_calc.get(output_idx)[1];
409
410 using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
411 arg_vec_t value;
412
413 if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
414 const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
415 value = thread_reduce<output_vec_size>(input_slice);
416 }
417
418 if (config.should_block_y_reduce()) {
419 value = block_y_reduce<output_vec_size>(value, shared_memory);
420 }
421 if (config.should_block_x_reduce()) {
422 value = block_x_reduce<output_vec_size>(value, shared_memory);
423 }
424
425 using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
426 using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
427 offset_vec_t base_offsets;
428 out_ptr_vec_t out;
429
430 #pragma unroll
431 for (int i = 0; i < output_vec_size; i++) {
432 base_offsets[i] = output_calc.get(output_idx + i)[0];
433 out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
434 }
435
436 arg_vec_t* acc = nullptr;
437 if (acc_buf != nullptr) {
438 size_t numerator = sizeof(arg_t);
439 size_t denominator = sizeof(out_scalar_t);
440 reduce_fraction(numerator, denominator);
441 acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
442 }
443
444 if (config.should_global_reduce()) {
445 value = global_reduce<output_vec_size>(value, acc, shared_memory);
446 } else if (config.should_store(output_idx)) {
447 if (accumulate) {
448 #pragma unroll
449 for (int i = 0; i < output_vec_size; i++) {
450 value[i] = ops.translate_idx(value[i], base_idx);
451 }
452 }
453
454 if (acc == nullptr) {
455 if (accumulate) {
456 value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
457 }
458 if (final_output) {
459 set_results_to_output<output_vec_size>(value, base_offsets);
460 } else {
461 #pragma unroll
462 for (int i = 0; i < output_vec_size; i++) {
463 *(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
464 }
465 }
466 } else {
467 if (accumulate) {
468 #pragma unroll
469 for (int i = 0; i < output_vec_size; i++) {
470 value[i] = ops.combine((*acc)[i], value[i]);
471 }
472 }
473 if (final_output) {
474 set_results_to_output<output_vec_size>(value, base_offsets);
475 } else {
476 *acc = value;
477 }
478 }
479 }
480 }
481
482 template <int output_vec_size>
thread_reduceat::native::ReduceOp483 C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
484 if (config.vectorize_input) {
485 CUDA_KERNEL_ASSERT(output_vec_size == 1);
486 // reduce at the header of input_slice where memory is not aligned,
487 // so that thread_reduce will have an aligned memory to work on.
488 return {input_vectorized_thread_reduce_impl(data)};
489 } else {
490 index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
491 bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
492 if (is_contiguous) {
493 return thread_reduce_impl<output_vec_size>(data, [](index_t idx) { return idx; });
494 } else if (input_calc.dims == 1) {
495 return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return idx * element_stride; });
496 } else {
497 return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
498 }
499 }
500 }
501
input_vectorized_thread_reduce_implat::native::ReduceOp502 C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
503 index_t end = config.num_inputs;
504
505 // Handle the head of input slice where data is not aligned
506 arg_t value = ident;
507 constexpr int align_bytes = alignof(at::native::memory::aligned_vector<scalar_t, input_vec_size>);
508 constexpr int align_elements = align_bytes / sizeof(scalar_t);
509 int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t);
510 if (shift > 0) {
511 data -= shift;
512 end += shift;
513 if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
514 value = ops.reduce(value, c10::load(data + threadIdx.x), threadIdx.x - shift);
515 }
516 end -= align_elements;
517 data += align_elements;
518 shift = align_elements - shift;
519 }
520
521 // Do the vectorized reduction
522 using load_t = at::native::memory::aligned_vector<scalar_t, input_vec_size>;
523
524 index_t idx = config.input_idx();
525 const index_t stride = config.step_input;
526
527 // Multiple accumulators to remove dependency between unrolled loops.
528 arg_t value_list[input_vec_size];
529 value_list[0] = value;
530
531 #pragma unroll
532 for (int i = 1; i < input_vec_size; i++) {
533 value_list[i] = ident;
534 }
535
536 while (idx * input_vec_size + input_vec_size - 1 < end) {
537 const auto values_vec = memory::load_vector<input_vec_size>(data, idx);
538 #pragma unroll
539 for (index_t i = 0; i < input_vec_size; i++) {
540 value_list[i] = ops.reduce(value_list[i], values_vec.val[i], shift + idx * input_vec_size + i);
541 }
542 idx += stride;
543 }
544
545 // tail
546 index_t tail_start = end - end % input_vec_size;
547 if (config.should_reduce_tail()) {
548 int idx = tail_start + threadIdx.x;
549 if (idx < end) {
550 const auto value = c10::load(data + idx);
551 value_list[0] = ops.reduce(value_list[0], value, idx + shift);
552 }
553 }
554
555 // combine accumulators
556 #pragma unroll
557 for (int i = 1; i < input_vec_size; i++) {
558 value_list[0] = ops.combine(value_list[0], value_list[i]);
559 }
560 return value_list[0];
561 }
562
563 template <int output_vec_size, typename offset_calc_t>
thread_reduce_implat::native::ReduceOp564 C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
565 index_t idx = config.input_idx();
566 const index_t end = config.num_inputs;
567 const index_t stride = config.step_input;
568
569 using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
570 using load_t = at::native::memory::aligned_vector<scalar_t, output_vec_size>;
571
572 // Multiple accumulators to remove dependency between unrolled loops.
573 arg_vec_t value_list[vt0];
574
575 #pragma unroll
576 for (int i = 0; i < vt0; i++) {
577 #pragma unroll
578 for (int j = 0; j < output_vec_size; j++) {
579 value_list[i][j] = ident;
580 }
581 }
582
583 load_t values[vt0];
584
585 while (idx + (vt0 - 1) * stride < end) {
586 #pragma unroll
587 for (index_t i = 0; i < vt0; i++) {
588 const auto offset = calc(idx + i * stride) / output_vec_size;
589 values[i] = memory::load_vector<output_vec_size>(data_, offset);
590 }
591 #pragma unroll
592 for (index_t i = 0; i < vt0; i++) {
593 #pragma unroll
594 for (index_t j = 0; j < output_vec_size; j++) {
595 value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride);
596 }
597 }
598 idx += stride * vt0;
599 }
600
601 // tail
602 int idx_ = idx;
603 #pragma unroll
604 for (index_t i = 0; i < vt0; i++) {
605 if (idx >= end) {
606 break;
607 }
608 const auto offset = calc(idx) / output_vec_size;
609 values[i] = memory::load_vector<output_vec_size>(data_, offset);
610 idx += stride;
611 }
612 idx = idx_;
613 #pragma unroll
614 for (index_t i = 0; i < vt0; i++) {
615 if (idx >= end) {
616 break;
617 }
618 #pragma unroll
619 for (index_t j = 0; j < output_vec_size; j++) {
620 value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx);
621 }
622 idx += stride;
623 }
624
625 // combine accumulators
626 #pragma unroll
627 for (int i = 1; i < vt0; i++) {
628 #pragma unroll
629 for (index_t j = 0; j < output_vec_size; j++) {
630 value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]);
631 }
632 }
633 return value_list[0];
634 }
635
636 template <int output_vec_size>
block_x_reduceat::native::ReduceOp637 C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_x_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
638 using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
639 int dim_x = blockDim.x;
640 args_vec_t* shared = (args_vec_t*)shared_memory;
641 if (dim_x > warpSize) {
642 int address_base = threadIdx.x + threadIdx.y*blockDim.x;
643 shared[address_base] = value;
644 for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
645 __syncthreads();
646 if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
647 args_vec_t other = shared[address_base + offset];
648 #pragma unroll
649 for (int i = 0; i < output_vec_size; i++) {
650 value[i] = ops.combine(value[i], other[i]);
651 }
652 shared[address_base] = value;
653 }
654 }
655 dim_x = warpSize;
656 }
657
658 __syncthreads();
659
660 for (int offset = 1; offset < dim_x; offset <<= 1) {
661 #pragma unroll
662 for (int i = 0; i < output_vec_size; i++) {
663 arg_t other = ops.warp_shfl_down(value[i], offset);
664 value[i] = ops.combine(value[i], other);
665 }
666 }
667 return value;
668 }
669
670 template <int output_vec_size>
block_y_reduceat::native::ReduceOp671 C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_y_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
672 using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
673 args_vec_t* shared = (args_vec_t*)shared_memory;
674 shared[config.shared_memory_offset(0)] = value;
675 for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
676 __syncthreads();
677 if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
678 args_vec_t other = shared[config.shared_memory_offset(offset)];
679 #pragma unroll
680 for (int i = 0; i < output_vec_size; i++) {
681 value[i] = ops.combine(value[i], other[i]);
682 }
683 shared[config.shared_memory_offset(0)] = value;
684 }
685 }
686 return value;
687 }
688
mark_block_finishedat::native::ReduceOp689 C10_DEVICE bool mark_block_finished() const {
690 __shared__ bool is_last_block_done_shared;
691
692 __syncthreads();
693 if (threadIdx.x == 0 && threadIdx.y == 0) {
694 int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
695 is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
696 }
697
698 __syncthreads();
699
700 return is_last_block_done_shared;
701 }
702
703 template <int output_vec_size, bool can_acc>
accumulate_in_outputat::native::ReduceOp704 C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
705 at::detail::Array<out_scalar_t*, output_vec_size> out,
706 at::detail::Array<arg_t, output_vec_size> value,
707 typename std::enable_if<can_acc>::type* = nullptr
708 ) const {
709 at::detail::Array<arg_t, output_vec_size> ret;
710 #pragma unroll
711 for (int i = 0; i < output_vec_size; i++) {
712 ret[i] = ops.combine(*(out[i]), value[i]);
713 }
714 return ret;
715 }
716
717 template <bool can_acc>
get_accumulated_outputat::native::ReduceOp718 C10_DEVICE out_scalar_t get_accumulated_output(
719 out_scalar_t* out, arg_t value,
720 typename std::enable_if<can_acc>::type* = nullptr
721 ) const {
722 CUDA_KERNEL_ASSERT(!final_output);
723 return (out_scalar_t)value;
724 }
725
726 // This function should never be called --
727 // it's the version of `accumulate_in_output`
728 // when accumulation in the output is not possible.
729 template <int output_vec_size, bool can_acc>
accumulate_in_outputat::native::ReduceOp730 C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
731 at::detail::Array<out_scalar_t*, output_vec_size>,
732 at::detail::Array<arg_t, output_vec_size>,
733 typename std::enable_if<!can_acc>::type* = nullptr
734 ) const {
735 CUDA_KERNEL_ASSERT(false);
736 return arg_t {};
737 }
738
739 // This function should never be called --
740 // it's the version of `get_accumulated_output`
741 // when accumulation in the output is not possible.
742 template <bool can_acc>
get_accumulated_outputat::native::ReduceOp743 C10_DEVICE out_scalar_t get_accumulated_output(
744 out_scalar_t* out, arg_t value,
745 typename std::enable_if<!can_acc>::type* = nullptr
746 ) const {
747 CUDA_KERNEL_ASSERT(false);
748 return *out;
749 }
750
751 template<class T>
set_resultsat::native::ReduceOp752 C10_DEVICE void set_results(const T x, const index_t base_offset) const {
753 CUDA_KERNEL_ASSERT(noutputs == 1);
754 auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
755 *res = x;
756 }
757
758 //Currently implemented for max of two outputs
759 template<class T1, class T2>
set_resultsat::native::ReduceOp760 C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
761 if (noutputs >= 1) {
762 auto res0 = (T1*)((char*)dst[0] + base_offset);
763 *res0 = x.first;
764 }
765 if (noutputs >= 2) {
766 // base offset is computed assuming element size being sizeof(T1), so we need to make a
767 // correction to obtain the correct base offset
768 auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
769 *res1 = x.second;
770 }
771 }
772
773 template <int output_vec_size>
set_results_to_outputat::native::ReduceOp774 C10_DEVICE void set_results_to_output(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<index_t, output_vec_size> base_offset) const {
775 CUDA_KERNEL_ASSERT(final_output);
776 #pragma unroll
777 for (int i = 0; i < output_vec_size; i++) {
778 set_results(ops.project(value[i]), base_offset[i]);
779 }
780 }
781
782 template <int output_vec_size>
global_reduceat::native::ReduceOp783 C10_DEVICE at::detail::Array<arg_t, output_vec_size> global_reduce(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
784 using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
785 using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
786 using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
787
788 arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
789 index_t output_idx = config.output_idx<output_vec_size>();
790 offset_vec_t base_offsets;
791 out_ptr_vec_t out;
792
793 #pragma unroll
794 for (int i = 0; i < output_vec_size; i++) {
795 base_offsets[i] = output_calc.get(output_idx + i)[0];
796 out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
797 }
798
799 bool should_store = config.should_store(output_idx);
800 if (should_store) {
801 index_t offset = config.staging_memory_offset(blockIdx.y);
802 reduce_buffer[offset] = value;
803 }
804
805 __threadfence(); // make sure writes are globally visible
806 __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
807 bool is_last_block_done = mark_block_finished();
808
809 if (is_last_block_done) {
810 __threadfence(); // complete the acquire pattern after atomic
811 value = ident;
812 if (config.should_block_x_reduce()) {
813 index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
814 index_t step = blockDim.x * blockDim.y;
815 for (; input_offset < config.ctas_per_output; input_offset += step) {
816 index_t idx = config.staging_memory_offset(input_offset);
817 arg_vec_t next = reduce_buffer[idx];
818 #pragma unroll
819 for (int i = 0; i < output_vec_size; i++) {
820 value[i] = ops.combine(value[i], next[i]);
821 }
822 }
823 } else {
824 index_t input_offset = threadIdx.y;
825 index_t step = blockDim.y;
826 for (; input_offset < config.ctas_per_output; input_offset += step) {
827 index_t idx = config.staging_memory_offset(input_offset);
828 arg_vec_t next = reduce_buffer[idx];
829 #pragma unroll
830 for (int i = 0; i < output_vec_size; i++) {
831 value[i] = ops.combine(value[i], next[i]);
832 }
833 }
834 }
835 value = block_y_reduce(value, shared_memory);
836 if (config.should_block_x_reduce()) {
837 value = block_x_reduce<output_vec_size>(value, shared_memory);
838 }
839 if (should_store) {
840 if (accumulate) {
841 #pragma unroll
842 for (int i = 0; i < output_vec_size; i++) {
843 value[i] = ops.translate_idx(value[i], base_idx);
844 }
845 }
846
847 if (acc == nullptr) {
848 if (accumulate) {
849 value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
850 }
851 if (final_output) {
852 set_results_to_output<output_vec_size>(value, base_offsets);
853 } else {
854 #pragma unroll
855 for (int i = 0; i < output_vec_size; i++) {
856 *(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
857 }
858 }
859 } else {
860 if (accumulate) {
861 #pragma unroll
862 for (int i = 0; i < output_vec_size; i++) {
863 value[i] = ops.combine((*acc)[i], value[i]);
864 }
865 }
866 if (final_output) {
867 set_results_to_output<output_vec_size>(value, base_offsets);
868 } else {
869 *acc = value;
870 }
871 }
872 }
873 }
874
875 return value;
876 }
877 };
878
879 template<int max_threads, typename R>
launch_reduce_kernel(const ReduceConfig & config,const R & reduction)880 static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
881 dim3 block = config.block();
882 dim3 grid = config.grid();
883
884 auto stream = at::cuda::getCurrentCUDAStream();
885 int shared_memory = config.shared_memory_size();
886
887 switch(config.output_vec_size) {
888 case 4:
889 reduce_kernel<max_threads / 4, 4, R><<<grid, block, shared_memory, stream>>>(reduction);
890 C10_CUDA_KERNEL_LAUNCH_CHECK();
891 break;
892 case 2:
893 reduce_kernel<max_threads / 2, 2, R><<<grid, block, shared_memory, stream>>>(reduction);
894 C10_CUDA_KERNEL_LAUNCH_CHECK();
895 break;
896 default:
897 reduce_kernel<max_threads / 1, 1, R><<<grid, block, shared_memory, stream>>>(reduction);
898 C10_CUDA_KERNEL_LAUNCH_CHECK();
899 }
900 }
901
launch_jitted_reduce_kernel(std::mutex & jiterator_mutex,std::array<at::cuda::jit::NvrtcFunction,3> & fn_cache,const at::cuda::jit::KernelDescriptor & desc,int vt0,const ReduceConfig & config,void * reduction)902 inline void launch_jitted_reduce_kernel(
903 std::mutex &jiterator_mutex,
904 std::array<at::cuda::jit::NvrtcFunction, 3> &fn_cache,
905 const at::cuda::jit::KernelDescriptor &desc,
906 int vt0, const ReduceConfig& config, void *reduction) {
907 dim3 block = config.block();
908 dim3 grid = config.grid();
909
910 int shared_memory = config.shared_memory_size();
911 at::cuda::jit::NvrtcFunction* fn_ptr;
912 switch(config.output_vec_size) {
913 case 4:
914 fn_ptr = &fn_cache[0];
915 break;
916 case 2:
917 fn_ptr = &fn_cache[1];
918 break;
919 default:
920 fn_ptr = &fn_cache[2];
921 }
922 if (!fn_ptr->function) {
923 int max_threads_codegen =
924 max_reduce_threads(desc.f_inputs_type) / config.output_vec_size;
925 auto code = at::cuda::jit::generate_reduction_code(
926 desc, vt0, true, false, config.output_vec_size, max_threads_codegen);
927
928 *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_" + desc.name);
929 }
930 constexpr int kernel_args = 1;
931 void* args[kernel_args];
932 args[0] = reduction;
933 at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory);
934 }
935
936
937 class AccumulationBuffer {
938 public:
AccumulationBuffer()939 AccumulationBuffer() {}
940
AccumulationBuffer(size_t acc_t_size,size_t out_t_size,char * out_ptr,int64_t size)941 AccumulationBuffer(size_t acc_t_size, size_t out_t_size, char* out_ptr, int64_t size) {
942 out_ptr_ = (char*)out_ptr;
943 if (out_t_size >= acc_t_size) {
944 // reusing output buffer for accumulation.
945 acc_ptr_ = (char*)out_ptr;
946 numerator_ = 1;
947 denominator_ = 1;
948 } else {
949 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
950 buffer_ = allocator.allocate(size);
951 acc_ptr_ = (char*)buffer_.get();
952 numerator_ = acc_t_size;
953 denominator_ = out_t_size;
954 reduce_fraction(numerator_, denominator_);
955 }
956 }
957
get_acc_slice(char * out_ptr)958 char* get_acc_slice(char* out_ptr) {
959 if (acc_ptr_ == nullptr) {
960 return nullptr;
961 }
962 return acc_ptr_ + ((out_ptr - out_ptr_) * numerator_ / denominator_);
963 }
964
965 private:
966 char* acc_ptr_ = nullptr;
967 char* out_ptr_ = nullptr;
968 size_t numerator_;
969 size_t denominator_;
970 at::DataPtr buffer_;
971 };
972
973 template <typename scalar_t>
get_output_vec_size(const TensorIterator & iter)974 int get_output_vec_size(const TensorIterator &iter) {
975 int vec_size = 4;
976 auto update_vec_size = [&vec_size](uint64_t n) {
977 while(n % vec_size != 0) {
978 vec_size /= 2;
979 }
980 };
981
982 uint64_t base_address = reinterpret_cast<uint64_t>(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t);
983 update_vec_size(base_address);
984
985 const int output_index = iter.num_reduce_dims();
986 update_vec_size(iter.shape()[output_index]);
987
988 int j = 0;
989 for(auto i : iter.strides(iter.noutputs())) {
990 if (j != output_index) {
991 update_vec_size(i / sizeof(scalar_t));
992 }
993 j++;
994 }
995 return vec_size;
996 }
997
998 template<typename arg_t, typename scalar_t, int vt0>
setReduceConfig(const TensorIterator & iter)999 ReduceConfig setReduceConfig(const TensorIterator& iter){
1000 // Start by assuming that each thread handles a single output and all
1001 // the inputs for that output.
1002 int64_t num_outputs = iter.num_output_elements();
1003 int64_t inputs_per_output = iter.numel() / num_outputs;
1004 int input_index = iter.ntensors() - 1;
1005
1006 auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output);
1007
1008 int64_t dim0;
1009 int64_t dim1;
1010 int64_t fastest_moving_stride;
1011 bool reduction_on_fastest_striding_dimension;
1012
1013 if (iter.ndim() > 0) {
1014 // Adjust block size to map block width to fastest changing dimension of input
1015 // tensor. This grants the best possible memory accessing pattern, given that
1016 // for non-contiguous tensor with space in between, we cannot have perfect
1017 // memory coalescing.
1018 reduction_on_fastest_striding_dimension =
1019 (iter.num_reduce_dims() == iter.ndim()) ||
1020 (iter.strides(/*arg=*/input_index)[0] <
1021 iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]);
1022 // Notice that dim0 & dim1 does NOT guarantee any launch configuration here!
1023 // dim0 & dim1 are more like the upper bound of the block dimension. The
1024 // actual launch config and reduction scheme is determined by setting values
1025 // to `config.input_mult` and `config.output_mult`.
1026 // We try to max out dim1 so that we have enough threads per CTA to deliver
1027 // performance for larger problem size.
1028 if (reduction_on_fastest_striding_dimension) {
1029 // Map block.x to the fastest reducing dimension. It implies:
1030 // 1. block_x_reduce is required.
1031 // 2. block.y now max out to num_outputs.
1032 dim0 = inputs_per_output;
1033 dim1 = num_outputs;
1034 fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
1035 } else {
1036 // Map block.x to the fastest non reducing dimension. It implies:
1037 // 1. block_x_reduce is turned off.
1038 // 2. block.y now max out to inputs_per_output.
1039 dim0 = num_outputs;
1040 dim1 = inputs_per_output;
1041 fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
1042 }
1043 } else {
1044 reduction_on_fastest_striding_dimension = true;
1045 fastest_moving_stride = sizeof(scalar_t);
1046 dim0 = 1;
1047 dim1 = 1;
1048 }
1049
1050 // We do vectorization to gain better memory access, there are two cases which we call
1051 // "vectorize along input" and "vectorize along output". Note that the "input/output"
1052 // here does not mean we are vectorizing load/store instructions. We always only vectorize
1053 // load instructions.
1054 //
1055 // Case 1: "vectorize along input"
1056 // This case happens when we are reducing along fastest moving dimesion. In such case, threads
1057 // with the same threadIdx.y works on the same reduction cooperatively and will produce results
1058 // for the same output. In such case, values in each loaded vector always correspond to the same output.
1059 //
1060 // Case 2: "vectorize along output"
1061 // This case happens when the fastest moving dimesion is not the dimension of reduction. In such case,
1062 // threads with different threadIdx.x are independent and will produce results for different outputs.
1063 // In such case, values in each loaded vector always correspond to different outputs.
1064 if (fastest_moving_stride == sizeof(scalar_t)) {
1065 if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= ReduceConfig::input_vec_size) {
1066 // Case 1: "vectorize along input"
1067 // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
1068 // we should avoid vectorization.
1069 config.vectorize_input = true;
1070 dim0 /= config.input_vec_size;
1071 } else if (!reduction_on_fastest_striding_dimension) {
1072 // Case 2: "vectorize along output"
1073 config.output_vec_size = get_output_vec_size<scalar_t>(iter);
1074 dim0 /= config.output_vec_size;
1075 }
1076 }
1077
1078 // Adjust block_width and block_height
1079 config.set_block_dimension<scalar_t>(dim0, dim1);
1080
1081 int block_width = config.block_width;
1082 int block_height = config.block_height;
1083
1084 if (iter.ndim() == 0 || reduction_on_fastest_striding_dimension) {
1085 // Split the input across lanes if the input is contiguous in the reduced
1086 // dimension. This will require reduction between threads using warp
1087 // shuffle instructions and shared memory (if block_width > warpSize).
1088 config.input_mult[0] = config.split_input(block_width);
1089 } else {
1090 // Otherwise split the output across lanes in a warp.
1091 config.output_mult[0] = config.split_output(block_width);
1092 }
1093
1094 constexpr int min_values_per_thread = 16;
1095 constexpr int max_values_per_thread = 256;
1096
1097 if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) {
1098 // Divide the input across warps in a thread-block, if that leaves at least
1099 // 16 elements to be summed by each thread. This will require inter-warp
1100 // reduction using shared memory.
1101 config.input_mult[1] = config.split_input(block_height);
1102 } else {
1103 // Otherwise, each warp handles a separate output.
1104 config.output_mult[1] = config.split_output(block_height);
1105 }
1106
1107 const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / config.num_threads;
1108 const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1109 const int target_grid_size = num_mp * blocks_per_sm;
1110 int grid = config.grid().x;
1111 if (config.input_mult[1] != 0 && config.values_per_thread() >= max_values_per_thread && grid <= target_grid_size) {
1112 // Divide the input across thread-blocks if the amount of work per-thread
1113 // is large enough and the size of the output is small enough. This will
1114 // require a reduction using global memory.
1115 // If we decide to split input across blocks, as long as we can get enough
1116 // number of blocks (`target_grid_size`) to balance SM, we should still
1117 // make the number of values per thread large for best performance.
1118 int ctas_per_output1 = div_up(target_grid_size, grid);
1119 int ctas_per_output2 = div_up(config.values_per_thread(), min_values_per_thread);
1120 int ctas_per_output3 = div_up(config.values_per_thread(), max_values_per_thread);
1121 // We want the minimum of ctas_per_output1 and ctas_per_output2, so that each thread can have
1122 // a large number of values to deal with. But we don't want values_per_thread to be larger than
1123 // max_values_per_thread
1124 config.ctas_per_output = std::max(std::min<int>(ctas_per_output1, ctas_per_output2), ctas_per_output3);
1125 if (config.ctas_per_output > 1) {
1126 config.input_mult[2] = config.split_input(config.ctas_per_output);
1127 }
1128 }
1129 return config;
1130 };
1131
1132 template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
gpu_reduce_kernel(TensorIterator & iter,const ops_t & ops,ident_t ident=0,AccumulationBuffer * acc_buf_ptr=nullptr,int64_t base_idx=0)1133 inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
1134 AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
1135 AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
1136
1137 using traits = function_traits<decltype(&ops_t::reduce)>;
1138 using arg_t = typename traits::template arg<0>::type;
1139 // at::Half/at::ComplexHalf overflows easily as it's range is very small.
1140 // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
1141 // set can_accumulate_in_output to False.
1142 static constexpr bool is_inp_out_type_half_or_chalf =
1143 (std::is_same<at::Half, scalar_t>::value &&
1144 std::is_same<at::Half, out_scalar_t>::value) ||
1145 (std::is_same<c10::complex<Half>, scalar_t>::value &&
1146 std::is_same<c10::complex<Half>, out_scalar_t>::value);
1147 // at::BFloat16 has lower precision and can lead to rounding errors.
1148 // So when scalar_t and out_scalar_t are at::BFloat16, we
1149 // set can_accumulate_in_output to False.
1150 static constexpr bool is_inp_out_type_bfloat16 =
1151 (std::is_same<at::BFloat16, scalar_t>::value &&
1152 std::is_same<at::BFloat16, out_scalar_t>::value);
1153 static constexpr bool can_accumulate_in_output =
1154 std::is_convertible<arg_t, out_scalar_t>::value &&
1155 !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);
1156
1157 bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
1158 std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
1159 // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
1160 // reused by all recursive function calls.
1161 if (acc_buf_ptr == NULL) {
1162 // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
1163 // when accumulation in output is not possible.
1164 if (!can_accumulate_in_output && !can_use_32bit_indexing) {
1165 int64_t output_memory_size = iter.element_size(0);
1166 for (int dim = 0; dim < iter.ndim(); dim++) {
1167 output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
1168 }
1169 output_memory_size /= iter.element_size(0); //iter.strides is in bytes
1170 owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
1171 sizeof(out_scalar_t),
1172 (char*) iter.data_ptr(0),
1173 output_memory_size * sizeof(arg_t)));
1174 } else {
1175 owned_buf_ptr.reset(new AccumulationBuffer());
1176 }
1177 acc_buf_ptr = owned_buf_ptr.get();
1178 }
1179
1180 if (!can_use_32bit_indexing) {
1181 for (auto& sub_iter : iter.with_32bit_indexing()) {
1182 int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
1183
1184 gpu_reduce_kernel<scalar_t, out_scalar_t, vt0>(sub_iter, ops, ident,
1185 acc_buf_ptr, sub_iter_base_idx);
1186 }
1187 return;
1188 }
1189
1190 const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
1191 char* out_data = (char*)iter.data_ptr(0);
1192 const auto noutputs = iter.noutputs();
1193 std::optional<char*> out_data_extra;
1194 if (noutputs > 1) {
1195 out_data_extra = (char*)iter.data_ptr(1);
1196 } else {
1197 out_data_extra = std::nullopt;
1198 }
1199 char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
1200
1201 ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
1202 at::DataPtr buffer;
1203 at::DataPtr semaphores;
1204 if (config.should_global_reduce()) {
1205 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1206 buffer = allocator.allocate(config.global_memory_size());
1207 semaphores = allocator.allocate(config.semaphore_size());
1208
1209 auto stream = at::cuda::getCurrentCUDAStream();
1210 AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
1211 }
1212
1213 AT_ASSERT(can_use_32bit_indexing);
1214 auto output_calc = make_output_calculator<uint32_t>(iter);
1215 auto input_calc = make_input_calculator<uint32_t>(iter);
1216 auto reduce = ReduceOp<scalar_t, ops_t, uint32_t, out_scalar_t, vt0>(
1217 ops,
1218 config,
1219 input_calc,
1220 output_calc,
1221 in_data,
1222 out_data,
1223 out_data_extra,
1224 acc_data,
1225 buffer.get(),
1226 (int*)semaphores.get(),
1227 ident,
1228 noutputs,
1229 base_idx);
1230 reduce.accumulate = iter.should_accumulate();
1231 reduce.final_output = iter.is_final_output();
1232
1233 launch_reduce_kernel<mnt_wrapper<scalar_t>::MAX_NUM_THREADS>(config, reduce);
1234 }
1235
1236 //TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function
1237 //try unifying with gpu_reduce_kernel
1238 template <char const* name, typename scalar_t, typename out_scalar_t, int vt0=4, typename ident_t=double>
jitted_gpu_reduce_kernel(TensorIterator & iter,const std::string & func,ident_t ident=0,AccumulationBuffer * acc_buf_ptr=nullptr,int64_t base_idx=0)1239 inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0,
1240 AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
1241 AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
1242
1243 //TODO - this will be different for more complicated reductions, but for now reductions using
1244 //func_wrapper all have arg_t = opmath
1245 using arg_t = at::opmath_type<scalar_t>;
1246 // at::Half/at::ComplexHalf overflows easily as it's range is very small.
1247 // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
1248 // set can_accumulate_in_output to False.
1249 static constexpr bool is_inp_out_type_half_or_chalf =
1250 (std::is_same<at::Half, scalar_t>::value &&
1251 std::is_same<at::Half, out_scalar_t>::value) ||
1252 (std::is_same<c10::complex<Half>, scalar_t>::value &&
1253 std::is_same<c10::complex<Half>, out_scalar_t>::value);
1254 // at::BFloat16 has lower precision and can lead to rounding errors.
1255 // So when scalar_t and out_scalar_t are at::BFloat16, we
1256 // set can_accumulate_in_output to False.
1257 static constexpr bool is_inp_out_type_bfloat16 =
1258 (std::is_same<at::BFloat16, scalar_t>::value &&
1259 std::is_same<at::BFloat16, out_scalar_t>::value);
1260 static constexpr bool can_accumulate_in_output =
1261 std::is_convertible<arg_t, out_scalar_t>::value &&
1262 !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);
1263
1264 bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
1265 std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
1266
1267 // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
1268 // reused by all recursive function calls.
1269 if (acc_buf_ptr == NULL) {
1270 // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
1271 // when accumulation in output is not possible.
1272 if (!can_accumulate_in_output && !can_use_32bit_indexing) {
1273 int64_t output_memory_size = iter.element_size(0);
1274 for (int dim = 0; dim < iter.ndim(); dim++) {
1275 output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
1276 }
1277 output_memory_size /= iter.element_size(0); //iter.strides is in bytes
1278 owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO
1279 sizeof(out_scalar_t),
1280 (char*) iter.data_ptr(0),
1281 output_memory_size * sizeof(out_scalar_t))); //TODO
1282 } else {
1283 owned_buf_ptr.reset(new AccumulationBuffer());
1284 }
1285 acc_buf_ptr = owned_buf_ptr.get();
1286 }
1287
1288 if (!can_use_32bit_indexing) {
1289 for (auto& sub_iter : iter.with_32bit_indexing()) {
1290 int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
1291
1292 jitted_gpu_reduce_kernel<name, scalar_t, out_scalar_t, vt0>(sub_iter, func, ident,
1293 acc_buf_ptr, sub_iter_base_idx);
1294 }
1295 return;
1296 }
1297
1298 //TODO - for now we support a single input, we may be able to relax this constraint
1299 const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
1300 char* out_data = (char*)iter.data_ptr(0);
1301 const auto noutputs = iter.noutputs();
1302 std::optional<char*> out_data_extra;
1303 if (noutputs > 1) {
1304 out_data_extra = (char*)iter.data_ptr(1);
1305 } else {
1306 out_data_extra = std::nullopt;
1307 }
1308 char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
1309
1310 ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
1311
1312 at::DataPtr buffer;
1313 at::DataPtr semaphores;
1314 if (config.should_global_reduce()) {
1315 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1316 buffer = allocator.allocate(config.global_memory_size());
1317 semaphores = allocator.allocate(config.semaphore_size());
1318
1319 auto stream = at::cuda::getCurrentCUDAStream();
1320 AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
1321 }
1322
1323 AT_ASSERT(can_use_32bit_indexing);
1324 auto output_calc = make_output_calculator<uint32_t>(iter);
1325 auto input_calc = make_input_calculator<uint32_t>(iter);
1326 auto reduce = ReduceJitOp<scalar_t, out_scalar_t>(
1327 config,
1328 input_calc,
1329 output_calc,
1330 in_data,
1331 out_data,
1332 out_data_extra,
1333 acc_data,
1334 buffer.get(),
1335 (int*)semaphores.get(),
1336 ident,
1337 noutputs,
1338 base_idx);
1339 reduce.accumulate = iter.should_accumulate();
1340 reduce.final_output = iter.is_final_output();
1341
1342 constexpr int nInputs = 1;
1343 constexpr int nOutputs = 1;
1344 static auto desc = at::cuda::jit::make_kernel_descriptor<
1345 out_scalar_t, scalar_t>(name, func, nInputs, nOutputs);
1346
1347 static std::mutex jiterator_mutex;
1348 static std::vector<std::array<at::cuda::jit::NvrtcFunction, 3>> fn_cache(c10::cuda::device_count());
1349 auto &cache = fn_cache[iter.device().index()];
1350
1351 launch_jitted_reduce_kernel(
1352 jiterator_mutex, cache, desc, vt0, config, &reduce);
1353 }
1354
1355 }} // namespace at::native
1356