1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <cmath>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/cpu/vec/vec.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/Pow.h>
8 #include <ATen/native/UnaryOps.h>
9 #include <ATen/native/cpu/Loops.h>
10
11 #include <c10/core/Scalar.h>
12
13 namespace at::native {
14
15 inline namespace CPU_CAPABILITY {
16
pow_tensor_tensor_kernel(TensorIteratorBase & iter)17 static void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
18 const auto dtype = iter.common_dtype();
19 if (isFloatingType(dtype) || isComplexType(dtype)) {
20 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, dtype, "pow", [&]() {
21
22 using Vec = Vectorized<scalar_t>;
23 cpu_kernel_vec(iter,
24 [=](scalar_t base, scalar_t exp) -> scalar_t {
25 return std::pow(base, exp);
26 },
27 [&](Vec base, Vec exp) -> Vec {
28 return base.pow(exp);
29 }
30 );
31 });
32 } else {
33 AT_DISPATCH_INTEGRAL_TYPES(dtype, "pow", [&]() {
34 cpu_kernel(iter,
35 [=](scalar_t base, scalar_t exp) -> scalar_t {
36 return native::powi(base, exp);
37 }
38 );
39 });
40 }
41 }
42
43 // The source-code of kernels for float, double and complex types is similar,
44 // barring a small distinction - even if the output dtype is float, a double
45 // exponent can be used. But Complex types' computation doesn't allow standard
46 // & double-precision to be mixed, since std::pow takes either complex64 inputs,
47 // or complex128 inputs, but not both. So, in order to provide a common path for
48 // float, double & complex types, template parameter cast_scalar_t is being used
49 // to resolve the aforementioned distinction. This approach also allows BFloat16
50 // to use this common-path. Half cannot currently use it, as AVX2 support for
51 // sqrt & rsqrt doesn't currently exist for it.
52 template <typename scalar_t, typename cast_scalar_t, typename exp_scalar_t>
pow_tensor_scalar_optimized_kernel(TensorIteratorBase & iter,const exp_scalar_t exp)53 void pow_tensor_scalar_optimized_kernel(TensorIteratorBase& iter, const exp_scalar_t exp) {
54 using Vec = Vectorized<scalar_t>;
55 // .5 (sqrt), -.5 (rsqrt) and -1 (reciprocal) specializations are handled
56 // in pow_tensor_scalar_kernel
57 if (exp == 2.0) {
58 cpu_kernel_vec(iter,
59 [](scalar_t base) -> scalar_t {
60 return base * base;
61 },
62 [](Vec base) -> Vec { return base * base; }
63 );
64 } else if (exp == 3.0) {
65 cpu_kernel_vec(iter,
66 [](scalar_t base) -> scalar_t {
67 return base * base * base;
68 },
69 [](Vec base) -> Vec { return base * base * base; }
70 );
71 } else if (exp == -2.0) {
72 cpu_kernel_vec(iter,
73 [](scalar_t base) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
74 return static_cast<cast_scalar_t>(1.0) / (base * base); },
75 [](Vec base) -> Vec { return (base * base).reciprocal(); }
76 );
77 } else {
78 cpu_kernel_vec(iter,
79 [=](scalar_t base) -> scalar_t {
80 return std::pow(base, static_cast<cast_scalar_t>(exp));
81 },
82 [=](Vec base) -> Vec {
83 return base.pow(static_cast<cast_scalar_t>(exp));
84 }
85 );
86 }
87 }
88
pow_tensor_scalar_kernel(TensorIteratorBase & iter,const Scalar & exp_scalar)89 static void pow_tensor_scalar_kernel(
90 TensorIteratorBase& iter,
91 const Scalar& exp_scalar) {
92 // prevent multiple calls to iter.common_dtype()
93 const auto dtype = iter.common_dtype();
94
95 if (dtype == ScalarType::Float || dtype == ScalarType::Double ||
96 dtype == kBFloat16 || isComplexType(dtype)) {
97 // Dispatch to fast specialization for sqrt, rsqrt and reciprocal
98 if (exp_scalar.equal(.5)) {
99 return sqrt_kernel(iter);
100 } else if (exp_scalar.equal(-0.5)) {
101 return rsqrt_kernel(iter);
102 } else if (exp_scalar.equal(-1.0)) {
103 return reciprocal_kernel(iter);
104 }
105 }
106
107 if (dtype == ScalarType::Float || dtype == ScalarType::Double) {
108 AT_DISPATCH_FLOATING_TYPES(dtype, "pow", [&]() {
109 pow_tensor_scalar_optimized_kernel<scalar_t, double>(
110 iter, exp_scalar.to<double>());
111 });
112 } else if (isComplexType(dtype)) {
113 AT_DISPATCH_COMPLEX_TYPES(dtype, "pow", [&]() {
114 pow_tensor_scalar_optimized_kernel<scalar_t, scalar_t>(
115 iter, exp_scalar.to<c10::complex<double>>());
116 });
117 } else if (dtype == ScalarType::Half) {
118 [&]() {
119 using scalar_t =
120 decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
121 const auto exp = exp_scalar.to<scalar_t>();
122 using Vec = Vectorized<scalar_t>;
123 cpu_kernel_vec(iter,
124 [=](scalar_t base) -> scalar_t {
125 return std::pow(base, exp);
126 },
127 [=](Vec base) -> Vec { return base.pow(exp); }
128 );
129 }();
130 } else if (dtype == ScalarType::BFloat16) {
131 AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, dtype, "pow", [&]() {
132 pow_tensor_scalar_optimized_kernel<scalar_t, scalar_t>(
133 iter, exp_scalar.to<scalar_t>());
134 });
135 } else {
136 AT_DISPATCH_INTEGRAL_TYPES(dtype, "pow", [&]() {
137 const scalar_t exp = exp_scalar.to<scalar_t>();
138 cpu_kernel(iter, [=](scalar_t base) -> scalar_t {
139 return native::powi(base, exp);
140 });
141 });
142 }
143 }
144
145 } // anonymous namespace
146
147 ALSO_REGISTER_AVX512_DISPATCH(pow_tensor_tensor_stub, &CPU_CAPABILITY::pow_tensor_tensor_kernel);
148 ALSO_REGISTER_AVX512_DISPATCH(pow_tensor_scalar_stub, &CPU_CAPABILITY::pow_tensor_scalar_kernel);
149
150 } // namespace at::native
151