1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Pow.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/ScalarOps.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/float_power_native.h>
12 #include <ATen/ops/pow.h>
13 #include <ATen/ops/pow_native.h>
14 #include <ATen/ops/result_type.h>
15 #endif
16
17 namespace at::meta {
18
TORCH_META_FUNC2(pow,Tensor_Tensor)19 TORCH_META_FUNC2(pow, Tensor_Tensor) (const Tensor& base, const Tensor& exp) {
20 build_borrowing_binary_op(maybe_get_output(), base, exp);
21 }
22
TORCH_META_FUNC2(pow,Tensor_Scalar)23 TORCH_META_FUNC2(pow, Tensor_Scalar) (const Tensor& base, const Scalar& exp) {
24 // Numpy compatibility check:
25 TORCH_CHECK(!(isIntegralType(base.scalar_type(), true) &&
26 exp.isIntegral(true) && exp.toLong() < 0),
27 "Integers to negative integer powers are not allowed.");
28
29 auto common_dtype = at::result_type(base, exp);
30 build_output_borrowing_argument_owning_unary_op(maybe_get_output(), base.to(common_dtype));
31 }
32
TORCH_META_FUNC2(pow,Scalar)33 TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) {
34 // This overload doesn't directly use TensorIterator. It attempts to short-circuit,
35 // but otherwise redispatches to the Tensor_Tensor overload.
36 auto dtype = maybe_get_output().defined() ? maybe_get_output().scalar_type() : at::result_type(base, exp);
37 set_output_raw_strided(0, exp.sizes(), {}, exp.options().dtype(dtype), exp.has_names() ? exp.names() : ArrayRef<Dimname>());
38 }
39
40 } // namespace at::meta
41
42 namespace at::native {
43
44 DEFINE_DISPATCH(pow_tensor_tensor_stub);
45 DEFINE_DISPATCH(pow_tensor_scalar_stub);
46
TORCH_IMPL_FUNC(pow_Tensor_Tensor_out)47 TORCH_IMPL_FUNC(pow_Tensor_Tensor_out) (const Tensor& base, const Tensor& exp, const Tensor& out) {
48 pow_tensor_tensor_stub(device_type(), *this);
49 }
50
TORCH_IMPL_FUNC(pow_Tensor_Scalar_out)51 TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) {
52 if (exp.equal(0.0) || exp.equal(false)) {
53 out.fill_(1);
54 } else if (exp.equal(1.0) || exp.equal(true) ) {
55 out.copy_(base);
56 } else {
57 pow_tensor_scalar_stub(device_type(), *this, exp);
58 }
59 }
60
TORCH_IMPL_FUNC(pow_Scalar_out)61 TORCH_IMPL_FUNC(pow_Scalar_out) (const Scalar& base, const Tensor& exp, const Tensor& out) {
62 if (base.equal(1.0)) {
63 out.fill_(1);
64 } else {
65 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
66 at::pow_out(const_cast<Tensor&>(out), wrapped_scalar_tensor(base, exp.device()), exp); // redispatch!
67 }
68 }
69
float_power_out(const Tensor & base,const Tensor & exp,Tensor & result)70 Tensor& float_power_out(const Tensor& base, const Tensor& exp, Tensor& result) {
71 auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
72 at::kComplexDouble : at::kDouble;
73 TORCH_CHECK(result.scalar_type() == dtype,
74 "the output given to float_power has dtype ", result.scalar_type(),
75 " but the operation's result requires dtype ", dtype);
76
77 return at::pow_out(result, base.to(dtype), exp.to(dtype));
78 }
79
float_power_out(const Tensor & base,const Scalar & exp,Tensor & result)80 Tensor& float_power_out(const Tensor& base, const Scalar& exp, Tensor& result) {
81 auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
82 TORCH_CHECK(result.scalar_type() == dtype,
83 "the output given to float_power has dtype ", result.scalar_type(),
84 " but the operation's result requires dtype ", dtype);
85
86 // Note: need the casts inside the ternary because conversion functions return e.g. c10::complex,
87 // which causes a complex scalar to always be returned.
88 auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
89 return at::pow_out(result, base.to(dtype), casted_exp);
90 }
91
float_power_out(const Scalar & base,const Tensor & exp,Tensor & result)92 Tensor& float_power_out(const Scalar& base, const Tensor& exp, Tensor& result) {
93 auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
94 TORCH_CHECK(result.scalar_type() == dtype,
95 "the output given to float_power has dtype ", result.scalar_type(),
96 " but the operation's result requires dtype ", dtype);
97
98 auto casted_base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
99 return at::pow_out(result, casted_base, exp.to(dtype));
100 }
101
float_power(const Tensor & base,const Scalar & exp)102 Tensor float_power(const Tensor& base, const Scalar& exp) {
103 auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
104 auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
105 return at::pow(base.to(dtype), casted_exp);
106 }
107
float_power(const Scalar & base,const Tensor & exp)108 Tensor float_power(const Scalar& base, const Tensor& exp) {
109 auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
110 auto casted_base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
111 return at::pow(casted_base, exp.to(dtype));
112 }
113
float_power(const Tensor & base,const Tensor & exp)114 Tensor float_power(const Tensor& base, const Tensor& exp) {
115 auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
116 return at::pow(base.to(dtype), exp.to(dtype));
117 }
118
float_power_(Tensor & base,const Tensor & exp)119 Tensor& float_power_(Tensor& base, const Tensor& exp) {
120 auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
121 TORCH_CHECK(base.scalar_type() == dtype,
122 "the base given to float_power_ has dtype ", base.scalar_type(),
123 " but the operation's result requires dtype ", dtype);
124
125 return base.pow_(exp.to(dtype));
126 }
127
float_power_(Tensor & base,const Scalar & exp)128 Tensor& float_power_(Tensor& base, const Scalar& exp) {
129 auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
130 TORCH_CHECK(base.scalar_type() == dtype,
131 "the base given to float_power_ has dtype ", base.scalar_type(),
132 " but the operation's result requires dtype ", dtype);
133
134 auto casted_exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
135 return base.pow_(casted_exp);
136 }
137
138 } // namespace at::native
139