xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SharedReduceOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 // Please note that this file is
3 // used across both CPU and GPU.
4 
5 #include <type_traits>
6 #include <complex>
7 #include <c10/macros/Macros.h>
8 #include <ATen/detail/FunctionTraits.h>
9 #include <ATen/NumericUtils.h>
10 #if defined(__CUDACC__)
11 #include <ATen/cuda/DeviceUtils.cuh>
12 #include <ATen/native/cuda/DeviceSqrt.cuh>
13 #elif defined(__HIPCC__)
14 #include <ATen/hip/DeviceUtils.cuh>
15 #include <ATen/native/hip/DeviceSqrt.cuh>
16 #endif
17 #if defined(__CUDACC__) || defined(__HIPCC__)
18 #include <thrust/pair.h>
19 #else
20 #include <cmath>
21 #define device_sqrt std::sqrt
22 #endif
23 #if defined(__CUDACC__) || defined(__HIPCC__)
24 template <typename scalar_t>
max_propagate_nan(scalar_t a,scalar_t b)25 inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
26 #if defined(__HIPCC__)
27   // TODO: remove this special case for HIP when issue is fixed:
28   //       https://github.com/ROCm-Developer-Tools/HIP/issues/2209
29   scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
30 #else
31   scalar_t max = at::_isnan(b) ? b : std::max(a, b);
32 #endif
33   return max;
34 }
35 template <typename scalar_t>
min_propagate_nan(scalar_t a,scalar_t b)36 inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
37 #if defined(__HIPCC__)
38   // TODO: remove this special case for HIP when issue is fixed:
39   //       https://github.com/ROCm-Developer-Tools/HIP/issues/2209
40   scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
41 #else
42   scalar_t min = at::_isnan(b) ? b : std::min(a, b);
43 #endif
44   return min;
45 }
46 #define MAX(X, Y) max_propagate_nan(X,Y)
47 #define MIN(X, Y) min_propagate_nan(X,Y)
48 #else
49 #include <ATen/native/cpu/zmath.h>
50 #define MAX(X, Y) max_impl(X,Y)
51 #define MIN(X, Y) min_impl(X,Y)
52 #endif
53 
54 // ROCM hcc doesn't work well with using std:: in kernel functions
55 #if defined(__CUDA_ARCH__)
56 #include <c10/cuda/CUDAMathCompat.h>
57 #define compat_pow c10::cuda::compat::pow
58 #elif defined(__HIPCC__)
59 #include <c10/hip/HIPMathCompat.h>
60 #define compat_pow c10::hip::compat::pow
61 #else
62 #define compat_pow std::pow
63 #endif
64 
65 namespace at { namespace native {
66 
67 namespace detail {
68 
69 #if defined(__CUDACC__) || defined(__HIPCC__)
70 template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
71 #else
72 template <typename T1, typename T2> using pair = std::pair<T1, T2>;
73 #endif
74 
75 } // namespace detail
76 
77 template <typename scalar_t, typename index_t>
78 struct WelfordData {
79   scalar_t mean;
80   scalar_t m2;
81   index_t n;
82   scalar_t nf;
83 
WelfordDataWelfordData84   C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
85 
WelfordDataWelfordData86   C10_HOST_DEVICE WelfordData(
87       scalar_t mean,
88       scalar_t m2,
89       index_t n,
90       scalar_t nf)
91       : mean(mean), m2(m2), n(n), nf(nf) {}
92 };
93 
94 
95 template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
96 struct WelfordOps {
97   acc_scalar_t correction;
98   bool take_sqrt;
99  public:
100   using acc_t = WelfordData<acc_scalar_t, index_t>;
reduceWelfordOps101   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
102     // We accumulate n in index_t to avoid cumulative rounding error, but still
103     // need nf for use in combine where int32 may overflow.
104     index_t new_n = acc.n + 1;
105     acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
106     acc_scalar_t delta = data - acc.mean;
107     acc_scalar_t new_mean = acc.mean + delta / new_nf;
108     acc_scalar_t new_delta = data - new_mean;
109     return {
110       new_mean,
111       acc.m2 + delta * new_delta,
112       new_n,
113       new_nf,
114     };
115   }
combineWelfordOps116   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
117     if (a.nf == 0) {
118       return b;
119     }
120     if (b.nf == 0) {
121       return a;
122     }
123     acc_scalar_t delta = b.mean - a.mean;
124     acc_scalar_t new_count = a.nf + b.nf;
125     acc_scalar_t nb_over_n = b.nf / new_count;
126     return {
127       a.mean + delta * nb_over_n,
128       a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
129       // setting acc.n as -1 since acc.n might not be able to represent the count
130       // correctly within its range, setting it to -1 to avoid confusion
131       -1,
132       new_count
133     };
134   }
projectWelfordOps135   inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
136     const auto mean = static_cast<scalar_t>(acc.mean);
137     const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
138     const auto var = acc.m2 / divisor;
139     res_t results(take_sqrt ? device_sqrt(var) : var, mean);
140     return results;
141   }
142 
translate_idxWelfordOps143   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
144     return acc;
145   }
146 
147 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downWelfordOps148   inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
149     return {
150       WARP_SHFL_DOWN(acc.mean, offset)
151       , WARP_SHFL_DOWN(acc.m2, offset)
152       , WARP_SHFL_DOWN(acc.n, offset)
153       , WARP_SHFL_DOWN(acc.nf, offset)
154     };
155   }
156 #endif
WelfordOpsWelfordOps157   C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
158       : correction(correction), take_sqrt(take_sqrt) {}
159 };
160 
161 template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
162 struct MeanOps {
163   factor_t factor;
164 
reduceMeanOps165   inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
166     return combine(a, static_cast<acc_t>(b));
167   }
168 
combineMeanOps169   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
170     return a + b;
171   }
172 
projectMeanOps173   inline C10_DEVICE out_t project(acc_t a) const {
174     return a * factor;
175   }
176 
translate_idxMeanOps177   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
178     return acc;
179   }
180 
181 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downMeanOps182   inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
183     return WARP_SHFL_DOWN(data, offset);
184   }
185 #endif
186 
MeanOpsMeanOps187   MeanOps(factor_t factor): factor(factor) {
188   }
189 };
190 
191 // This accumulator template is used to calculate the minimum absolute value of
192 // a set of numbers.
193 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
194 // value. These types differ for complex number input support.
195 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
196 struct AbsMinOps {
197 
reduceAbsMinOps198   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
199     return MIN(acc, static_cast<acc_t>(std::abs(data)));
200   }
201 
combineAbsMinOps202   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
203     return MIN(a, b);
204   }
205 
projectAbsMinOps206   inline C10_DEVICE out_t project(acc_t a) const {
207     return a;
208   }
209 
translate_idxAbsMinOps210   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
211     return acc;
212   }
213 
214 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downAbsMinOps215   inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
216     return WARP_SHFL_DOWN(acc, offset);
217   }
218 #endif
219 };
220 
221 // This accumulator template is used to calculate the maximum absolute value of
222 // a set of numbers.
223 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
224 // value. These types differ for complex number input support.
225 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
226 struct AbsMaxOps {
reduceAbsMaxOps227   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
228     return MAX(acc, static_cast<acc_t>(std::abs(data)));
229   }
230 
combineAbsMaxOps231   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
232     return MAX(a, b);
233   }
234 
projectAbsMaxOps235   inline C10_DEVICE out_t project(acc_t a) const {
236     return a;
237   }
238 
translate_idxAbsMaxOps239   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
240     return acc;
241   }
242 
243 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downAbsMaxOps244   inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
245     return WARP_SHFL_DOWN(acc, offset);
246   }
247 #endif
248 };
249 
250 // This accumulator template is used to calculate the norm of the absolute value
251 // of a set of numbers.
252 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
253 // value. These types differ for complex number input support.
254 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
255 struct NormOps {
256   acc_t norm_;
257 
reduceNormOps258   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
259     return acc + compat_pow(static_cast<acc_t>(std::abs(data)), norm_);
260   }
261 
combineNormOps262   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
263     return a + b;
264   }
265 
projectNormOps266   inline C10_DEVICE out_t project(acc_t a) const {
267     return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
268   }
269 
translate_idxNormOps270   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
271     return acc;
272   }
273 
274 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNormOps275   inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
276     return WARP_SHFL_DOWN(acc, offset);
277   }
278 #endif
279 
NormOpsNormOps280   NormOps(acc_t norm_): norm_(norm_) {
281   }
282 };
283 
284 // This accumulator template is used to calculate the order zero norm of the
285 // absolute value of a set of numbers.
286 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
287 // value. These types differ for complex number input support.
288 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
289 struct NormZeroOps {
reduceNormZeroOps290   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
291     return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
292   }
293 
combineNormZeroOps294   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
295     return a + b;
296   }
297 
projectNormZeroOps298   inline C10_DEVICE out_t project(acc_t a) const {
299     return a;
300   }
301 
translate_idxNormZeroOps302   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
303     return acc;
304   }
305 
306 
307 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNormZeroOps308   inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
309     return WARP_SHFL_DOWN(acc, offset);
310   }
311 #endif
312 };
313 
314 // This accumulator template is used to calculate the order one norm of the
315 // absolute value of a set of numbers.
316 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
317 // value. These types differ for complex number input support.
318 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
319 struct NormOneOps {
reduceNormOneOps320   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
321     return acc + static_cast<acc_t>(std::abs(data));
322   }
323 
combineNormOneOps324   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
325     return a + b;
326   }
327 
projectNormOneOps328   inline C10_DEVICE out_t project(acc_t a) const {
329     return a;
330   }
331 
translate_idxNormOneOps332   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
333     return acc;
334   }
335 
336 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNormOneOps337   inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
338     return WARP_SHFL_DOWN(acc, offset);
339   }
340 #endif
341 };
342 
343 
344 template<typename acc_t>
345 struct AbsSwitch {};
346 
347 template<typename scalar_t, typename acc_t>
abs_if_complex(scalar_t data,AbsSwitch<acc_t>)348 inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
349   return static_cast<acc_t>(data);
350 }
351 
352 template<typename scalar_t, typename acc_t>
abs_if_complex(std::complex<scalar_t> data,AbsSwitch<acc_t>)353 inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
354   return static_cast<acc_t>(std::abs(data));
355 }
356 
357 template<typename scalar_t, typename acc_t>
abs_if_complex(c10::complex<scalar_t> data,AbsSwitch<acc_t>)358 inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
359   return static_cast<acc_t>(std::abs(data));
360 }
361 
362 // This accumulator template is used to calculate the order two norm of the
363 // absolute value of a set of numbers.
364 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
365 // value. These types differ for complex number input support.
366 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
367 struct NormTwoOps {
reduceNormTwoOps368   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
369     acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
370     return acc + data_ * data_;
371   }
372 
combineNormTwoOps373   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
374     return a + b;
375   }
376 
projectNormTwoOps377   inline C10_DEVICE out_t project(acc_t a) const {
378     return device_sqrt(a);
379   }
380 
translate_idxNormTwoOps381   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
382     return acc;
383   }
384 
385 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNormTwoOps386   inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
387     return WARP_SHFL_DOWN(acc, offset);
388   }
389 #endif
390 };
391 
392 template <typename acc_t, typename data_t>
393 struct NanSumOps {
reduceNanSumOps394   inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
395     return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
396   }
397 
combineNanSumOps398   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
399     return  a + b;
400   }
401 
projectNanSumOps402   inline C10_DEVICE data_t project(acc_t a) const {
403     return data_t{a};
404   }
405 
translate_idxNanSumOps406   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
407     return acc;
408   }
409 
410 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNanSumOps411   inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
412     return WARP_SHFL_DOWN(data, offset);
413   }
414 #endif
415 };
416 
417 namespace detail {
418 
419 template <typename scalar_t>
420 struct LessOrNan {
operatorLessOrNan421   C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
422     // If (a == b), then choose the one with lower idx, else min(a, b)
423     if (at::_isnan(a)) {
424       if (at::_isnan(b)) {
425         return idx_a < idx_b;
426       }
427       return true;
428     }
429     return (a == b) ? idx_a < idx_b : (a < b);
430   }
431 };
432 
433 template <typename scalar_t>
434 struct GreaterOrNan {
operatorGreaterOrNan435   C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
436     // If (a == b), then choose the one with lower idx, else max(a, b)
437     if (at::_isnan(a)) {
438       if (at::_isnan(b)) {
439         return idx_a < idx_b;
440       }
441       return true;
442     }
443     return (a == b) ? idx_a < idx_b : (a > b);
444   }
445 };
446 
447 template <typename comp_t>
448 struct MinMaxReductionOps {
449   using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
450   using index_t = int64_t;
451   using arg_t = detail::pair<scalar_t, index_t>;
452 
projectMinMaxReductionOps453   static C10_DEVICE arg_t project(arg_t arg) {
454     return arg;
455   }
456 
reduceMinMaxReductionOps457   static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
458     return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
459   }
460 
combineMinMaxReductionOps461   static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
462     return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
463   }
464 
translate_idxMinMaxReductionOps465   static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
466     return {a.first, a.second + base_idx};
467   }
468 
469 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downMinMaxReductionOps470   static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
471     return arg_t(WARP_SHFL_DOWN(arg.first, offset),
472                  WARP_SHFL_DOWN(arg.second, offset));
473   }
474 #endif
475 };
476 
477 template <typename comp_t>
478 struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
479   using typename MinMaxReductionOps<comp_t>::scalar_t;
480   using typename MinMaxReductionOps<comp_t>::index_t;
481   using typename MinMaxReductionOps<comp_t>::arg_t;
482 
projectArgReductionOps483   static C10_DEVICE index_t project(arg_t arg) {
484     return arg.second;
485   }
486 };
487 
488 } // namespace detail
489 
490 template <typename scalar_t>
491 struct ArgMaxOps :
492   public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
493 };
494 
495 template <typename scalar_t>
496 struct ArgMinOps :
497   public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
498 };
499 
500 template <typename scalar_t>
501 struct MinOps :
502   public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
503 };
504 
505 template <typename scalar_t>
506 struct MaxOps :
507   public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
508 };
509 
510 template <typename scalar_t, typename acc_scalar_t, typename index_t>
511 struct MinMaxOps {
512   using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
reduceMinMaxOps513   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
514     return combine(acc, {data, data});
515   }
516 
combineMinMaxOps517   inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
518     auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
519     auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
520 
521     return {min_val, max_val};
522   }
523 
projectMinMaxOps524   inline C10_DEVICE acc_t project(acc_t acc) const {
525     return acc;
526   }
527 
translate_idxMinMaxOps528   static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
529     return acc;
530   }
531 
532 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downMinMaxOps533   inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
534     return {
535       WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
536     };
537   }
538 #endif
539 };
540 
541 }} // namespace at::native
542 
543 #undef MAX
544 #undef MIN
545