1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/TensorIterator.h>
4 #include <ATen/native/cuda/Reduce.cuh>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/SharedReduceOps.h>
7 #include <ATen/native/ReduceOps.h>
8 #include <ATen/native/LinearAlgebra.h>
9 #include <c10/core/Scalar.h>
10
11 namespace at::native {
12
13 // This reduction accumulates results as the type `acc_t`. By default, when
14 // `scalar_t` is complex, `acc_t` is the downgraded real number type.
15 // Otherwise, `acc_t` and `scalar_t` are the same type.
16 template <typename scalar_t, typename acc_t=typename scalar_value_type<scalar_t>::type, typename out_t=typename scalar_value_type<scalar_t>::type>
norm_kernel_cuda_impl(TensorIterator & iter,double p)17 void norm_kernel_cuda_impl(TensorIterator& iter, double p) {
18 if (p == static_cast<double>(0)) {
19 gpu_reduce_kernel<scalar_t, out_t>(iter, NormZeroOps<scalar_t, acc_t, out_t>(), 0);
20 } else if (p == static_cast<double>(1)) {
21 gpu_reduce_kernel<scalar_t, out_t>(iter, NormOneOps<scalar_t, acc_t, out_t>(), 0);
22 } else if (p == static_cast<double>(2)) {
23 gpu_reduce_kernel<scalar_t, out_t>(iter, NormTwoOps<scalar_t, acc_t, out_t>(), 0);
24 } else if (p == static_cast<double>(INFINITY)) {
25 gpu_reduce_kernel<scalar_t, out_t>(iter, AbsMaxOps<scalar_t, acc_t, out_t>(), 0);
26 } else if (p == static_cast<double>(-INFINITY)) {
27 gpu_reduce_kernel<scalar_t, out_t>(iter, AbsMinOps<scalar_t, acc_t, out_t>(), std::numeric_limits<acc_t>::infinity());
28 } else {
29 gpu_reduce_kernel<scalar_t, out_t>(iter, NormOps<scalar_t, acc_t, out_t>{acc_t(p)}, 0);
30 }
31 }
32
norm_launch_kernel(TensorIterator & iter,double ord)33 void norm_launch_kernel(TensorIterator& iter, double ord) {
34 if (iter.dtype(0) == kHalf) {
35 return norm_kernel_cuda_impl<at::Half, float>(iter, ord);
36 } else if (iter.input_dtype() == kHalf && iter.dtype(0) == kFloat) {
37 // type promotion that does cast and reduction in a single kernel
38 return norm_kernel_cuda_impl<at::Half, float, float>(iter, ord);
39 }
40 else if(iter.dtype(0) == kBFloat16) {
41 return norm_kernel_cuda_impl<at::BFloat16, float>(iter, ord);
42 } else if (iter.input_dtype() == kBFloat16 && iter.dtype(0) == kFloat) {
43 // type promotion that does cast and reduction in a single kernel
44 return norm_kernel_cuda_impl<at::BFloat16, float, float>(iter, ord);
45 }
46 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cuda", [&] {
47 norm_kernel_cuda_impl<scalar_t>(iter, ord);
48 });
49 }
50
51 } // namespace at::native
52