xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Pow.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Pow.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/ScalarOps.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/float_power_native.h>
12 #include <ATen/ops/pow.h>
13 #include <ATen/ops/pow_native.h>
14 #include <ATen/ops/result_type.h>
15 #endif
16 
17 namespace at::meta {
18 
TORCH_META_FUNC2(pow,Tensor_Tensor)19 TORCH_META_FUNC2(pow, Tensor_Tensor) (const Tensor& base, const Tensor& exp) {
20   build_borrowing_binary_op(maybe_get_output(), base, exp);
21 }
22 
TORCH_META_FUNC2(pow,Tensor_Scalar)23 TORCH_META_FUNC2(pow, Tensor_Scalar) (const Tensor& base, const Scalar& exp) {
24   // Numpy compatibility check:
25   TORCH_CHECK(!(isIntegralType(base.scalar_type(), true) &&
26               exp.isIntegral(true) && exp.toLong() < 0),
27               "Integers to negative integer powers are not allowed.");
28 
29   auto common_dtype = at::result_type(base, exp);
30   build_output_borrowing_argument_owning_unary_op(maybe_get_output(), base.to(common_dtype));
31 }
32 
TORCH_META_FUNC2(pow,Scalar)33 TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) {
34     // This overload doesn't directly use TensorIterator. It attempts to short-circuit,
35     // but otherwise redispatches to the Tensor_Tensor overload.
36     auto dtype = maybe_get_output().defined() ? maybe_get_output().scalar_type() : at::result_type(base, exp);
37     set_output_raw_strided(0, exp.sizes(), {}, exp.options().dtype(dtype), exp.has_names() ? exp.names() : ArrayRef<Dimname>());
38 }
39 
40 } // namespace at::meta
41 
42 namespace at::native {
43 
44 DEFINE_DISPATCH(pow_tensor_tensor_stub);
45 DEFINE_DISPATCH(pow_tensor_scalar_stub);
46 
TORCH_IMPL_FUNC(pow_Tensor_Tensor_out)47 TORCH_IMPL_FUNC(pow_Tensor_Tensor_out) (const Tensor& base, const Tensor& exp, const Tensor& out) {
48   pow_tensor_tensor_stub(device_type(), *this);
49 }
50 
TORCH_IMPL_FUNC(pow_Tensor_Scalar_out)51 TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) {
52   if (exp.equal(0.0) || exp.equal(false)) {
53     out.fill_(1);
54   } else if (exp.equal(1.0) || exp.equal(true) ) {
55     out.copy_(base);
56   } else {
57     pow_tensor_scalar_stub(device_type(), *this, exp);
58   }
59 }
60 
TORCH_IMPL_FUNC(pow_Scalar_out)61 TORCH_IMPL_FUNC(pow_Scalar_out) (const Scalar& base, const Tensor& exp, const Tensor& out) {
62   if (base.equal(1.0)) {
63     out.fill_(1);
64   } else {
65     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
66     at::pow_out(const_cast<Tensor&>(out), wrapped_scalar_tensor(base, exp.device()), exp); // redispatch!
67   }
68 }
69 
float_power_out(const Tensor & base,const Tensor & exp,Tensor & result)70 Tensor& float_power_out(const Tensor& base, const Tensor& exp, Tensor& result) {
71   auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
72                 at::kComplexDouble : at::kDouble;
73   TORCH_CHECK(result.scalar_type() == dtype,
74               "the output given to float_power has dtype ", result.scalar_type(),
75               " but the operation's result requires dtype ", dtype);
76 
77   return at::pow_out(result, base.to(dtype), exp.to(dtype));
78 }
79 
float_power_out(const Tensor & base,const Scalar & exp,Tensor & result)80 Tensor& float_power_out(const Tensor& base, const Scalar& exp, Tensor& result) {
81   auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
82   TORCH_CHECK(result.scalar_type() == dtype,
83               "the output given to float_power has dtype ", result.scalar_type(),
84               " but the operation's result requires dtype ", dtype);
85 
86   // Note: need the casts inside the ternary because conversion functions return e.g. c10::complex,
87   // which causes a complex scalar to always be returned.
88   auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
89   return at::pow_out(result, base.to(dtype), casted_exp);
90 }
91 
float_power_out(const Scalar & base,const Tensor & exp,Tensor & result)92 Tensor& float_power_out(const Scalar& base, const Tensor& exp, Tensor& result) {
93   auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
94   TORCH_CHECK(result.scalar_type() == dtype,
95               "the output given to float_power has dtype ", result.scalar_type(),
96               " but the operation's result requires dtype ", dtype);
97 
98   auto casted_base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
99   return at::pow_out(result, casted_base, exp.to(dtype));
100 }
101 
float_power(const Tensor & base,const Scalar & exp)102 Tensor float_power(const Tensor& base, const Scalar& exp) {
103   auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
104   auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
105   return at::pow(base.to(dtype), casted_exp);
106 }
107 
float_power(const Scalar & base,const Tensor & exp)108 Tensor float_power(const Scalar& base, const Tensor& exp) {
109   auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
110   auto casted_base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
111   return at::pow(casted_base, exp.to(dtype));
112 }
113 
float_power(const Tensor & base,const Tensor & exp)114 Tensor float_power(const Tensor& base, const Tensor& exp) {
115   auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
116   return at::pow(base.to(dtype), exp.to(dtype));
117 }
118 
float_power_(Tensor & base,const Tensor & exp)119 Tensor& float_power_(Tensor& base, const Tensor& exp) {
120   auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
121   TORCH_CHECK(base.scalar_type() == dtype,
122               "the base given to float_power_ has dtype ", base.scalar_type(),
123               " but the operation's result requires dtype ", dtype);
124 
125   return base.pow_(exp.to(dtype));
126 }
127 
float_power_(Tensor & base,const Scalar & exp)128 Tensor& float_power_(Tensor& base, const Scalar& exp) {
129   auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
130   TORCH_CHECK(base.scalar_type() == dtype,
131               "the base given to float_power_ has dtype ", base.scalar_type(),
132               " but the operation's result requires dtype ", dtype);
133 
134   auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
135   return base.pow_(casted_exp);
136 }
137 
138 } // namespace at::native
139