xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ReduceNormKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/TensorIterator.h>
4 #include <ATen/native/cuda/Reduce.cuh>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/SharedReduceOps.h>
7 #include <ATen/native/ReduceOps.h>
8 #include <ATen/native/LinearAlgebra.h>
9 #include <c10/core/Scalar.h>
10 
11 namespace at::native {
12 
13 // This reduction accumulates results as the type `acc_t`. By default, when
14 // `scalar_t` is complex, `acc_t` is the downgraded real number type.
15 // Otherwise, `acc_t` and `scalar_t` are the same type.
16 template <typename scalar_t, typename acc_t=typename scalar_value_type<scalar_t>::type, typename out_t=typename scalar_value_type<scalar_t>::type>
norm_kernel_cuda_impl(TensorIterator & iter,double p)17 void norm_kernel_cuda_impl(TensorIterator& iter, double p) {
18   if (p == static_cast<double>(0)) {
19     gpu_reduce_kernel<scalar_t, out_t>(iter, NormZeroOps<scalar_t, acc_t, out_t>(), 0);
20   } else if (p == static_cast<double>(1)) {
21     gpu_reduce_kernel<scalar_t, out_t>(iter, NormOneOps<scalar_t, acc_t, out_t>(), 0);
22   } else if (p == static_cast<double>(2)) {
23     gpu_reduce_kernel<scalar_t, out_t>(iter, NormTwoOps<scalar_t, acc_t, out_t>(), 0);
24   } else if (p == static_cast<double>(INFINITY)) {
25     gpu_reduce_kernel<scalar_t, out_t>(iter, AbsMaxOps<scalar_t, acc_t, out_t>(), 0);
26   } else if (p == static_cast<double>(-INFINITY)) {
27     gpu_reduce_kernel<scalar_t, out_t>(iter, AbsMinOps<scalar_t, acc_t, out_t>(), std::numeric_limits<acc_t>::infinity());
28   } else {
29     gpu_reduce_kernel<scalar_t, out_t>(iter, NormOps<scalar_t, acc_t, out_t>{acc_t(p)}, 0);
30   }
31 }
32 
norm_launch_kernel(TensorIterator & iter,double ord)33 void norm_launch_kernel(TensorIterator& iter, double ord) {
34   if (iter.dtype(0) == kHalf) {
35     return norm_kernel_cuda_impl<at::Half, float>(iter, ord);
36   } else if (iter.input_dtype() == kHalf && iter.dtype(0) == kFloat) {
37     // type promotion that does cast and reduction in a single kernel
38     return norm_kernel_cuda_impl<at::Half, float, float>(iter, ord);
39   }
40   else if(iter.dtype(0) == kBFloat16) {
41     return norm_kernel_cuda_impl<at::BFloat16, float>(iter, ord);
42   } else if (iter.input_dtype() == kBFloat16 && iter.dtype(0) == kFloat) {
43     // type promotion that does cast and reduction in a single kernel
44     return norm_kernel_cuda_impl<at::BFloat16, float, float>(iter, ord);
45   }
46   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cuda", [&] {
47     norm_kernel_cuda_impl<scalar_t>(iter, ord);
48   });
49 }
50 
51 } // namespace at::native
52