xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Reduce.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Array.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/DeviceUtils.cuh>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/detail/FunctionTraits.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/cuda/thread_constants.h>
10 #include <ATen/native/cuda/MemoryAccess.cuh>
11 #include <ATen/OpMathType.h>
12 #include <c10/macros/Macros.h>
13 #include <c10/cuda/CUDACachingAllocator.h>
14 #include <functional>
15 #include <iosfwd>
16 #include <type_traits>
17 #include <utility>
18 #include <thrust/pair.h>
19 
20 #include <ATen/native/cuda/jit_utils.h>
21 
22 namespace at { namespace native {
23 
24 using at::detail::Array;
25 
div_up(int64_t a,int64_t b)26 static inline int64_t div_up(int64_t a, int64_t b) {
27   return (a + b - 1) / b;
28 }
29 
30 // returns floor(log2(n))
last_pow2(int n)31 static inline int last_pow2(int n) {
32   n |= (n >>  1);
33   n |= (n >>  2);
34   n |= (n >>  4);
35   n |= (n >>  8);
36   n |= (n >> 16);
37   return std::max(1, n - (n >> 1));
38 }
39 
40 // returns reduced fraction numerator & denominator
reduce_fraction(size_t & numerator,size_t & denominator)41 C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
42   // get GCD of num and denom using Euclid's algorithm.
43   // Can replace this with std::gcd if we ever support c++17.
44   size_t a = denominator;
45   size_t b = numerator;
46   while (b != 0) {
47       a %= b;
48       // swap(a,b)
49       size_t tmp = a;
50       a = b;
51       b = tmp;
52   }
53 
54   // a is now the GCD
55   numerator /= a;
56   denominator /= a;
57 }
58 
59 //template for changing MAX_NUM_THREADS based on op dtype
60 template <typename T>
61 struct mnt_wrapper {
62   static constexpr int MAX_NUM_THREADS = 512;
63 };
64 
65 template <>
66 struct mnt_wrapper <c10::complex<double>>{
67   static constexpr int MAX_NUM_THREADS = 256;
68 };
69 
max_reduce_threads(c10::ScalarType type)70 constexpr int max_reduce_threads(c10::ScalarType type) {
71   return type == kComplexDouble ? 256 : 512;
72 }
73 
74 struct ReduceConfig {
75   static constexpr int BLOCK_X = 0;
76   static constexpr int BLOCK_Y = 1;
77   static constexpr int CTA = 2;
78 
79   static constexpr int input_vec_size = 4;
80 
ReduceConfigat::native::ReduceConfig81   ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
82     : element_size_bytes(element_size_bytes)
83     , num_inputs(num_inputs)
84     , num_outputs(num_outputs) {}
85   int element_size_bytes;
86   int num_inputs;
87   int num_outputs;
88   int step_input = 1;
89   int step_output = 1;
90   int ctas_per_output = 1;
91   int input_mult[3] = {0, 0, 0};
92   int output_mult[2] = {0, 0};
93 
94   int block_width;
95   int block_height;
96   int num_threads;
97 
98   bool vectorize_input = false;
99   int output_vec_size = 1;
100 
101   template <typename T>
set_block_dimensionat::native::ReduceConfig102   void set_block_dimension(int64_t dim0, int64_t dim1) {
103     const int max_num_threads = mnt_wrapper<T>::MAX_NUM_THREADS / output_vec_size;
104     int dim0_pow2 = dim0 < max_num_threads ? static_cast<int>(last_pow2(dim0)) : max_num_threads;
105     int dim1_pow2 = dim1 < max_num_threads ? static_cast<int>(last_pow2(dim1)) : max_num_threads;
106     block_width = std::min(dim0_pow2, int(at::cuda::warp_size()));
107     block_height = std::min(dim1_pow2, int(max_num_threads / block_width));
108     block_width = std::min(dim0_pow2, int(max_num_threads / block_height));
109     num_threads = block_width * block_height;
110   }
111 
split_inputat::native::ReduceConfig112   int split_input(int parallelism) {
113     int step = step_input;
114     step_input *= parallelism;
115     return step;
116   }
117 
split_outputat::native::ReduceConfig118   int split_output(int parallelism) {
119     int step = step_output;
120     step_output *= parallelism;
121     return step;
122   }
123 
blockat::native::ReduceConfig124   dim3 block() const {
125     return dim3(block_width, block_height);
126   }
127 
gridat::native::ReduceConfig128   dim3 grid() const {
129     return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output);
130   }
131 
should_block_x_reduceat::native::ReduceConfig132   C10_HOST_DEVICE bool should_block_x_reduce() const {
133     return input_mult[BLOCK_X] != 0;
134   }
135 
should_block_y_reduceat::native::ReduceConfig136   C10_HOST_DEVICE bool should_block_y_reduce() const {
137     return input_mult[BLOCK_Y] != 0;
138   }
139 
should_global_reduceat::native::ReduceConfig140   C10_HOST_DEVICE bool should_global_reduce() const {
141     return input_mult[CTA] != 0;
142   }
143 
should_storeat::native::ReduceConfig144   C10_DEVICE bool should_store(int output_idx) const {
145     return output_idx < num_outputs &&
146       (!should_block_x_reduce() || threadIdx.x == 0) &&
147       (!should_block_y_reduce() || threadIdx.y == 0);
148   }
149 
should_reduce_tailat::native::ReduceConfig150   C10_DEVICE bool should_reduce_tail() const {
151     return (!should_block_y_reduce() || threadIdx.y == 0) &&
152       (!should_global_reduce() || blockIdx.y == 0);
153   }
154 
input_idxat::native::ReduceConfig155   C10_HOST_DEVICE int input_idx() const {
156     int lane = threadIdx.x;
157     int warp = threadIdx.y;
158     int cta2 = blockIdx.y;
159     return (lane * input_mult[BLOCK_X] +
160             warp * input_mult[BLOCK_Y] +
161             cta2 * input_mult[CTA]);
162   }
163 
164   template <int output_vec_size>
output_idxat::native::ReduceConfig165   C10_HOST_DEVICE int output_idx() const {
166     int lane = threadIdx.x;
167     int warp = threadIdx.y;
168     int cta1 = blockIdx.x;
169     return (lane * output_mult[BLOCK_X] +
170             warp * output_mult[BLOCK_Y] +
171             cta1 * step_output) * output_vec_size;
172   }
173 
shared_memory_offsetat::native::ReduceConfig174   C10_DEVICE int shared_memory_offset(int offset) const {
175     return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
176   }
177 
staging_memory_offsetat::native::ReduceConfig178   C10_DEVICE int staging_memory_offset(int cta2) const {
179     int offset = cta2 + blockIdx.x * gridDim.y;
180     if (!should_block_x_reduce()) {
181       offset = threadIdx.x + offset * blockDim.x;
182     }
183     return offset;
184   }
185 
shared_memory_sizeat::native::ReduceConfig186   int shared_memory_size() const {
187     if (!should_block_y_reduce() &&
188         (!should_block_x_reduce() ||
189          block_width <= at::cuda::warp_size())) {
190       return 0;
191     }
192     return element_size_bytes * num_threads * output_vec_size;
193   }
194 
global_memory_sizeat::native::ReduceConfig195   int64_t global_memory_size() const {
196     if (!should_global_reduce()) {
197       return 0;
198     }
199     auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
200     if (!should_block_x_reduce()) {
201       size *= block().x * output_vec_size;
202     }
203     return size;
204   }
205 
semaphore_sizeat::native::ReduceConfig206   int semaphore_size() const {
207     if (!should_global_reduce()) {
208       return 0;
209     }
210     return sizeof(int) * grid().x;
211   }
212 
values_per_threadat::native::ReduceConfig213   int values_per_thread() const {
214     return div_up(num_inputs, step_input);
215   }
216 };
217 
218 std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
219 
220 template<int nt, int output_vec_size, typename R>
221 C10_LAUNCH_BOUNDS_2(nt, 4)
reduce_kernel(R reduction)222 __global__ void reduce_kernel(R reduction) {
223   reduction.template run<output_vec_size>();
224 }
225 
226 template <typename index_t>
make_output_calculator(const TensorIterator & iter)227 static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) {
228   int num_reduce_dims = iter.num_reduce_dims();
229   int num_output_dims = iter.ndim() - num_reduce_dims;
230   int input_index = iter.ntensors() - 1;
231   int output_index = 0;
232   std::array<const int64_t*, 2> strides = {
233     iter.strides(output_index).data() + num_reduce_dims,
234     iter.strides(input_index).data() + num_reduce_dims,
235   };
236   auto shape = iter.shape().data() + num_reduce_dims;
237   return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data());
238 }
239 
240 template <typename index_t>
make_input_calculator(const TensorIterator & iter)241 static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) {
242   int num_reduce_dims = iter.num_reduce_dims();
243   int input_index = iter.ntensors() - 1;
244   std::array<const int64_t*, 1> strides = {
245     iter.strides(input_index).data(),
246   };
247   return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data());
248 }
249 
250 template <typename out_scalar_t, typename func_t>
251 struct func_wrapper_t {
252   using arg_t = typename binary_function_traits<func_t>::arg1_t;
253   using scalar_t = typename binary_function_traits<func_t>::arg2_t;
254 
255   func_t combine;
projectat::native::func_wrapper_t256   static inline __device__ out_scalar_t project(arg_t arg) {
257     return (out_scalar_t) arg;
258   }
warp_shfl_downat::native::func_wrapper_t259   static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
260     return WARP_SHFL_DOWN(arg, offset);
261   }
262 
translate_idxat::native::func_wrapper_t263   static __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
264     return acc;
265   }
266 
func_wrapper_tat::native::func_wrapper_t267   func_wrapper_t(const func_t& op) : combine(op) {
268   }
269 
270   // wrap a normal reduction that ignores the index
reduceat::native::func_wrapper_t271   __device__ arg_t reduce(arg_t acc, scalar_t val, int64_t idx) const {
272     return combine(acc, val);
273   }
274 };
275 
276 template <typename scalar_t, typename func_t>
func_wrapper(const func_t & op)277 func_wrapper_t<scalar_t, func_t> func_wrapper(const func_t& op) {
278   return func_wrapper_t<scalar_t, func_t> { op };
279 }
280 
281 template <typename scalar_t, typename out_scalar_t=scalar_t>
282 struct ReduceJitOp {
283 //ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations
284 //Maybe we can find a way to unify ReduceOp and ReduceJitOp
285   using InputCalculator = OffsetCalculator<1, uint32_t>;
286   using OutputCalculator = OffsetCalculator<2, uint32_t>;
287   //TODO for now arg_t is always opmath_t of the input, later we'll need to change it
288   using arg_t = at::opmath_type<scalar_t>;
289 
290   static constexpr int input_vec_size = ReduceConfig::input_vec_size;
291   //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor,
292   //not just wrapper
293   arg_t ident;
294   ReduceConfig config;
295   InputCalculator input_calc;
296   OutputCalculator output_calc;
297   const void* src;
298   const char* dst[2]; //it accepts at most two destinations
299   // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
300   // output is not permissible
301   void* acc_buf;
302   // cta_buf used for accumulation between blocks during global reduction
303   void* cta_buf;
304   int* semaphores;
305   int64_t base_idx;
306   bool accumulate;
307   bool final_output;
308   int noutputs;
309 
ReduceJitOpat::native::ReduceJitOp310   ReduceJitOp(
311       ReduceConfig config,
312       InputCalculator input_calc,
313       OutputCalculator output_calc,
314       const void* src,
315       char* dst0,
316       std::optional<char*> dst1,
317       void* acc_buf,
318       void* cta_buf,
319       int* semaphores,
320       arg_t ident,
321       int noutputs,
322       int64_t base_idx)
323       : ident(ident),
324         config(config),
325         input_calc(input_calc),
326         output_calc(output_calc),
327         src(src),
328         acc_buf(acc_buf),
329         cta_buf(cta_buf),
330         semaphores(semaphores),
331         base_idx(base_idx),
332         noutputs(noutputs) {
333     dst[0] = dst0;
334     if (dst1.has_value()) {
335       dst[1] = dst1.value();
336     }
337   }
338 };
339 
340 template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t, int vt0=4>
341 struct ReduceOp {
342   using traits = function_traits<decltype(&ops_t::reduce)>;
343   using arg_t = typename std::decay<typename traits::template arg<0>::type>::type;
344 
345   using InputCalculator = OffsetCalculator<1, index_t>;
346   using OutputCalculator = OffsetCalculator<2, index_t>;
347 
348   static constexpr bool can_accumulate_in_output =
349     std::is_convertible<arg_t, out_scalar_t>::value
350     && std::is_convertible<out_scalar_t, arg_t>::value;
351 
352   static constexpr int input_vec_size = ReduceConfig::input_vec_size;
353 
354   ops_t ops;
355   arg_t ident;
356   ReduceConfig config;
357   InputCalculator input_calc;
358   OutputCalculator output_calc;
359   const void* src;
360   const char* dst[2]; //it accepts at most two destinations
361   // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
362   // output is not permissible
363   void* acc_buf;
364   // cta_buf used for accumulation between blocks during global reduction
365   void* cta_buf;
366   int* semaphores;
367   int64_t base_idx;
368   bool accumulate;
369   bool final_output;
370   int noutputs;
371 
ReduceOpat::native::ReduceOp372   ReduceOp(
373       ops_t ops,
374       ReduceConfig config,
375       InputCalculator input_calc,
376       OutputCalculator output_calc,
377       const void* src,
378       char* dst0,
379       std::optional<char*> dst1,
380       void* acc_buf,
381       void* cta_buf,
382       int* semaphores,
383       arg_t ident,
384       int noutputs,
385       int64_t base_idx)
386       : ops(ops),
387         ident(ident),
388         config(config),
389         input_calc(input_calc),
390         output_calc(output_calc),
391         src(src),
392         acc_buf(acc_buf),
393         cta_buf(cta_buf),
394         semaphores(semaphores),
395         base_idx(base_idx),
396         noutputs(noutputs) {
397     dst[0] = dst0;
398     if (dst1.has_value()) {
399       dst[1] = dst1.value();
400     }
401   }
402 
403   template <int output_vec_size>
runat::native::ReduceOp404   C10_DEVICE void run() const {
405     extern __shared__ char shared_memory[];
406     index_t output_idx = config.output_idx<output_vec_size>();
407     index_t input_idx = config.input_idx();
408     auto base_offsets1 = output_calc.get(output_idx)[1];
409 
410     using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
411     arg_vec_t value;
412 
413     if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
414       const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
415       value = thread_reduce<output_vec_size>(input_slice);
416     }
417 
418     if (config.should_block_y_reduce()) {
419       value = block_y_reduce<output_vec_size>(value, shared_memory);
420     }
421     if (config.should_block_x_reduce()) {
422       value = block_x_reduce<output_vec_size>(value, shared_memory);
423     }
424 
425     using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
426     using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
427     offset_vec_t base_offsets;
428     out_ptr_vec_t out;
429 
430     #pragma unroll
431     for (int i = 0; i < output_vec_size; i++) {
432       base_offsets[i] = output_calc.get(output_idx + i)[0];
433       out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
434     }
435 
436     arg_vec_t* acc = nullptr;
437     if (acc_buf != nullptr) {
438       size_t numerator = sizeof(arg_t);
439       size_t denominator = sizeof(out_scalar_t);
440       reduce_fraction(numerator, denominator);
441       acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
442     }
443 
444     if (config.should_global_reduce()) {
445       value = global_reduce<output_vec_size>(value, acc, shared_memory);
446     } else if (config.should_store(output_idx)) {
447       if (accumulate) {
448         #pragma unroll
449         for (int i = 0; i < output_vec_size; i++) {
450           value[i] = ops.translate_idx(value[i], base_idx);
451         }
452       }
453 
454       if (acc == nullptr) {
455         if (accumulate) {
456           value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
457         }
458         if (final_output) {
459           set_results_to_output<output_vec_size>(value, base_offsets);
460         } else {
461           #pragma unroll
462           for (int i = 0; i < output_vec_size; i++) {
463             *(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
464           }
465         }
466       } else {
467         if (accumulate) {
468           #pragma unroll
469           for (int i = 0; i < output_vec_size; i++) {
470             value[i] = ops.combine((*acc)[i], value[i]);
471           }
472         }
473         if (final_output) {
474           set_results_to_output<output_vec_size>(value, base_offsets);
475         } else {
476           *acc = value;
477         }
478       }
479     }
480   }
481 
482   template <int output_vec_size>
thread_reduceat::native::ReduceOp483   C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
484     if (config.vectorize_input) {
485       CUDA_KERNEL_ASSERT(output_vec_size == 1);
486       // reduce at the header of input_slice where memory is not aligned,
487       // so that thread_reduce will have an aligned memory to work on.
488       return {input_vectorized_thread_reduce_impl(data)};
489     } else {
490       index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
491       bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
492       if (is_contiguous) {
493         return thread_reduce_impl<output_vec_size>(data, [](index_t idx) { return idx; });
494       } else if (input_calc.dims == 1) {
495         return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return idx * element_stride; });
496       } else {
497         return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
498       }
499     }
500   }
501 
input_vectorized_thread_reduce_implat::native::ReduceOp502   C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
503     index_t end = config.num_inputs;
504 
505     // Handle the head of input slice where data is not aligned
506     arg_t value = ident;
507     constexpr int align_bytes = alignof(at::native::memory::aligned_vector<scalar_t, input_vec_size>);
508     constexpr int align_elements = align_bytes / sizeof(scalar_t);
509     int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t);
510     if (shift > 0) {
511       data -= shift;
512       end += shift;
513       if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
514         value = ops.reduce(value, c10::load(data + threadIdx.x), threadIdx.x - shift);
515       }
516       end -= align_elements;
517       data += align_elements;
518       shift = align_elements - shift;
519     }
520 
521     // Do the vectorized reduction
522     using load_t = at::native::memory::aligned_vector<scalar_t, input_vec_size>;
523 
524     index_t idx = config.input_idx();
525     const index_t stride = config.step_input;
526 
527     // Multiple accumulators to remove dependency between unrolled loops.
528     arg_t value_list[input_vec_size];
529     value_list[0] = value;
530 
531     #pragma unroll
532     for (int i = 1; i < input_vec_size; i++) {
533       value_list[i] = ident;
534     }
535 
536     while (idx * input_vec_size + input_vec_size - 1 < end) {
537       const auto values_vec = memory::load_vector<input_vec_size>(data, idx);
538       #pragma unroll
539       for (index_t i = 0; i < input_vec_size; i++) {
540         value_list[i] = ops.reduce(value_list[i], values_vec.val[i], shift + idx * input_vec_size + i);
541       }
542       idx += stride;
543     }
544 
545     // tail
546     index_t tail_start = end - end % input_vec_size;
547     if (config.should_reduce_tail()) {
548       int idx = tail_start + threadIdx.x;
549       if (idx < end) {
550         const auto value = c10::load(data + idx);
551         value_list[0] = ops.reduce(value_list[0], value, idx + shift);
552       }
553     }
554 
555     // combine accumulators
556     #pragma unroll
557     for (int i = 1; i < input_vec_size; i++) {
558       value_list[0] = ops.combine(value_list[0], value_list[i]);
559     }
560     return value_list[0];
561   }
562 
563   template <int output_vec_size, typename offset_calc_t>
thread_reduce_implat::native::ReduceOp564   C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
565     index_t idx = config.input_idx();
566     const index_t end = config.num_inputs;
567     const index_t stride = config.step_input;
568 
569     using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
570     using load_t = at::native::memory::aligned_vector<scalar_t, output_vec_size>;
571 
572     // Multiple accumulators to remove dependency between unrolled loops.
573     arg_vec_t value_list[vt0];
574 
575     #pragma unroll
576     for (int i = 0; i < vt0; i++) {
577       #pragma unroll
578       for (int j = 0; j < output_vec_size; j++) {
579         value_list[i][j] = ident;
580       }
581     }
582 
583     load_t values[vt0];
584 
585     while (idx + (vt0 - 1) * stride < end) {
586       #pragma unroll
587       for (index_t i = 0; i < vt0; i++) {
588         const auto offset = calc(idx + i * stride) / output_vec_size;
589         values[i] = memory::load_vector<output_vec_size>(data_, offset);
590       }
591       #pragma unroll
592       for (index_t i = 0; i < vt0; i++) {
593         #pragma unroll
594         for (index_t j = 0; j < output_vec_size; j++) {
595           value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride);
596         }
597       }
598       idx += stride * vt0;
599     }
600 
601     // tail
602     int idx_ = idx;
603     #pragma unroll
604     for (index_t i = 0; i < vt0; i++) {
605       if (idx >= end) {
606         break;
607       }
608       const auto offset = calc(idx) / output_vec_size;
609       values[i] = memory::load_vector<output_vec_size>(data_, offset);
610       idx += stride;
611     }
612     idx = idx_;
613     #pragma unroll
614     for (index_t i = 0; i < vt0; i++) {
615       if (idx >= end) {
616         break;
617       }
618       #pragma unroll
619       for (index_t j = 0; j < output_vec_size; j++) {
620         value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx);
621       }
622       idx += stride;
623     }
624 
625     // combine accumulators
626     #pragma unroll
627     for (int i = 1; i < vt0; i++) {
628       #pragma unroll
629       for (index_t j = 0; j < output_vec_size; j++) {
630         value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]);
631       }
632     }
633     return value_list[0];
634   }
635 
636   template <int output_vec_size>
block_x_reduceat::native::ReduceOp637   C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_x_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
638     using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
639     int dim_x = blockDim.x;
640     args_vec_t* shared = (args_vec_t*)shared_memory;
641     if (dim_x > warpSize) {
642       int address_base = threadIdx.x + threadIdx.y*blockDim.x;
643       shared[address_base] = value;
644       for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
645         __syncthreads();
646         if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
647           args_vec_t other = shared[address_base + offset];
648           #pragma unroll
649           for (int i = 0; i < output_vec_size; i++) {
650             value[i] = ops.combine(value[i], other[i]);
651           }
652           shared[address_base] = value;
653         }
654       }
655       dim_x = warpSize;
656     }
657 
658     __syncthreads();
659 
660     for (int offset = 1; offset < dim_x; offset <<= 1) {
661       #pragma unroll
662       for (int i = 0; i < output_vec_size; i++) {
663         arg_t other = ops.warp_shfl_down(value[i], offset);
664         value[i] = ops.combine(value[i], other);
665       }
666     }
667     return value;
668   }
669 
670   template <int output_vec_size>
block_y_reduceat::native::ReduceOp671   C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_y_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
672     using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
673     args_vec_t* shared = (args_vec_t*)shared_memory;
674     shared[config.shared_memory_offset(0)] = value;
675     for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
676       __syncthreads();
677       if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
678         args_vec_t other = shared[config.shared_memory_offset(offset)];
679         #pragma unroll
680         for (int i = 0; i < output_vec_size; i++) {
681           value[i] = ops.combine(value[i], other[i]);
682         }
683         shared[config.shared_memory_offset(0)] = value;
684       }
685     }
686     return value;
687   }
688 
mark_block_finishedat::native::ReduceOp689   C10_DEVICE bool mark_block_finished() const {
690     __shared__ bool is_last_block_done_shared;
691 
692     __syncthreads();
693     if (threadIdx.x == 0 && threadIdx.y == 0) {
694       int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
695       is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
696     }
697 
698     __syncthreads();
699 
700     return is_last_block_done_shared;
701   }
702 
703   template <int output_vec_size, bool can_acc>
accumulate_in_outputat::native::ReduceOp704   C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
705     at::detail::Array<out_scalar_t*, output_vec_size> out,
706     at::detail::Array<arg_t, output_vec_size> value,
707     typename std::enable_if<can_acc>::type* = nullptr
708   ) const {
709     at::detail::Array<arg_t, output_vec_size> ret;
710     #pragma unroll
711     for (int i = 0; i < output_vec_size; i++) {
712       ret[i] = ops.combine(*(out[i]), value[i]);
713     }
714     return ret;
715   }
716 
717   template <bool can_acc>
get_accumulated_outputat::native::ReduceOp718   C10_DEVICE out_scalar_t get_accumulated_output(
719     out_scalar_t* out, arg_t value,
720     typename std::enable_if<can_acc>::type* = nullptr
721   ) const {
722     CUDA_KERNEL_ASSERT(!final_output);
723     return (out_scalar_t)value;
724   }
725 
726   // This function should never be called --
727   // it's the version of `accumulate_in_output`
728   // when accumulation in the output is not possible.
729   template <int output_vec_size, bool can_acc>
accumulate_in_outputat::native::ReduceOp730   C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
731     at::detail::Array<out_scalar_t*, output_vec_size>,
732     at::detail::Array<arg_t, output_vec_size>,
733     typename std::enable_if<!can_acc>::type* = nullptr
734   ) const {
735     CUDA_KERNEL_ASSERT(false);
736     return arg_t {};
737   }
738 
739   // This function should never be called --
740   // it's the version of `get_accumulated_output`
741   // when accumulation in the output is not possible.
742   template <bool can_acc>
get_accumulated_outputat::native::ReduceOp743   C10_DEVICE out_scalar_t get_accumulated_output(
744     out_scalar_t* out, arg_t value,
745     typename std::enable_if<!can_acc>::type* = nullptr
746   ) const {
747     CUDA_KERNEL_ASSERT(false);
748     return *out;
749   }
750 
751   template<class T>
set_resultsat::native::ReduceOp752   C10_DEVICE void set_results(const T x, const index_t base_offset) const {
753     CUDA_KERNEL_ASSERT(noutputs == 1);
754     auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
755     *res = x;
756   }
757 
758   //Currently implemented for max of two outputs
759   template<class T1, class T2>
set_resultsat::native::ReduceOp760   C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
761     if (noutputs >= 1) {
762       auto res0 = (T1*)((char*)dst[0] + base_offset);
763       *res0 = x.first;
764     }
765     if (noutputs >= 2) {
766       // base offset is computed assuming element size being sizeof(T1), so we need to make a
767       // correction to obtain the correct base offset
768       auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
769       *res1 = x.second;
770     }
771   }
772 
773   template <int output_vec_size>
set_results_to_outputat::native::ReduceOp774   C10_DEVICE void set_results_to_output(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<index_t, output_vec_size> base_offset) const {
775     CUDA_KERNEL_ASSERT(final_output);
776     #pragma unroll
777     for (int i = 0; i < output_vec_size; i++) {
778       set_results(ops.project(value[i]), base_offset[i]);
779     }
780   }
781 
782   template <int output_vec_size>
global_reduceat::native::ReduceOp783   C10_DEVICE at::detail::Array<arg_t, output_vec_size> global_reduce(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
784     using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
785     using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
786     using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
787 
788     arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
789     index_t output_idx = config.output_idx<output_vec_size>();
790     offset_vec_t base_offsets;
791     out_ptr_vec_t out;
792 
793     #pragma unroll
794     for (int i = 0; i < output_vec_size; i++) {
795       base_offsets[i] = output_calc.get(output_idx + i)[0];
796       out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
797     }
798 
799     bool should_store = config.should_store(output_idx);
800     if (should_store) {
801       index_t offset = config.staging_memory_offset(blockIdx.y);
802       reduce_buffer[offset] = value;
803     }
804 
805     __threadfence(); // make sure writes are globally visible
806     __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
807     bool is_last_block_done = mark_block_finished();
808 
809     if (is_last_block_done) {
810       __threadfence(); // complete the acquire pattern after atomic
811       value = ident;
812       if (config.should_block_x_reduce()) {
813         index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
814         index_t step = blockDim.x * blockDim.y;
815         for (; input_offset < config.ctas_per_output; input_offset += step) {
816           index_t idx = config.staging_memory_offset(input_offset);
817           arg_vec_t next = reduce_buffer[idx];
818           #pragma unroll
819           for (int i = 0; i < output_vec_size; i++) {
820             value[i] = ops.combine(value[i], next[i]);
821           }
822         }
823       } else {
824         index_t input_offset = threadIdx.y;
825         index_t step = blockDim.y;
826         for (; input_offset < config.ctas_per_output; input_offset += step) {
827           index_t idx = config.staging_memory_offset(input_offset);
828           arg_vec_t next = reduce_buffer[idx];
829           #pragma unroll
830           for (int i = 0; i < output_vec_size; i++) {
831             value[i] = ops.combine(value[i], next[i]);
832           }
833         }
834       }
835       value = block_y_reduce(value, shared_memory);
836       if (config.should_block_x_reduce()) {
837         value = block_x_reduce<output_vec_size>(value, shared_memory);
838       }
839       if (should_store) {
840         if (accumulate) {
841           #pragma unroll
842           for (int i = 0; i < output_vec_size; i++) {
843             value[i] = ops.translate_idx(value[i], base_idx);
844           }
845         }
846 
847         if (acc == nullptr) {
848           if (accumulate) {
849             value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
850           }
851           if (final_output) {
852             set_results_to_output<output_vec_size>(value, base_offsets);
853           } else {
854             #pragma unroll
855             for (int i = 0; i < output_vec_size; i++) {
856               *(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
857             }
858           }
859         } else {
860           if (accumulate) {
861             #pragma unroll
862             for (int i = 0; i < output_vec_size; i++) {
863               value[i] = ops.combine((*acc)[i], value[i]);
864             }
865           }
866           if (final_output) {
867             set_results_to_output<output_vec_size>(value, base_offsets);
868           } else {
869             *acc = value;
870           }
871         }
872       }
873     }
874 
875     return value;
876   }
877 };
878 
879 template<int max_threads, typename R>
launch_reduce_kernel(const ReduceConfig & config,const R & reduction)880 static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
881   dim3 block = config.block();
882   dim3 grid = config.grid();
883 
884   auto stream = at::cuda::getCurrentCUDAStream();
885   int shared_memory = config.shared_memory_size();
886 
887   switch(config.output_vec_size) {
888   case 4:
889     reduce_kernel<max_threads / 4, 4, R><<<grid, block, shared_memory, stream>>>(reduction);
890     C10_CUDA_KERNEL_LAUNCH_CHECK();
891     break;
892   case 2:
893     reduce_kernel<max_threads / 2, 2, R><<<grid, block, shared_memory, stream>>>(reduction);
894     C10_CUDA_KERNEL_LAUNCH_CHECK();
895     break;
896   default:
897     reduce_kernel<max_threads / 1, 1, R><<<grid, block, shared_memory, stream>>>(reduction);
898     C10_CUDA_KERNEL_LAUNCH_CHECK();
899   }
900 }
901 
launch_jitted_reduce_kernel(std::mutex & jiterator_mutex,std::array<at::cuda::jit::NvrtcFunction,3> & fn_cache,const at::cuda::jit::KernelDescriptor & desc,int vt0,const ReduceConfig & config,void * reduction)902 inline void launch_jitted_reduce_kernel(
903     std::mutex &jiterator_mutex,
904     std::array<at::cuda::jit::NvrtcFunction, 3> &fn_cache,
905     const at::cuda::jit::KernelDescriptor &desc,
906     int vt0, const ReduceConfig& config, void *reduction) {
907   dim3 block = config.block();
908   dim3 grid = config.grid();
909 
910   int shared_memory = config.shared_memory_size();
911   at::cuda::jit::NvrtcFunction* fn_ptr;
912   switch(config.output_vec_size) {
913   case 4:
914     fn_ptr = &fn_cache[0];
915     break;
916   case 2:
917     fn_ptr = &fn_cache[1];
918     break;
919   default:
920     fn_ptr = &fn_cache[2];
921   }
922   if (!fn_ptr->function) {
923     int max_threads_codegen =
924         max_reduce_threads(desc.f_inputs_type) / config.output_vec_size;
925     auto code = at::cuda::jit::generate_reduction_code(
926         desc, vt0, true, false, config.output_vec_size, max_threads_codegen);
927 
928     *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_" + desc.name);
929   }
930   constexpr int kernel_args = 1;
931   void* args[kernel_args];
932   args[0] = reduction;
933   at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory);
934 }
935 
936 
937 class AccumulationBuffer {
938  public:
AccumulationBuffer()939   AccumulationBuffer() {}
940 
AccumulationBuffer(size_t acc_t_size,size_t out_t_size,char * out_ptr,int64_t size)941   AccumulationBuffer(size_t acc_t_size, size_t out_t_size, char* out_ptr, int64_t size) {
942     out_ptr_ = (char*)out_ptr;
943     if (out_t_size >= acc_t_size) {
944       // reusing output buffer for accumulation.
945       acc_ptr_ = (char*)out_ptr;
946       numerator_ = 1;
947       denominator_ = 1;
948     } else {
949       auto& allocator = *c10::cuda::CUDACachingAllocator::get();
950       buffer_ = allocator.allocate(size);
951       acc_ptr_ = (char*)buffer_.get();
952       numerator_ = acc_t_size;
953       denominator_ = out_t_size;
954       reduce_fraction(numerator_, denominator_);
955     }
956   }
957 
get_acc_slice(char * out_ptr)958   char* get_acc_slice(char* out_ptr) {
959     if (acc_ptr_ == nullptr) {
960       return nullptr;
961     }
962     return acc_ptr_ + ((out_ptr - out_ptr_) * numerator_ / denominator_);
963   }
964 
965  private:
966   char* acc_ptr_ = nullptr;
967   char* out_ptr_ = nullptr;
968   size_t numerator_;
969   size_t denominator_;
970   at::DataPtr buffer_;
971 };
972 
973 template <typename scalar_t>
get_output_vec_size(const TensorIterator & iter)974 int get_output_vec_size(const TensorIterator &iter) {
975   int vec_size = 4;
976   auto update_vec_size = [&vec_size](uint64_t n) {
977     while(n % vec_size != 0) {
978       vec_size /= 2;
979     }
980   };
981 
982   uint64_t base_address = reinterpret_cast<uint64_t>(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t);
983   update_vec_size(base_address);
984 
985   const int output_index = iter.num_reduce_dims();
986   update_vec_size(iter.shape()[output_index]);
987 
988   int j = 0;
989   for(auto i : iter.strides(iter.noutputs())) {
990     if (j != output_index) {
991       update_vec_size(i / sizeof(scalar_t));
992     }
993     j++;
994   }
995   return vec_size;
996 }
997 
998 template<typename arg_t, typename scalar_t, int vt0>
setReduceConfig(const TensorIterator & iter)999 ReduceConfig setReduceConfig(const TensorIterator& iter){
1000   // Start by assuming that each thread handles a single output and all
1001   // the inputs for that output.
1002   int64_t num_outputs = iter.num_output_elements();
1003   int64_t inputs_per_output = iter.numel() / num_outputs;
1004   int input_index = iter.ntensors() - 1;
1005 
1006   auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output);
1007 
1008   int64_t dim0;
1009   int64_t dim1;
1010   int64_t fastest_moving_stride;
1011   bool reduction_on_fastest_striding_dimension;
1012 
1013   if (iter.ndim() > 0) {
1014     // Adjust block size to map block width to fastest changing dimension of input
1015     // tensor. This grants the best possible memory accessing pattern, given that
1016     // for non-contiguous tensor with space in between, we cannot have perfect
1017     // memory coalescing.
1018     reduction_on_fastest_striding_dimension =
1019         (iter.num_reduce_dims() == iter.ndim()) ||
1020         (iter.strides(/*arg=*/input_index)[0] <
1021         iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]);
1022     // Notice that dim0 & dim1 does NOT guarantee any launch configuration here!
1023     // dim0 & dim1 are more like the upper bound of the block dimension. The
1024     // actual launch config and reduction scheme is determined by setting values
1025     // to `config.input_mult` and `config.output_mult`.
1026     // We try to max out dim1 so that we have enough threads per CTA to deliver
1027     // performance for larger problem size.
1028     if (reduction_on_fastest_striding_dimension) {
1029       // Map block.x to the fastest reducing dimension. It implies:
1030       //   1. block_x_reduce is required.
1031       //   2. block.y now max out to num_outputs.
1032       dim0 = inputs_per_output;
1033       dim1 = num_outputs;
1034       fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
1035     } else {
1036       // Map block.x to the fastest non reducing dimension. It implies:
1037       //   1. block_x_reduce is turned off.
1038       //   2. block.y now max out to inputs_per_output.
1039       dim0 = num_outputs;
1040       dim1 = inputs_per_output;
1041       fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
1042     }
1043   } else {
1044     reduction_on_fastest_striding_dimension = true;
1045     fastest_moving_stride = sizeof(scalar_t);
1046     dim0 = 1;
1047     dim1 = 1;
1048   }
1049 
1050   // We do vectorization to gain better memory access, there are two cases which we call
1051   // "vectorize along input" and "vectorize along output". Note that the "input/output"
1052   // here does not mean we are vectorizing load/store instructions. We always only vectorize
1053   // load instructions.
1054   //
1055   // Case 1: "vectorize along input"
1056   // This case happens when we are reducing along fastest moving dimesion. In such case, threads
1057   // with the same threadIdx.y works on the same reduction cooperatively and will produce results
1058   // for the same output. In such case, values in each loaded vector always correspond to the same output.
1059   //
1060   // Case 2: "vectorize along output"
1061   // This case happens when the fastest moving dimesion is not the dimension of reduction. In such case,
1062   // threads with different threadIdx.x are independent and will produce results for different outputs.
1063   // In such case, values in each loaded vector always correspond to different outputs.
1064   if (fastest_moving_stride == sizeof(scalar_t)) {
1065     if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= ReduceConfig::input_vec_size) {
1066       // Case 1: "vectorize along input"
1067       // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
1068       // we should avoid vectorization.
1069       config.vectorize_input = true;
1070       dim0 /= config.input_vec_size;
1071     } else if (!reduction_on_fastest_striding_dimension) {
1072       // Case 2: "vectorize along output"
1073       config.output_vec_size = get_output_vec_size<scalar_t>(iter);
1074       dim0 /= config.output_vec_size;
1075     }
1076   }
1077 
1078   // Adjust block_width and block_height
1079   config.set_block_dimension<scalar_t>(dim0, dim1);
1080 
1081   int block_width = config.block_width;
1082   int block_height = config.block_height;
1083 
1084   if (iter.ndim() == 0 || reduction_on_fastest_striding_dimension) {
1085     // Split the input across lanes if the input is contiguous in the reduced
1086     // dimension. This will require reduction between threads using warp
1087     // shuffle instructions and shared memory (if block_width > warpSize).
1088     config.input_mult[0] = config.split_input(block_width);
1089   } else {
1090     // Otherwise split the output across lanes in a warp.
1091     config.output_mult[0] = config.split_output(block_width);
1092   }
1093 
1094   constexpr int min_values_per_thread = 16;
1095   constexpr int max_values_per_thread = 256;
1096 
1097   if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) {
1098     // Divide the input across warps in a thread-block, if that leaves at least
1099     // 16 elements to be summed by each thread. This will require inter-warp
1100     // reduction using shared memory.
1101     config.input_mult[1] = config.split_input(block_height);
1102   } else {
1103     // Otherwise, each warp handles a separate output.
1104     config.output_mult[1] = config.split_output(block_height);
1105   }
1106 
1107   const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / config.num_threads;
1108   const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1109   const int target_grid_size = num_mp * blocks_per_sm;
1110   int grid = config.grid().x;
1111   if (config.input_mult[1] != 0 && config.values_per_thread() >= max_values_per_thread && grid <= target_grid_size) {
1112     // Divide the input across thread-blocks if the amount of work per-thread
1113     // is large enough and the size of the output is small enough. This will
1114     // require a reduction using global memory.
1115     // If we decide to split input across blocks, as long as we can get enough
1116     // number of blocks (`target_grid_size`) to balance SM, we should still
1117     // make the number of values per thread large for best performance.
1118     int ctas_per_output1 = div_up(target_grid_size, grid);
1119     int ctas_per_output2 = div_up(config.values_per_thread(), min_values_per_thread);
1120     int ctas_per_output3 = div_up(config.values_per_thread(), max_values_per_thread);
1121     // We want the minimum of ctas_per_output1 and ctas_per_output2, so that each thread can have
1122     // a large number of values to deal with. But we don't want values_per_thread to be larger than
1123     // max_values_per_thread
1124     config.ctas_per_output = std::max(std::min<int>(ctas_per_output1, ctas_per_output2), ctas_per_output3);
1125     if (config.ctas_per_output > 1) {
1126       config.input_mult[2] = config.split_input(config.ctas_per_output);
1127     }
1128   }
1129   return config;
1130 };
1131 
1132 template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
gpu_reduce_kernel(TensorIterator & iter,const ops_t & ops,ident_t ident=0,AccumulationBuffer * acc_buf_ptr=nullptr,int64_t base_idx=0)1133 inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
1134                               AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
1135   AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
1136 
1137   using traits = function_traits<decltype(&ops_t::reduce)>;
1138   using arg_t = typename traits::template arg<0>::type;
1139   // at::Half/at::ComplexHalf overflows easily as it's range is very small.
1140   // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
1141   // set can_accumulate_in_output to False.
1142   static constexpr bool is_inp_out_type_half_or_chalf =
1143       (std::is_same<at::Half, scalar_t>::value &&
1144        std::is_same<at::Half, out_scalar_t>::value) ||
1145       (std::is_same<c10::complex<Half>, scalar_t>::value &&
1146        std::is_same<c10::complex<Half>, out_scalar_t>::value);
1147   // at::BFloat16 has lower precision and can lead to rounding errors.
1148   // So when scalar_t and out_scalar_t are at::BFloat16, we
1149   // set can_accumulate_in_output to False.
1150   static constexpr bool is_inp_out_type_bfloat16 =
1151       (std::is_same<at::BFloat16, scalar_t>::value &&
1152        std::is_same<at::BFloat16, out_scalar_t>::value);
1153   static constexpr bool can_accumulate_in_output =
1154       std::is_convertible<arg_t, out_scalar_t>::value &&
1155       !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);
1156 
1157   bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
1158   std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
1159   // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
1160   // reused by all recursive function calls.
1161   if (acc_buf_ptr == NULL) {
1162     // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
1163     // when accumulation in output is not possible.
1164     if (!can_accumulate_in_output && !can_use_32bit_indexing) {
1165       int64_t output_memory_size = iter.element_size(0);
1166       for (int dim = 0; dim < iter.ndim(); dim++) {
1167         output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
1168       }
1169       output_memory_size /= iter.element_size(0); //iter.strides is in bytes
1170       owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
1171                                                  sizeof(out_scalar_t),
1172                                                  (char*) iter.data_ptr(0),
1173                                                  output_memory_size * sizeof(arg_t)));
1174     } else {
1175       owned_buf_ptr.reset(new AccumulationBuffer());
1176     }
1177     acc_buf_ptr = owned_buf_ptr.get();
1178   }
1179 
1180   if (!can_use_32bit_indexing) {
1181     for (auto& sub_iter : iter.with_32bit_indexing()) {
1182       int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
1183 
1184       gpu_reduce_kernel<scalar_t, out_scalar_t, vt0>(sub_iter, ops, ident,
1185           acc_buf_ptr, sub_iter_base_idx);
1186     }
1187     return;
1188   }
1189 
1190   const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
1191   char* out_data = (char*)iter.data_ptr(0);
1192   const auto noutputs = iter.noutputs();
1193   std::optional<char*> out_data_extra;
1194   if (noutputs > 1) {
1195     out_data_extra = (char*)iter.data_ptr(1);
1196   } else {
1197     out_data_extra = std::nullopt;
1198   }
1199   char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
1200 
1201   ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
1202   at::DataPtr buffer;
1203   at::DataPtr semaphores;
1204   if (config.should_global_reduce()) {
1205     auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1206     buffer = allocator.allocate(config.global_memory_size());
1207     semaphores = allocator.allocate(config.semaphore_size());
1208 
1209     auto stream = at::cuda::getCurrentCUDAStream();
1210     AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
1211   }
1212 
1213   AT_ASSERT(can_use_32bit_indexing);
1214   auto output_calc = make_output_calculator<uint32_t>(iter);
1215   auto input_calc = make_input_calculator<uint32_t>(iter);
1216   auto reduce = ReduceOp<scalar_t, ops_t, uint32_t, out_scalar_t, vt0>(
1217       ops,
1218       config,
1219       input_calc,
1220       output_calc,
1221       in_data,
1222       out_data,
1223       out_data_extra,
1224       acc_data,
1225       buffer.get(),
1226       (int*)semaphores.get(),
1227       ident,
1228       noutputs,
1229       base_idx);
1230   reduce.accumulate = iter.should_accumulate();
1231   reduce.final_output = iter.is_final_output();
1232 
1233   launch_reduce_kernel<mnt_wrapper<scalar_t>::MAX_NUM_THREADS>(config, reduce);
1234 }
1235 
1236 //TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function
1237 //try unifying with gpu_reduce_kernel
1238 template <char const* name, typename scalar_t, typename out_scalar_t, int vt0=4, typename ident_t=double>
jitted_gpu_reduce_kernel(TensorIterator & iter,const std::string & func,ident_t ident=0,AccumulationBuffer * acc_buf_ptr=nullptr,int64_t base_idx=0)1239 inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0,
1240                               AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
1241   AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
1242 
1243   //TODO - this will be different for more complicated reductions, but for now reductions using
1244   //func_wrapper all have arg_t = opmath
1245   using arg_t = at::opmath_type<scalar_t>;
1246   // at::Half/at::ComplexHalf overflows easily as it's range is very small.
1247   // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
1248   // set can_accumulate_in_output to False.
1249   static constexpr bool is_inp_out_type_half_or_chalf =
1250       (std::is_same<at::Half, scalar_t>::value &&
1251        std::is_same<at::Half, out_scalar_t>::value) ||
1252       (std::is_same<c10::complex<Half>, scalar_t>::value &&
1253        std::is_same<c10::complex<Half>, out_scalar_t>::value);
1254   // at::BFloat16 has lower precision and can lead to rounding errors.
1255   // So when scalar_t and out_scalar_t are at::BFloat16, we
1256   // set can_accumulate_in_output to False.
1257   static constexpr bool is_inp_out_type_bfloat16 =
1258       (std::is_same<at::BFloat16, scalar_t>::value &&
1259        std::is_same<at::BFloat16, out_scalar_t>::value);
1260   static constexpr bool can_accumulate_in_output =
1261       std::is_convertible<arg_t, out_scalar_t>::value &&
1262       !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);
1263 
1264   bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
1265   std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
1266 
1267   // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
1268   // reused by all recursive function calls.
1269   if (acc_buf_ptr == NULL) {
1270     // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
1271     // when accumulation in output is not possible.
1272     if (!can_accumulate_in_output && !can_use_32bit_indexing) {
1273       int64_t output_memory_size = iter.element_size(0);
1274       for (int dim = 0; dim < iter.ndim(); dim++) {
1275         output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
1276       }
1277       output_memory_size /= iter.element_size(0); //iter.strides is in bytes
1278       owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO
1279                                                  sizeof(out_scalar_t),
1280                                                  (char*) iter.data_ptr(0),
1281                                                  output_memory_size * sizeof(out_scalar_t))); //TODO
1282     } else {
1283       owned_buf_ptr.reset(new AccumulationBuffer());
1284     }
1285     acc_buf_ptr = owned_buf_ptr.get();
1286   }
1287 
1288   if (!can_use_32bit_indexing) {
1289     for (auto& sub_iter : iter.with_32bit_indexing()) {
1290       int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
1291 
1292       jitted_gpu_reduce_kernel<name, scalar_t, out_scalar_t, vt0>(sub_iter, func, ident,
1293           acc_buf_ptr, sub_iter_base_idx);
1294     }
1295     return;
1296   }
1297 
1298   //TODO - for now we support a single input, we may be able to relax this constraint
1299   const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
1300   char* out_data = (char*)iter.data_ptr(0);
1301   const auto noutputs = iter.noutputs();
1302   std::optional<char*> out_data_extra;
1303   if (noutputs > 1) {
1304     out_data_extra = (char*)iter.data_ptr(1);
1305   } else {
1306     out_data_extra = std::nullopt;
1307   }
1308   char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
1309 
1310   ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
1311 
1312   at::DataPtr buffer;
1313   at::DataPtr semaphores;
1314   if (config.should_global_reduce()) {
1315     auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1316     buffer = allocator.allocate(config.global_memory_size());
1317     semaphores = allocator.allocate(config.semaphore_size());
1318 
1319     auto stream = at::cuda::getCurrentCUDAStream();
1320     AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
1321   }
1322 
1323   AT_ASSERT(can_use_32bit_indexing);
1324   auto output_calc = make_output_calculator<uint32_t>(iter);
1325   auto input_calc = make_input_calculator<uint32_t>(iter);
1326   auto reduce = ReduceJitOp<scalar_t, out_scalar_t>(
1327       config,
1328       input_calc,
1329       output_calc,
1330       in_data,
1331       out_data,
1332       out_data_extra,
1333       acc_data,
1334       buffer.get(),
1335       (int*)semaphores.get(),
1336       ident,
1337       noutputs,
1338       base_idx);
1339   reduce.accumulate = iter.should_accumulate();
1340   reduce.final_output = iter.is_final_output();
1341 
1342   constexpr int nInputs = 1;
1343   constexpr int nOutputs = 1;
1344   static auto desc = at::cuda::jit::make_kernel_descriptor<
1345     out_scalar_t, scalar_t>(name, func, nInputs, nOutputs);
1346 
1347   static std::mutex jiterator_mutex;
1348   static std::vector<std::array<at::cuda::jit::NvrtcFunction, 3>> fn_cache(c10::cuda::device_count());
1349   auto &cache = fn_cache[iter.device().index()];
1350 
1351   launch_jitted_reduce_kernel(
1352       jiterator_mutex, cache, desc, vt0, config, &reduce);
1353 }
1354 
1355 }} // namespace at::native
1356