xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/PowKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <cmath>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/cpu/vec/vec.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/Pow.h>
8 #include <ATen/native/UnaryOps.h>
9 #include <ATen/native/cpu/Loops.h>
10 
11 #include <c10/core/Scalar.h>
12 
13 namespace at::native {
14 
15 inline namespace CPU_CAPABILITY {
16 
pow_tensor_tensor_kernel(TensorIteratorBase & iter)17 static void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
18   const auto dtype = iter.common_dtype();
19   if (isFloatingType(dtype) || isComplexType(dtype)) {
20     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, dtype, "pow", [&]() {
21 
22       using Vec = Vectorized<scalar_t>;
23       cpu_kernel_vec(iter,
24         [=](scalar_t base, scalar_t exp) -> scalar_t {
25           return std::pow(base, exp);
26         },
27         [&](Vec base, Vec exp) -> Vec {
28           return base.pow(exp);
29         }
30       );
31     });
32   } else {
33     AT_DISPATCH_INTEGRAL_TYPES(dtype, "pow", [&]() {
34       cpu_kernel(iter,
35         [=](scalar_t base, scalar_t exp) -> scalar_t {
36           return native::powi(base, exp);
37         }
38       );
39     });
40   }
41 }
42 
43 // The source-code of kernels for float, double and complex types is similar,
44 // barring a small distinction - even if the output dtype is float, a double
45 // exponent can be used. But Complex types' computation doesn't allow standard
46 // & double-precision to be mixed, since std::pow takes either complex64 inputs,
47 // or complex128 inputs, but not both. So, in order to provide a common path for
48 // float, double & complex types, template parameter cast_scalar_t is being used
49 // to resolve the aforementioned distinction. This approach also allows BFloat16
50 // to use this common-path. Half cannot currently use it, as AVX2 support for
51 // sqrt & rsqrt doesn't currently exist for it.
52 template <typename scalar_t, typename cast_scalar_t, typename exp_scalar_t>
pow_tensor_scalar_optimized_kernel(TensorIteratorBase & iter,const exp_scalar_t exp)53 void pow_tensor_scalar_optimized_kernel(TensorIteratorBase& iter, const exp_scalar_t exp) {
54   using Vec = Vectorized<scalar_t>;
55   // .5 (sqrt), -.5 (rsqrt) and -1 (reciprocal) specializations are handled
56   // in pow_tensor_scalar_kernel
57   if (exp == 2.0) {
58     cpu_kernel_vec(iter,
59         [](scalar_t base) -> scalar_t {
60           return base * base;
61         },
62         [](Vec base) -> Vec { return base * base; }
63     );
64   } else if (exp == 3.0) {
65     cpu_kernel_vec(iter,
66         [](scalar_t base) -> scalar_t {
67           return base * base * base;
68         },
69         [](Vec base) -> Vec { return base * base * base; }
70     );
71   } else if (exp == -2.0) {
72     cpu_kernel_vec(iter,
73         [](scalar_t base) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
74           return static_cast<cast_scalar_t>(1.0) / (base * base); },
75         [](Vec base) -> Vec { return (base * base).reciprocal(); }
76     );
77   } else {
78     cpu_kernel_vec(iter,
79         [=](scalar_t base) -> scalar_t {
80           return std::pow(base, static_cast<cast_scalar_t>(exp));
81         },
82         [=](Vec base) -> Vec {
83           return base.pow(static_cast<cast_scalar_t>(exp));
84         }
85     );
86   }
87 }
88 
pow_tensor_scalar_kernel(TensorIteratorBase & iter,const Scalar & exp_scalar)89 static void pow_tensor_scalar_kernel(
90     TensorIteratorBase& iter,
91     const Scalar& exp_scalar) {
92   // prevent multiple calls to iter.common_dtype()
93   const auto dtype = iter.common_dtype();
94 
95   if (dtype == ScalarType::Float || dtype == ScalarType::Double ||
96       dtype == kBFloat16 || isComplexType(dtype)) {
97     // Dispatch to fast specialization for sqrt, rsqrt and reciprocal
98     if (exp_scalar.equal(.5)) {
99       return sqrt_kernel(iter);
100     } else if (exp_scalar.equal(-0.5)) {
101       return rsqrt_kernel(iter);
102     } else if (exp_scalar.equal(-1.0)) {
103       return reciprocal_kernel(iter);
104     }
105   }
106 
107   if (dtype == ScalarType::Float || dtype == ScalarType::Double) {
108     AT_DISPATCH_FLOATING_TYPES(dtype, "pow", [&]() {
109       pow_tensor_scalar_optimized_kernel<scalar_t, double>(
110           iter, exp_scalar.to<double>());
111     });
112   } else if (isComplexType(dtype)) {
113     AT_DISPATCH_COMPLEX_TYPES(dtype, "pow", [&]() {
114       pow_tensor_scalar_optimized_kernel<scalar_t, scalar_t>(
115           iter, exp_scalar.to<c10::complex<double>>());
116     });
117   } else if (dtype == ScalarType::Half) {
118     [&]() {
119       using scalar_t =
120           decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
121       const auto exp = exp_scalar.to<scalar_t>();
122       using Vec = Vectorized<scalar_t>;
123       cpu_kernel_vec(iter,
124           [=](scalar_t base) -> scalar_t {
125             return std::pow(base, exp);
126           },
127           [=](Vec base) -> Vec { return base.pow(exp); }
128       );
129     }();
130   } else if (dtype == ScalarType::BFloat16) {
131       AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, dtype, "pow", [&]() {
132         pow_tensor_scalar_optimized_kernel<scalar_t, scalar_t>(
133             iter, exp_scalar.to<scalar_t>());
134       });
135   } else {
136     AT_DISPATCH_INTEGRAL_TYPES(dtype, "pow", [&]() {
137       const scalar_t exp = exp_scalar.to<scalar_t>();
138       cpu_kernel(iter, [=](scalar_t base) -> scalar_t {
139         return native::powi(base, exp);
140       });
141     });
142   }
143 }
144 
145 } // anonymous namespace
146 
147 ALSO_REGISTER_AVX512_DISPATCH(pow_tensor_tensor_stub, &CPU_CAPABILITY::pow_tensor_tensor_kernel);
148 ALSO_REGISTER_AVX512_DISPATCH(pow_tensor_scalar_stub, &CPU_CAPABILITY::pow_tensor_scalar_kernel);
149 
150 } // namespace at::native
151