1 #pragma once
2 // Please note that this file is
3 // used across both CPU and GPU.
4
5 #include <type_traits>
6 #include <complex>
7 #include <c10/macros/Macros.h>
8 #include <ATen/detail/FunctionTraits.h>
9 #include <ATen/NumericUtils.h>
10 #if defined(__CUDACC__)
11 #include <ATen/cuda/DeviceUtils.cuh>
12 #include <ATen/native/cuda/DeviceSqrt.cuh>
13 #elif defined(__HIPCC__)
14 #include <ATen/hip/DeviceUtils.cuh>
15 #include <ATen/native/hip/DeviceSqrt.cuh>
16 #endif
17 #if defined(__CUDACC__) || defined(__HIPCC__)
18 #include <thrust/pair.h>
19 #else
20 #include <cmath>
21 #define device_sqrt std::sqrt
22 #endif
23 #if defined(__CUDACC__) || defined(__HIPCC__)
24 template <typename scalar_t>
max_propagate_nan(scalar_t a,scalar_t b)25 inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
26 #if defined(__HIPCC__)
27 // TODO: remove this special case for HIP when issue is fixed:
28 // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
29 scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
30 #else
31 scalar_t max = at::_isnan(b) ? b : std::max(a, b);
32 #endif
33 return max;
34 }
35 template <typename scalar_t>
min_propagate_nan(scalar_t a,scalar_t b)36 inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
37 #if defined(__HIPCC__)
38 // TODO: remove this special case for HIP when issue is fixed:
39 // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
40 scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
41 #else
42 scalar_t min = at::_isnan(b) ? b : std::min(a, b);
43 #endif
44 return min;
45 }
46 #define MAX(X, Y) max_propagate_nan(X,Y)
47 #define MIN(X, Y) min_propagate_nan(X,Y)
48 #else
49 #include <ATen/native/cpu/zmath.h>
50 #define MAX(X, Y) max_impl(X,Y)
51 #define MIN(X, Y) min_impl(X,Y)
52 #endif
53
54 // ROCM hcc doesn't work well with using std:: in kernel functions
55 #if defined(__CUDA_ARCH__)
56 #include <c10/cuda/CUDAMathCompat.h>
57 #define compat_pow c10::cuda::compat::pow
58 #elif defined(__HIPCC__)
59 #include <c10/hip/HIPMathCompat.h>
60 #define compat_pow c10::hip::compat::pow
61 #else
62 #define compat_pow std::pow
63 #endif
64
65 namespace at { namespace native {
66
67 namespace detail {
68
69 #if defined(__CUDACC__) || defined(__HIPCC__)
70 template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
71 #else
72 template <typename T1, typename T2> using pair = std::pair<T1, T2>;
73 #endif
74
75 } // namespace detail
76
77 template <typename scalar_t, typename index_t>
78 struct WelfordData {
79 scalar_t mean;
80 scalar_t m2;
81 index_t n;
82 scalar_t nf;
83
WelfordDataWelfordData84 C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
85
WelfordDataWelfordData86 C10_HOST_DEVICE WelfordData(
87 scalar_t mean,
88 scalar_t m2,
89 index_t n,
90 scalar_t nf)
91 : mean(mean), m2(m2), n(n), nf(nf) {}
92 };
93
94
95 template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
96 struct WelfordOps {
97 acc_scalar_t correction;
98 bool take_sqrt;
99 public:
100 using acc_t = WelfordData<acc_scalar_t, index_t>;
reduceWelfordOps101 inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
102 // We accumulate n in index_t to avoid cumulative rounding error, but still
103 // need nf for use in combine where int32 may overflow.
104 index_t new_n = acc.n + 1;
105 acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
106 acc_scalar_t delta = data - acc.mean;
107 acc_scalar_t new_mean = acc.mean + delta / new_nf;
108 acc_scalar_t new_delta = data - new_mean;
109 return {
110 new_mean,
111 acc.m2 + delta * new_delta,
112 new_n,
113 new_nf,
114 };
115 }
combineWelfordOps116 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
117 if (a.nf == 0) {
118 return b;
119 }
120 if (b.nf == 0) {
121 return a;
122 }
123 acc_scalar_t delta = b.mean - a.mean;
124 acc_scalar_t new_count = a.nf + b.nf;
125 acc_scalar_t nb_over_n = b.nf / new_count;
126 return {
127 a.mean + delta * nb_over_n,
128 a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
129 // setting acc.n as -1 since acc.n might not be able to represent the count
130 // correctly within its range, setting it to -1 to avoid confusion
131 -1,
132 new_count
133 };
134 }
projectWelfordOps135 inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
136 const auto mean = static_cast<scalar_t>(acc.mean);
137 const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
138 const auto var = acc.m2 / divisor;
139 res_t results(take_sqrt ? device_sqrt(var) : var, mean);
140 return results;
141 }
142
translate_idxWelfordOps143 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
144 return acc;
145 }
146
147 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downWelfordOps148 inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
149 return {
150 WARP_SHFL_DOWN(acc.mean, offset)
151 , WARP_SHFL_DOWN(acc.m2, offset)
152 , WARP_SHFL_DOWN(acc.n, offset)
153 , WARP_SHFL_DOWN(acc.nf, offset)
154 };
155 }
156 #endif
WelfordOpsWelfordOps157 C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
158 : correction(correction), take_sqrt(take_sqrt) {}
159 };
160
161 template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
162 struct MeanOps {
163 factor_t factor;
164
reduceMeanOps165 inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
166 return combine(a, static_cast<acc_t>(b));
167 }
168
combineMeanOps169 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
170 return a + b;
171 }
172
projectMeanOps173 inline C10_DEVICE out_t project(acc_t a) const {
174 return a * factor;
175 }
176
translate_idxMeanOps177 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
178 return acc;
179 }
180
181 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downMeanOps182 inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
183 return WARP_SHFL_DOWN(data, offset);
184 }
185 #endif
186
MeanOpsMeanOps187 MeanOps(factor_t factor): factor(factor) {
188 }
189 };
190
191 // This accumulator template is used to calculate the minimum absolute value of
192 // a set of numbers.
193 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
194 // value. These types differ for complex number input support.
195 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
196 struct AbsMinOps {
197
reduceAbsMinOps198 inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
199 return MIN(acc, static_cast<acc_t>(std::abs(data)));
200 }
201
combineAbsMinOps202 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
203 return MIN(a, b);
204 }
205
projectAbsMinOps206 inline C10_DEVICE out_t project(acc_t a) const {
207 return a;
208 }
209
translate_idxAbsMinOps210 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
211 return acc;
212 }
213
214 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downAbsMinOps215 inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
216 return WARP_SHFL_DOWN(acc, offset);
217 }
218 #endif
219 };
220
221 // This accumulator template is used to calculate the maximum absolute value of
222 // a set of numbers.
223 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
224 // value. These types differ for complex number input support.
225 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
226 struct AbsMaxOps {
reduceAbsMaxOps227 inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
228 return MAX(acc, static_cast<acc_t>(std::abs(data)));
229 }
230
combineAbsMaxOps231 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
232 return MAX(a, b);
233 }
234
projectAbsMaxOps235 inline C10_DEVICE out_t project(acc_t a) const {
236 return a;
237 }
238
translate_idxAbsMaxOps239 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
240 return acc;
241 }
242
243 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downAbsMaxOps244 inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
245 return WARP_SHFL_DOWN(acc, offset);
246 }
247 #endif
248 };
249
250 // This accumulator template is used to calculate the norm of the absolute value
251 // of a set of numbers.
252 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
253 // value. These types differ for complex number input support.
254 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
255 struct NormOps {
256 acc_t norm_;
257
reduceNormOps258 inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
259 return acc + compat_pow(static_cast<acc_t>(std::abs(data)), norm_);
260 }
261
combineNormOps262 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
263 return a + b;
264 }
265
projectNormOps266 inline C10_DEVICE out_t project(acc_t a) const {
267 return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
268 }
269
translate_idxNormOps270 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
271 return acc;
272 }
273
274 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNormOps275 inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
276 return WARP_SHFL_DOWN(acc, offset);
277 }
278 #endif
279
NormOpsNormOps280 NormOps(acc_t norm_): norm_(norm_) {
281 }
282 };
283
284 // This accumulator template is used to calculate the order zero norm of the
285 // absolute value of a set of numbers.
286 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
287 // value. These types differ for complex number input support.
288 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
289 struct NormZeroOps {
reduceNormZeroOps290 inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
291 return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
292 }
293
combineNormZeroOps294 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
295 return a + b;
296 }
297
projectNormZeroOps298 inline C10_DEVICE out_t project(acc_t a) const {
299 return a;
300 }
301
translate_idxNormZeroOps302 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
303 return acc;
304 }
305
306
307 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNormZeroOps308 inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
309 return WARP_SHFL_DOWN(acc, offset);
310 }
311 #endif
312 };
313
314 // This accumulator template is used to calculate the order one norm of the
315 // absolute value of a set of numbers.
316 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
317 // value. These types differ for complex number input support.
318 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
319 struct NormOneOps {
reduceNormOneOps320 inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
321 return acc + static_cast<acc_t>(std::abs(data));
322 }
323
combineNormOneOps324 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
325 return a + b;
326 }
327
projectNormOneOps328 inline C10_DEVICE out_t project(acc_t a) const {
329 return a;
330 }
331
translate_idxNormOneOps332 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
333 return acc;
334 }
335
336 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNormOneOps337 inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
338 return WARP_SHFL_DOWN(acc, offset);
339 }
340 #endif
341 };
342
343
344 template<typename acc_t>
345 struct AbsSwitch {};
346
347 template<typename scalar_t, typename acc_t>
abs_if_complex(scalar_t data,AbsSwitch<acc_t>)348 inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
349 return static_cast<acc_t>(data);
350 }
351
352 template<typename scalar_t, typename acc_t>
abs_if_complex(std::complex<scalar_t> data,AbsSwitch<acc_t>)353 inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
354 return static_cast<acc_t>(std::abs(data));
355 }
356
357 template<typename scalar_t, typename acc_t>
abs_if_complex(c10::complex<scalar_t> data,AbsSwitch<acc_t>)358 inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
359 return static_cast<acc_t>(std::abs(data));
360 }
361
362 // This accumulator template is used to calculate the order two norm of the
363 // absolute value of a set of numbers.
364 // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
365 // value. These types differ for complex number input support.
366 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
367 struct NormTwoOps {
reduceNormTwoOps368 inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
369 acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
370 return acc + data_ * data_;
371 }
372
combineNormTwoOps373 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
374 return a + b;
375 }
376
projectNormTwoOps377 inline C10_DEVICE out_t project(acc_t a) const {
378 return device_sqrt(a);
379 }
380
translate_idxNormTwoOps381 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
382 return acc;
383 }
384
385 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNormTwoOps386 inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
387 return WARP_SHFL_DOWN(acc, offset);
388 }
389 #endif
390 };
391
392 template <typename acc_t, typename data_t>
393 struct NanSumOps {
reduceNanSumOps394 inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
395 return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
396 }
397
combineNanSumOps398 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
399 return a + b;
400 }
401
projectNanSumOps402 inline C10_DEVICE data_t project(acc_t a) const {
403 return data_t{a};
404 }
405
translate_idxNanSumOps406 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
407 return acc;
408 }
409
410 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downNanSumOps411 inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
412 return WARP_SHFL_DOWN(data, offset);
413 }
414 #endif
415 };
416
417 namespace detail {
418
419 template <typename scalar_t>
420 struct LessOrNan {
operatorLessOrNan421 C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
422 // If (a == b), then choose the one with lower idx, else min(a, b)
423 if (at::_isnan(a)) {
424 if (at::_isnan(b)) {
425 return idx_a < idx_b;
426 }
427 return true;
428 }
429 return (a == b) ? idx_a < idx_b : (a < b);
430 }
431 };
432
433 template <typename scalar_t>
434 struct GreaterOrNan {
operatorGreaterOrNan435 C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
436 // If (a == b), then choose the one with lower idx, else max(a, b)
437 if (at::_isnan(a)) {
438 if (at::_isnan(b)) {
439 return idx_a < idx_b;
440 }
441 return true;
442 }
443 return (a == b) ? idx_a < idx_b : (a > b);
444 }
445 };
446
447 template <typename comp_t>
448 struct MinMaxReductionOps {
449 using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
450 using index_t = int64_t;
451 using arg_t = detail::pair<scalar_t, index_t>;
452
projectMinMaxReductionOps453 static C10_DEVICE arg_t project(arg_t arg) {
454 return arg;
455 }
456
reduceMinMaxReductionOps457 static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
458 return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
459 }
460
combineMinMaxReductionOps461 static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
462 return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
463 }
464
translate_idxMinMaxReductionOps465 static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
466 return {a.first, a.second + base_idx};
467 }
468
469 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downMinMaxReductionOps470 static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
471 return arg_t(WARP_SHFL_DOWN(arg.first, offset),
472 WARP_SHFL_DOWN(arg.second, offset));
473 }
474 #endif
475 };
476
477 template <typename comp_t>
478 struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
479 using typename MinMaxReductionOps<comp_t>::scalar_t;
480 using typename MinMaxReductionOps<comp_t>::index_t;
481 using typename MinMaxReductionOps<comp_t>::arg_t;
482
projectArgReductionOps483 static C10_DEVICE index_t project(arg_t arg) {
484 return arg.second;
485 }
486 };
487
488 } // namespace detail
489
490 template <typename scalar_t>
491 struct ArgMaxOps :
492 public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
493 };
494
495 template <typename scalar_t>
496 struct ArgMinOps :
497 public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
498 };
499
500 template <typename scalar_t>
501 struct MinOps :
502 public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
503 };
504
505 template <typename scalar_t>
506 struct MaxOps :
507 public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
508 };
509
510 template <typename scalar_t, typename acc_scalar_t, typename index_t>
511 struct MinMaxOps {
512 using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
reduceMinMaxOps513 inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
514 return combine(acc, {data, data});
515 }
516
combineMinMaxOps517 inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
518 auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
519 auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
520
521 return {min_val, max_val};
522 }
523
projectMinMaxOps524 inline C10_DEVICE acc_t project(acc_t acc) const {
525 return acc;
526 }
527
translate_idxMinMaxOps528 static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
529 return acc;
530 }
531
532 #if defined(__CUDACC__) || defined(__HIPCC__)
warp_shfl_downMinMaxOps533 inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
534 return {
535 WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
536 };
537 }
538 #endif
539 };
540
541 }} // namespace at::native
542
543 #undef MAX
544 #undef MIN
545