xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/UnaryOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/MemoryOverlap.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/ScalarOps.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/TensorOperators.h>
10 #include <ATen/WrapDimUtils.h>
11 
12 #include <ATen/native/Resize.h>
13 #include <ATen/native/UnaryOps.h>
14 #include <ATen/native/ComplexHelper.h>
15 
16 #include <c10/util/MathConstants.h>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/_conj_native.h>
23 #include <ATen/ops/_conj_physical.h>
24 #include <ATen/ops/_conj_physical_native.h>
25 #include <ATen/ops/_neg_view_native.h>
26 #include <ATen/ops/abs.h>
27 #include <ATen/ops/abs_native.h>
28 #include <ATen/ops/absolute_native.h>
29 #include <ATen/ops/acos.h>
30 #include <ATen/ops/acos_native.h>
31 #include <ATen/ops/acosh.h>
32 #include <ATen/ops/acosh_native.h>
33 #include <ATen/ops/angle.h>
34 #include <ATen/ops/angle_native.h>
35 #include <ATen/ops/arange_native.h>
36 #include <ATen/ops/arccos_native.h>
37 #include <ATen/ops/arccosh_native.h>
38 #include <ATen/ops/arcsin_native.h>
39 #include <ATen/ops/arcsinh_native.h>
40 #include <ATen/ops/arctan_native.h>
41 #include <ATen/ops/arctanh_native.h>
42 #include <ATen/ops/asin.h>
43 #include <ATen/ops/asin_native.h>
44 #include <ATen/ops/asinh.h>
45 #include <ATen/ops/asinh_native.h>
46 #include <ATen/ops/atan.h>
47 #include <ATen/ops/atan_native.h>
48 #include <ATen/ops/atanh.h>
49 #include <ATen/ops/atanh_native.h>
50 #include <ATen/ops/bitwise_not_native.h>
51 #include <ATen/ops/can_cast.h>
52 #include <ATen/ops/ceil_native.h>
53 #include <ATen/ops/conj_native.h>
54 #include <ATen/ops/conj_physical.h>
55 #include <ATen/ops/conj_physical_native.h>
56 #include <ATen/ops/cos_native.h>
57 #include <ATen/ops/cosh_native.h>
58 #include <ATen/ops/deg2rad.h>
59 #include <ATen/ops/deg2rad_native.h>
60 #include <ATen/ops/digamma.h>
61 #include <ATen/ops/digamma_native.h>
62 #include <ATen/ops/empty.h>
63 #include <ATen/ops/empty_like.h>
64 #include <ATen/ops/erf.h>
65 #include <ATen/ops/erf_native.h>
66 #include <ATen/ops/erfc.h>
67 #include <ATen/ops/erfc_native.h>
68 #include <ATen/ops/erfinv.h>
69 #include <ATen/ops/erfinv_native.h>
70 #include <ATen/ops/exp2.h>
71 #include <ATen/ops/exp2_native.h>
72 #include <ATen/ops/exp_native.h>
73 #include <ATen/ops/expm1.h>
74 #include <ATen/ops/expm1_native.h>
75 #include <ATen/ops/fix_native.h>
76 #include <ATen/ops/floor_native.h>
77 #include <ATen/ops/frac_native.h>
78 #include <ATen/ops/frexp.h>
79 #include <ATen/ops/frexp_native.h>
80 #include <ATen/ops/i0.h>
81 #include <ATen/ops/i0_native.h>
82 #include <ATen/ops/imag_native.h>
83 #include <ATen/ops/lgamma.h>
84 #include <ATen/ops/lgamma_native.h>
85 #include <ATen/ops/log10_native.h>
86 #include <ATen/ops/log1p.h>
87 #include <ATen/ops/log1p_native.h>
88 #include <ATen/ops/log2_native.h>
89 #include <ATen/ops/log_native.h>
90 #include <ATen/ops/logical_not.h>
91 #include <ATen/ops/logical_not_native.h>
92 #include <ATen/ops/logit.h>
93 #include <ATen/ops/logit_native.h>
94 #include <ATen/ops/mul.h>
95 #include <ATen/ops/mvlgamma.h>
96 #include <ATen/ops/mvlgamma_native.h>
97 #include <ATen/ops/nan_to_num.h>
98 #include <ATen/ops/nan_to_num_native.h>
99 #include <ATen/ops/neg.h>
100 #include <ATen/ops/neg_native.h>
101 #include <ATen/ops/negative_native.h>
102 #include <ATen/ops/polygamma.h>
103 #include <ATen/ops/polygamma_native.h>
104 #include <ATen/ops/positive_native.h>
105 #include <ATen/ops/pow.h>
106 #include <ATen/ops/rad2deg.h>
107 #include <ATen/ops/rad2deg_native.h>
108 #include <ATen/ops/real.h>
109 #include <ATen/ops/real_native.h>
110 #include <ATen/ops/reciprocal_native.h>
111 #include <ATen/ops/resolve_conj_native.h>
112 #include <ATen/ops/resolve_neg_native.h>
113 #include <ATen/ops/round.h>
114 #include <ATen/ops/round_native.h>
115 #include <ATen/ops/rsqrt_native.h>
116 #include <ATen/ops/select.h>
117 #include <ATen/ops/sgn_native.h>
118 #include <ATen/ops/sigmoid.h>
119 #include <ATen/ops/sigmoid_native.h>
120 #include <ATen/ops/sign_native.h>
121 #include <ATen/ops/signbit_native.h>
122 #include <ATen/ops/sin_native.h>
123 #include <ATen/ops/sinc.h>
124 #include <ATen/ops/sinc_native.h>
125 #include <ATen/ops/sinh_native.h>
126 #include <ATen/ops/special_airy_ai_native.h>
127 #include <ATen/ops/special_bessel_j0_native.h>
128 #include <ATen/ops/special_bessel_j1_native.h>
129 #include <ATen/ops/special_bessel_y0_native.h>
130 #include <ATen/ops/special_bessel_y1_native.h>
131 #include <ATen/ops/special_digamma_native.h>
132 #include <ATen/ops/special_entr_native.h>
133 #include <ATen/ops/special_erf_native.h>
134 #include <ATen/ops/special_erfc_native.h>
135 #include <ATen/ops/special_erfcx_native.h>
136 #include <ATen/ops/special_erfinv_native.h>
137 #include <ATen/ops/special_exp2_native.h>
138 #include <ATen/ops/special_expit_native.h>
139 #include <ATen/ops/special_expm1_native.h>
140 #include <ATen/ops/special_gammaln_native.h>
141 #include <ATen/ops/special_i0_native.h>
142 #include <ATen/ops/special_i0e_native.h>
143 #include <ATen/ops/special_i1_native.h>
144 #include <ATen/ops/special_i1e_native.h>
145 #include <ATen/ops/special_log1p_native.h>
146 #include <ATen/ops/special_log_ndtr_native.h>
147 #include <ATen/ops/special_logit_native.h>
148 #include <ATen/ops/special_modified_bessel_i0_native.h>
149 #include <ATen/ops/special_modified_bessel_i1_native.h>
150 #include <ATen/ops/special_modified_bessel_k0_native.h>
151 #include <ATen/ops/special_modified_bessel_k1_native.h>
152 #include <ATen/ops/special_multigammaln_native.h>
153 #include <ATen/ops/special_ndtr_native.h>
154 #include <ATen/ops/special_ndtri_native.h>
155 #include <ATen/ops/special_polygamma_native.h>
156 #include <ATen/ops/special_psi_native.h>
157 #include <ATen/ops/special_round_native.h>
158 #include <ATen/ops/special_scaled_modified_bessel_k0_native.h>
159 #include <ATen/ops/special_scaled_modified_bessel_k1_native.h>
160 #include <ATen/ops/special_sinc_native.h>
161 #include <ATen/ops/special_spherical_bessel_j0_native.h>
162 #include <ATen/ops/sqrt_native.h>
163 #include <ATen/ops/square_native.h>
164 #include <ATen/ops/tan_native.h>
165 #include <ATen/ops/tanh_native.h>
166 #include <ATen/ops/trunc.h>
167 #include <ATen/ops/trunc_native.h>
168 #include <ATen/ops/view_as_real.h>
169 #endif
170 
171 #include <cmath>
172 
173 namespace at::meta {
174 
175 // Unary float operations always produce floating point
176 // outputs for floating point and integral types
177 // For complex inputs, the output type should be the same as input type.
178 #define CREATE_UNARY_FLOAT_META_FUNC(func)                  \
179   TORCH_META_FUNC(func) (const Tensor& self) {        \
180     build_borrowing_unary_float_op(maybe_get_output(), self);   \
181   }
182 
183 CREATE_UNARY_FLOAT_META_FUNC(acos)
CREATE_UNARY_FLOAT_META_FUNC(acosh)184 CREATE_UNARY_FLOAT_META_FUNC(acosh)
185 CREATE_UNARY_FLOAT_META_FUNC(asin)
186 CREATE_UNARY_FLOAT_META_FUNC(asinh)
187 CREATE_UNARY_FLOAT_META_FUNC(atan)
188 CREATE_UNARY_FLOAT_META_FUNC(atanh)
189 CREATE_UNARY_FLOAT_META_FUNC(cos)
190 CREATE_UNARY_FLOAT_META_FUNC(cosh)
191 CREATE_UNARY_FLOAT_META_FUNC(digamma)
192 CREATE_UNARY_FLOAT_META_FUNC(erf)
193 CREATE_UNARY_FLOAT_META_FUNC(erfc)
194 CREATE_UNARY_FLOAT_META_FUNC(erfinv)
195 CREATE_UNARY_FLOAT_META_FUNC(exp)
196 CREATE_UNARY_FLOAT_META_FUNC(exp2)
197 CREATE_UNARY_FLOAT_META_FUNC(expm1)
198 CREATE_UNARY_FLOAT_META_FUNC(i0)
199 CREATE_UNARY_FLOAT_META_FUNC(lgamma)
200 CREATE_UNARY_FLOAT_META_FUNC(log)
201 CREATE_UNARY_FLOAT_META_FUNC(log10)
202 CREATE_UNARY_FLOAT_META_FUNC(log1p)
203 CREATE_UNARY_FLOAT_META_FUNC(log2)
204 CREATE_UNARY_FLOAT_META_FUNC(reciprocal)
205 CREATE_UNARY_FLOAT_META_FUNC(rsqrt)
206 CREATE_UNARY_FLOAT_META_FUNC(sigmoid)
207 CREATE_UNARY_FLOAT_META_FUNC(sin)
208 CREATE_UNARY_FLOAT_META_FUNC(sinc)
209 CREATE_UNARY_FLOAT_META_FUNC(sinh)
210 CREATE_UNARY_FLOAT_META_FUNC(special_entr)
211 CREATE_UNARY_FLOAT_META_FUNC(special_erfcx)
212 CREATE_UNARY_FLOAT_META_FUNC(special_i0e)
213 CREATE_UNARY_FLOAT_META_FUNC(special_i1)
214 CREATE_UNARY_FLOAT_META_FUNC(special_i1e)
215 CREATE_UNARY_FLOAT_META_FUNC(special_ndtri)
216 CREATE_UNARY_FLOAT_META_FUNC(special_log_ndtr)
217 CREATE_UNARY_FLOAT_META_FUNC(sqrt)
218 CREATE_UNARY_FLOAT_META_FUNC(tan)
219 CREATE_UNARY_FLOAT_META_FUNC(tanh)
220 CREATE_UNARY_FLOAT_META_FUNC(special_airy_ai)
221 CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j0)
222 CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j1)
223 CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y0)
224 CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y1)
225 CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i0)
226 CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i1)
227 CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k0)
228 CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k1)
229 CREATE_UNARY_FLOAT_META_FUNC(special_scaled_modified_bessel_k0)
230 CREATE_UNARY_FLOAT_META_FUNC(special_scaled_modified_bessel_k1)
231 CREATE_UNARY_FLOAT_META_FUNC(special_spherical_bessel_j0)
232 
233 TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
234   TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
235   build_borrowing_unary_float_op(maybe_get_output(), self);
236 }
237 
238 // These are normal unary ops that preserve dtype
239 #define CREATE_UNARY_META_FUNC(func)                  \
240   TORCH_META_FUNC(func) (const Tensor& self) {        \
241     build_borrowing_unary_op(maybe_get_output(), self);   \
242   }
243 CREATE_UNARY_META_FUNC(bitwise_not)
CREATE_UNARY_META_FUNC(frac)244 CREATE_UNARY_META_FUNC(frac)
245 CREATE_UNARY_META_FUNC(round)
246 CREATE_UNARY_META_FUNC(sgn)
247 
248 TORCH_META_FUNC2(round, decimals)(const Tensor& self, int64_t decimals){
249   build_unary_op(maybe_get_output(), self);
250 }
251 
TORCH_META_FUNC(neg)252 TORCH_META_FUNC(neg)(const Tensor& self) {
253   TORCH_CHECK(self.scalar_type() != kBool,
254               "Negation, the `-` operator, on a bool tensor is not supported. "
255               "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
256   build_borrowing_unary_op(maybe_get_output(), self);
257 }
258 
TORCH_META_FUNC(trunc)259 TORCH_META_FUNC(trunc) (const Tensor& self) {
260   // Note: this is consistent with NumPy
261   TORCH_CHECK(!self.is_complex(),
262     "trunc is not supported for complex inputs");
263   build_borrowing_unary_op(maybe_get_output(), self);
264 }
265 
TORCH_META_FUNC(floor)266 TORCH_META_FUNC(floor) (const Tensor& self) {
267   // Note: this is consistent with NumPy
268   TORCH_CHECK(!self.is_complex(),
269     "floor is not supported for complex inputs");
270   build_borrowing_unary_op(maybe_get_output(), self);
271 }
272 
TORCH_META_FUNC(sign)273 TORCH_META_FUNC(sign) (const Tensor& self) {
274   TORCH_CHECK(!self.is_complex(),
275               "Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
276   build_borrowing_unary_op(maybe_get_output(), self);
277 }
278 
TORCH_META_FUNC(signbit)279 TORCH_META_FUNC(signbit) (const Tensor& self) {
280   TORCH_CHECK(!self.is_complex(), "signbit is not implemented for complex tensors.");
281   TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
282               "signbit does not support non-boolean outputs.");
283   build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
284 }
285 
TORCH_META_FUNC(ceil)286 TORCH_META_FUNC(ceil) (const Tensor& self) {
287   // Note: this is consistent with NumPy
288   TORCH_CHECK(!self.is_complex(),
289     "ceil is not supported for complex inputs");
290   build_borrowing_unary_op(maybe_get_output(), self);
291 }
292 
293 } // namespace at::meta
294 
295 namespace at::native {
296 // NOTE: These are helper functions that reduce redundant code in implementing the most typical kind of unary operators.
297 // YOU ARE NOT OBLIGED TO USE THESE HELPERS---if you're writing something more specialized, please don't try to make
298 // them work for your case, but just write something new instead. Here we use helper functions instead of a flat fat
299 // macro that implements everything, because the former allows some simple preprocessing that are unique to some
300 // operators (more is foreseeable) and is more flexible and elegant than the latter.
301 #define CREATE_UNARY_TORCH_IMPL_FUNC(func_out, func_stub)                                \
302 TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& result) {  \
303   func_stub(device_type(), *this);                                      \
304 }
305 
306 // This macro is as optional as the one above. torch.(ceil|floor|round|trunc) are no-ops for integers
307 // See gh-70918
308 #define CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(func_out, func_stub)                                \
309 TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& result) {  \
310   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) {                                      \
311     result.copy_(self);                                                 \
312   } else {                                                              \
313     func_stub(device_type(), *this);                                    \
314   }                                                                     \
315 }
CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(ceil_out,ceil_stub)316 CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(ceil_out, ceil_stub)
317 CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(floor_out, floor_stub)
318 CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(round_out, round_stub)
319 CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(trunc_out, trunc_stub)
320 
321 CREATE_UNARY_TORCH_IMPL_FUNC(acos_out, acos_stub)
322 CREATE_UNARY_TORCH_IMPL_FUNC(acosh_out, acosh_stub)
323 CREATE_UNARY_TORCH_IMPL_FUNC(asin_out, asin_stub)
324 CREATE_UNARY_TORCH_IMPL_FUNC(asinh_out, asinh_stub)
325 CREATE_UNARY_TORCH_IMPL_FUNC(atan_out, atan_stub)
326 CREATE_UNARY_TORCH_IMPL_FUNC(atanh_out, atanh_stub)
327 CREATE_UNARY_TORCH_IMPL_FUNC(bitwise_not_out, bitwise_not_stub)
328 CREATE_UNARY_TORCH_IMPL_FUNC(cos_out, cos_stub)
329 CREATE_UNARY_TORCH_IMPL_FUNC(cosh_out, cosh_stub)
330 CREATE_UNARY_TORCH_IMPL_FUNC(digamma_out, digamma_stub)
331 CREATE_UNARY_TORCH_IMPL_FUNC(erf_out, erf_stub)
332 CREATE_UNARY_TORCH_IMPL_FUNC(erfc_out, erfc_stub)
333 CREATE_UNARY_TORCH_IMPL_FUNC(erfinv_out, erfinv_stub)
334 CREATE_UNARY_TORCH_IMPL_FUNC(exp_out, exp_stub)
335 CREATE_UNARY_TORCH_IMPL_FUNC(exp2_out, exp2_stub)
336 CREATE_UNARY_TORCH_IMPL_FUNC(expm1_out, expm1_stub)
337 CREATE_UNARY_TORCH_IMPL_FUNC(frac_out, frac_stub)
338 CREATE_UNARY_TORCH_IMPL_FUNC(i0_out, i0_stub)
339 CREATE_UNARY_TORCH_IMPL_FUNC(lgamma_out, lgamma_stub)
340 CREATE_UNARY_TORCH_IMPL_FUNC(log_out, log_stub)
341 CREATE_UNARY_TORCH_IMPL_FUNC(log10_out, log10_stub)
342 CREATE_UNARY_TORCH_IMPL_FUNC(log1p_out, log1p_stub)
343 CREATE_UNARY_TORCH_IMPL_FUNC(log2_out, log2_stub)
344 CREATE_UNARY_TORCH_IMPL_FUNC(neg_out, neg_stub)
345 CREATE_UNARY_TORCH_IMPL_FUNC(reciprocal_out, reciprocal_stub)
346 CREATE_UNARY_TORCH_IMPL_FUNC(rsqrt_out, rsqrt_stub)
347 CREATE_UNARY_TORCH_IMPL_FUNC(sigmoid_out, sigmoid_stub)
348 CREATE_UNARY_TORCH_IMPL_FUNC(sign_out, sign_stub)
349 CREATE_UNARY_TORCH_IMPL_FUNC(sin_out, sin_stub)
350 CREATE_UNARY_TORCH_IMPL_FUNC(sinc_out, sinc_stub)
351 CREATE_UNARY_TORCH_IMPL_FUNC(sinh_out, sinh_stub)
352 CREATE_UNARY_TORCH_IMPL_FUNC(special_entr_out, special_entr_stub)
353 CREATE_UNARY_TORCH_IMPL_FUNC(special_erfcx_out, special_erfcx_stub)
354 CREATE_UNARY_TORCH_IMPL_FUNC(special_i0e_out, special_i0e_stub)
355 CREATE_UNARY_TORCH_IMPL_FUNC(special_i1e_out, special_i1e_stub)
356 CREATE_UNARY_TORCH_IMPL_FUNC(special_i1_out, special_i1_stub)
357 CREATE_UNARY_TORCH_IMPL_FUNC(special_ndtri_out, special_ndtri_stub)
358 CREATE_UNARY_TORCH_IMPL_FUNC(special_log_ndtr_out, special_log_ndtr_stub)
359 CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
360 CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
361 CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
362 CREATE_UNARY_TORCH_IMPL_FUNC(special_airy_ai_out, special_airy_ai_stub)
363 CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub)
364 CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub)
365 CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y0_out, special_bessel_y0_stub)
366 CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y1_out, special_bessel_y1_stub)
367 CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i0_out, special_modified_bessel_i0_stub)
368 CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i1_out, special_modified_bessel_i1_stub)
369 CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k0_out, special_modified_bessel_k0_stub)
370 CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k1_out, special_modified_bessel_k1_stub)
371 CREATE_UNARY_TORCH_IMPL_FUNC(special_scaled_modified_bessel_k0_out, special_scaled_modified_bessel_k0_stub)
372 CREATE_UNARY_TORCH_IMPL_FUNC(special_scaled_modified_bessel_k1_out, special_scaled_modified_bessel_k1_stub)
373 CREATE_UNARY_TORCH_IMPL_FUNC(special_spherical_bessel_j0_out, special_spherical_bessel_j0_stub)
374 
375 TORCH_IMPL_FUNC(round_decimals_out)
376 (const Tensor& self, int64_t decimals, const Tensor& result) {
377   if (decimals != 0) {
378     round_decimals_stub(device_type(), *this, decimals);
379   } else {
380     round_stub(device_type(), *this);
381   }
382 }
383 
TORCH_IMPL_FUNC(polygamma_out)384 TORCH_IMPL_FUNC(polygamma_out)
385 (int64_t n, const Tensor& self, const Tensor& result) {
386   polygamma_stub(device_type(), *this, n);
387 }
388 
TORCH_IMPL_FUNC(signbit_out)389 TORCH_IMPL_FUNC(signbit_out) (const Tensor& self, const Tensor& result) {
390   if (self.dtype() == at::kBool) {
391     result.fill_(false);
392   } else {
393     signbit_stub(device_type(), *this);
394   }
395 }
396 
397 // since polygamma_ has different signature from its
398 // out and functional variant, we explicitly
399 // define it (instead of using structured kernel).
polygamma_(Tensor & self,int64_t n)400 Tensor& polygamma_(Tensor& self, int64_t n) {
401   return at::polygamma_out(self, n, self);
402 }
403 
404 template <typename Stub>
unary_op_impl_out(Tensor & result,const Tensor & self,Stub & stub)405 static inline Tensor& unary_op_impl_out(Tensor& result, const Tensor& self, Stub& stub) {
406   auto iter = TensorIterator::unary_op(result, self);
407   stub(iter.device_type(), iter);
408   return result;
409 }
410 
411 template <typename Stub, typename ...Args>
unary_op_impl_float_out(Tensor & result,const Tensor & self,Stub & stub,Args...args)412 static inline Tensor& unary_op_impl_float_out(Tensor& result, const Tensor& self, Stub& stub, Args... args) {
413   auto iter = TensorIterator::unary_float_op(result, self);
414   stub(iter.device_type(), iter, args...);
415   return result;
416 }
417 
418 template <typename Stub, typename ...Args>
unary_op_impl_float(const Tensor & self,Stub & stub,Args...args)419 static inline Tensor unary_op_impl_float(const Tensor& self, Stub& stub, Args... args) {
420   Tensor result;
421   auto iter = TensorIterator::unary_float_op(result, self);
422   stub(iter.device_type(), iter, args...);
423   return iter.output();
424 }
425 
426 // An alternate version of unary_op_impl_out that follows the same pattern
427 // for non-complex inputs, but returns a floating point tensor
428 // for complex inputs by default.
429 // Note: This is done by running the operation as usual and then copying the
430 // operation's result to the expected result type.
431 template <typename Stub>
unary_op_impl_with_complex_to_float_out(Tensor & result,const Tensor & self,Stub & stub,bool promotes_integer_to_float)432 static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, const Tensor& self, Stub& stub, bool promotes_integer_to_float) {
433     if (self.is_complex() && !result.is_complex()) {
434       // Checks if the corresponding float type can be cast to the desired dtype
435       const auto float_type = c10::toRealValueType(self.scalar_type());
436       TORCH_CHECK(canCast(float_type, result.scalar_type()),
437             "result type ", float_type, " can't be cast to the desired output type ",
438             result.scalar_type());
439 
440       // Runs the function complex->complex, as TensorIterator expects
441       Tensor complex_result = at::empty({0}, self.options());
442       auto iter = TensorIterator::unary_op(complex_result, self);
443       stub(iter.device_type(), iter);
444 
445       // Copies the complex result to the actual result and returns it
446       at::native::resize_output(result, complex_result.sizes());
447       result.copy_(at::real(complex_result));
448       return result;
449     }
450 
451     if (promotes_integer_to_float) {
452       return unary_op_impl_float_out(result, self, stub);
453     }
454 
455     return unary_op_impl_out(result, self, stub);
456 }
457 
458 // out_impl passed into unary_op_impl and unary_op_impl_  must go through at:: device dispatch
459 // otherwise it won't dispatch to out-of-source devices like XLA.
460 // For example it must be at::bitwise_not_out instead of bitwise_not_out(which is at::native!).
461 template <typename OutImpl>
unary_op_impl(const Tensor & self,OutImpl & out_impl)462 static inline Tensor unary_op_impl(const Tensor& self, OutImpl& out_impl) {
463   Tensor result = at::empty({0}, self.options());
464   return out_impl(result, self);
465 }
466 
467 // An alternate version of unary_op_impl that follows the same pattern
468 // for non-complex inputs, but returns a floating point tensor
469 // for complex inputs by default.
470 template <typename OutImpl>
unary_op_impl_with_complex_to_float(const Tensor & self,OutImpl & out_impl)471 static inline Tensor unary_op_impl_with_complex_to_float(const Tensor& self, OutImpl& out_impl) {
472   if (self.is_complex()) {
473     const auto float_type = c10::toRealValueType(self.scalar_type());
474     Tensor result = at::empty_like(self, self.options().dtype(float_type));
475     return out_impl(result, self);
476   }
477 
478   Tensor result = at::empty({0}, self.options());
479   return out_impl(result, self);
480 }
481 
482 template <typename OutImpl>
unary_op_impl_(Tensor & self,OutImpl & out_impl)483 static inline Tensor& unary_op_impl_(Tensor& self, OutImpl& out_impl) {
484   return out_impl(self, self);
485 }
486 
487 // arccos, alias for acos
arccos_out(const Tensor & self,Tensor & result)488 Tensor& arccos_out(const Tensor& self, Tensor& result) { return at::acos_out(result, self); }
arccos(const Tensor & self)489 Tensor arccos(const Tensor& self) { return self.acos(); }
arccos_(Tensor & self)490 Tensor& arccos_(Tensor& self) { return self.acos_(); }
491 
rad2deg_out(const Tensor & self,Tensor & result)492 Tensor& rad2deg_out(const Tensor& self, Tensor& result) {
493   TORCH_CHECK(!self.is_complex(), "rad2deg is not supported for complex tensors.");
494   constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564;
495   return at::mul_out(result, self, wrapped_scalar_tensor(Scalar(M_180_PI)));
496 }
rad2deg(const Tensor & self)497 Tensor rad2deg(const Tensor& self) {
498   // Note: int-> float promotion handled differently from other Unary ops,
499   // as it does not use the usual TensorIterator + Kernel Dispatch pattern.
500   auto options = self.options();
501   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
502     options = options.dtype(c10::get_default_dtype());
503   }
504   auto result = at::empty_like(self, options);
505   at::rad2deg_out(result, self);
506   return result;
507 }
rad2deg_(Tensor & self)508 Tensor& rad2deg_(Tensor& self) { return unary_op_impl_(self, at::rad2deg_out); }
509 
deg2rad_out(const Tensor & self,Tensor & result)510 Tensor& deg2rad_out(const Tensor& self, Tensor& result) {
511   TORCH_CHECK(!self.is_complex(), "deg2rad is not supported for complex tensors.");
512   constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417;
513   return at::mul_out(result, self, wrapped_scalar_tensor(Scalar(M_PI_180)));
514 }
deg2rad(const Tensor & self)515 Tensor deg2rad(const Tensor& self) {
516   // Note: int-> float promotion handled differently from other Unary ops,
517   // as it does not use the usual TensorIterator + Kernel Dispatch pattern.
518   auto options = self.options();
519   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
520     options = options.dtype(c10::get_default_dtype());
521   }
522   auto result = at::empty_like(self, options);
523   at::deg2rad_out(result, self);
524   return result;
525 }
deg2rad_(Tensor & self)526 Tensor& deg2rad_(Tensor& self) { return unary_op_impl_(self, at::deg2rad_out); }
527 
528 // arcsin, alias of asin
arcsin_out(const Tensor & self,Tensor & result)529 Tensor& arcsin_out(const Tensor& self, Tensor& result) { return at::asin_out(result, self); }
arcsin(const Tensor & self)530 Tensor arcsin(const Tensor& self) { return self.asin(); }
arcsin_(Tensor & self)531 Tensor& arcsin_(Tensor& self) { return self.asin_(); }
532 
533 // arctan, alias of atan
arctan_out(const Tensor & self,Tensor & result)534 Tensor& arctan_out(const Tensor& self, Tensor& result) { return at::atan_out(result, self); }
arctan(const Tensor & self)535 Tensor arctan(const Tensor& self) { return self.atan(); }
arctan_(Tensor & self)536 Tensor& arctan_(Tensor& self) { return self.atan_(); }
537 
538 // Note [Complex abs and angle]
539 // Complex inputs to abs and angle return float results by default.
540 // abs and angle, in both NumPy and C++, returns a float result when given a
541 // complex input. This makes sense mathematically since the absolute value
542 // and angle of a complex number has no imaginary part.
abs_out(const Tensor & self,Tensor & result)543 Tensor& abs_out(const Tensor& self, Tensor& result) {
544   return unary_op_impl_with_complex_to_float_out(result, self, abs_stub, /*promotes_integer_to_float=*/false);
545 }
abs(const Tensor & self)546 Tensor abs(const Tensor& self) {
547   return unary_op_impl_with_complex_to_float(self, at::abs_out);
548 }
abs_(Tensor & self)549 Tensor& abs_(Tensor& self) {
550   TORCH_CHECK(!self.is_complex(), "In-place abs is not supported for complex tensors.");
551   return unary_op_impl_(self, at::abs_out);
552 }
553 
554 // Absolute, alias for abs
absolute_out(const Tensor & self,Tensor & result)555 Tensor& absolute_out(const Tensor& self, Tensor& result) {
556   return at::abs_out(result, self);
557 }
absolute(const Tensor & self)558 Tensor absolute(const Tensor& self) {
559   return self.abs();
560 }
absolute_(Tensor & self)561 Tensor& absolute_(Tensor& self) {
562   return self.abs_();
563 }
564 
angle_out(const Tensor & self,Tensor & result)565 Tensor& angle_out(const Tensor& self, Tensor& result) {
566   return unary_op_impl_with_complex_to_float_out(result, self, angle_stub, /*promotes_integer_to_float=*/true);
567 }
angle(const Tensor & self)568 Tensor angle(const Tensor& self) {
569   if (self.is_complex()) {
570     const auto float_type = c10::toRealValueType(self.scalar_type());
571     Tensor result = at::empty({0}, self.options().dtype(float_type));
572     return at::angle_out(result, self);
573   }
574 
575   return unary_op_impl_float(self, angle_stub);
576 }
577 
real(const Tensor & self)578 Tensor real(const Tensor& self) {
579   if (self.is_complex()) {
580     Tensor real_tensor;
581     if (self.is_conj()) {
582       real_tensor = at::view_as_real(self._conj());
583     } else {
584       real_tensor = at::view_as_real(self);
585     }
586     return at::select(real_tensor, real_tensor.dim() - 1, 0);
587   } else {
588     return self;
589   }
590 }
591 
_neg_view(const Tensor & self)592 Tensor _neg_view(const Tensor& self) {
593   Tensor self_ = self.alias();
594   self_._set_neg(!self.is_neg());
595   namedinference::propagate_names(self_, self);
596   return self_;
597 }
598 
imag(const Tensor & self)599 Tensor imag(const Tensor& self) {
600   if (self.is_complex()) {
601     Tensor real_tensor;
602     if (self.is_conj()) {
603       real_tensor = at::view_as_real(self._conj());
604       // preemptively set the negative flag for the final imag tensor
605       real_tensor = real_tensor._neg_view();
606     } else {
607       real_tensor = at::view_as_real(self);
608     }
609     return at::select(real_tensor, real_tensor.dim() - 1, 1);
610   } else {
611     TORCH_CHECK(false, "imag is not implemented for tensors with non-complex dtypes.");
612   }
613 }
614 
conj_physical_out(const Tensor & self,Tensor & result)615 Tensor& conj_physical_out(const Tensor& self, Tensor& result) {
616   return unary_op_impl_out(result, self, conj_physical_stub);
617 }
618 
_conj_physical(const Tensor & self)619 Tensor _conj_physical(const Tensor& self) {
620   if (self.is_conj()) {
621     return self.conj().clone();
622   }
623   return unary_op_impl(self, at::conj_physical_out);
624 }
625 
conj_physical(const Tensor & self)626 Tensor conj_physical(const Tensor& self) {
627   if (!self.is_complex()) return self;
628   return at::_conj_physical(self);
629 }
630 
conj_physical_(Tensor & self)631 Tensor& conj_physical_(Tensor& self) {
632   if (!self.is_complex()) return self;
633   return unary_op_impl_out(self, self, conj_physical_stub);
634 }
635 
636 // No op if the neg bit is not set
637 // else returns a new negated tensor with neg bit set to 0
resolve_neg(const Tensor & self)638 Tensor resolve_neg(const Tensor& self) {
639   if (!self.is_neg()) { return self; }
640   // negation is materialized in `copy_()` that clone ultimately calls into
641   return self.clone();
642 }
643 
644 // No op if the conj bit is not set
645 // else returns a new negated tensor with neg bit set to 0
resolve_conj(const Tensor & self)646 Tensor resolve_conj(const Tensor& self) {
647   if (!self.is_conj()) { return self; }
648   // conjugation is materialized in `copy_()` that clone ultimately calls into
649   return self.clone();
650 }
651 
_conj(const Tensor & self)652 Tensor _conj(const Tensor& self) {
653   Tensor self_ = self.alias();
654   self_._set_conj(!self.is_conj());
655   namedinference::propagate_names(self_, self);
656   return self_;
657 }
658 
conj(const Tensor & self)659 Tensor conj(const Tensor& self) {
660   // This might look like an infinite recursion but it's not.
661   // This actually calls into `conj()` defined in the Tensor class.
662   return self.conj();
663 }
664 
665 // special_exp2, alias for exp2
special_exp2_out(const Tensor & self,Tensor & result)666 Tensor& special_exp2_out(const Tensor& self, Tensor& result) { return at::exp2_out(result, self); }
special_exp2(const Tensor & self)667 Tensor special_exp2(const Tensor& self) { return self.exp2(); }
668 
669 // special_expm1, alias for expm1
special_expm1_out(const Tensor & self,Tensor & result)670 Tensor& special_expm1_out(const Tensor& self, Tensor& result) { return at::expm1_out(result, self); }
special_expm1(const Tensor & self)671 Tensor special_expm1(const Tensor& self) { return self.expm1(); }
672 
673 // special_erf, alias for erf
special_erf_out(const Tensor & self,Tensor & result)674 Tensor& special_erf_out(const Tensor& self, Tensor& result) { return at::erf_out(result, self); }
special_erf(const Tensor & self)675 Tensor special_erf(const Tensor& self) { return self.erf(); }
676 
677 // special_erfc, alias for erfc
special_erfc_out(const Tensor & self,Tensor & result)678 Tensor& special_erfc_out(const Tensor& self, Tensor& result) { return at::erfc_out(result, self); }
special_erfc(const Tensor & self)679 Tensor special_erfc(const Tensor& self) { return self.erfc(); }
680 
681 // special_erfinv, alias for erfinv
special_erfinv_out(const Tensor & self,Tensor & result)682 Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); }
special_erfinv(const Tensor & self)683 Tensor special_erfinv(const Tensor& self) { return self.erfinv(); }
684 
685 // special_polygamma, alias for polygamma
special_polygamma_out(int64_t n,const Tensor & self,Tensor & result)686 Tensor& special_polygamma_out(int64_t n, const Tensor& self, Tensor& result) { return at::polygamma_out(result, n, self); }
special_polygamma(int64_t n,const Tensor & self)687 Tensor special_polygamma(int64_t n, const Tensor& self) { return self.polygamma(n); }
688 
689 // special_psi, alias for digamma
special_psi_out(const Tensor & self,Tensor & result)690 Tensor& special_psi_out(const Tensor& self, Tensor& result) { return at::digamma_out(result, self); }
special_psi(const Tensor & self)691 Tensor special_psi(const Tensor& self) { return self.digamma(); }
692 // special_digamma, alias for digamma
special_digamma_out(const Tensor & self,Tensor & result)693 Tensor& special_digamma_out(const Tensor& self, Tensor& result) { return at::digamma_out(result, self); }
special_digamma(const Tensor & self)694 Tensor special_digamma(const Tensor& self) { return self.digamma(); }
695 
696 // special_i0, alias for i0
special_i0_out(const Tensor & self,Tensor & result)697 Tensor& special_i0_out(const Tensor& self, Tensor& result) { return at::i0_out(result, self); }
special_i0(const Tensor & self)698 Tensor special_i0(const Tensor& self) { return self.i0(); }
699 
700 // special_log1p, alias for log1p
special_log1p_out(const Tensor & self,Tensor & result)701 Tensor& special_log1p_out(const Tensor& self, Tensor& result) { return at::log1p_out(result, self); }
special_log1p(const Tensor & self)702 Tensor special_log1p(const Tensor& self) { return self.log1p(); }
703 
704 // special_round, alias for round
special_round_out(const Tensor & self,int64_t decimals,Tensor & result)705 Tensor& special_round_out(const Tensor& self, int64_t decimals, Tensor& result) { return at::round_out(result, self, decimals); }
special_round(const Tensor & self,int64_t decimals)706 Tensor special_round(const Tensor& self, int64_t decimals) { return self.round(decimals); }
707 
708 // special_sinc, alias for sinc
special_sinc_out(const Tensor & self,Tensor & result)709 Tensor& special_sinc_out(const Tensor& self, Tensor& result) { return at::sinc_out(result, self); }
special_sinc(const Tensor & self)710 Tensor special_sinc(const Tensor& self) { return self.sinc(); }
711 
712 namespace {
713 
calc_ndtr(const Tensor & self)714 inline Tensor calc_ndtr(const Tensor& self) {
715   auto x_sqrt_2 = self * M_SQRT1_2;
716   return (1 + at::erf(x_sqrt_2)) * 0.5;
717 }
718 
719 } // namespace
720 
721 // special_ndtr
special_ndtr_out(const Tensor & self,Tensor & result)722 Tensor& special_ndtr_out(const Tensor& self, Tensor& result) {
723   TORCH_CHECK(
724       self.device() == result.device(),
725       "Expected all tensors to be on the same device, but found at least two devices, ",
726       self.device(),
727       " and ",
728       result.device(),
729       "!");
730 
731   auto ndtr = calc_ndtr(self);
732   TORCH_CHECK(
733       at::can_cast(ndtr.scalar_type(), result.scalar_type()),
734       "result type ",
735       ndtr.scalar_type(),
736       " can't be cast to the desired output type ",
737       result.scalar_type());
738 
739   at::native::resize_output(result, ndtr.sizes());
740   return result.copy_(ndtr);
741 }
special_ndtr(const Tensor & self)742 Tensor special_ndtr(const Tensor& self) {
743   return calc_ndtr(self);
744 }
745 
746 // FIXME: remove const_cast once unary_op_impl_out is updated
TORCH_IMPL_FUNC(sgn_out)747 TORCH_IMPL_FUNC(sgn_out) (const Tensor& self, const Tensor& result) {
748   if (self.is_complex()) {
749     sgn_stub(device_type(), *this);
750   } else {
751     sign_stub(device_type(), *this);
752   }
753 }
754 
755 // arccosh, alias for acosh
arccosh_out(const Tensor & self,Tensor & result)756 Tensor& arccosh_out(const Tensor& self, Tensor& result) { return at::acosh_out(result, self); }
arccosh(const Tensor & self)757 Tensor arccosh(const Tensor& self) { return at::acosh(self); }
arccosh_(Tensor & self)758 Tensor& arccosh_(Tensor& self) { return at::acosh_(self); }
759 
760 // arcsinh, alias for asinh
arcsinh_out(const Tensor & self,Tensor & result)761 Tensor& arcsinh_out(const Tensor& self, Tensor& result) { return at::asinh_out(result, self); }
arcsinh(const Tensor & self)762 Tensor arcsinh(const Tensor& self) { return self.asinh(); }
arcsinh_(Tensor & self)763 Tensor& arcsinh_(Tensor& self) { return self.asinh_(); }
764 
765 // arctanh, alias for atanh
arctanh_out(const Tensor & self,Tensor & result)766 Tensor& arctanh_out(const Tensor& self, Tensor& result) { return at::atanh_out(result, self); }
arctanh(const Tensor & self)767 Tensor arctanh(const Tensor& self) { return self.atanh(); }
arctanh_(Tensor & self)768 Tensor& arctanh_(Tensor& self) { return self.atanh_(); }
769 
square_out(const Tensor & self,Tensor & result)770 Tensor& square_out(const Tensor& self, Tensor& result) { return at::pow_out(result, self, 2); }
square(const Tensor & self)771 Tensor square(const Tensor& self) { return at::pow(self, 2); }
square_(Tensor & self)772 Tensor& square_(Tensor& self) { return self.pow_(2); }
773 
logit_out(const Tensor & self,std::optional<double> eps,Tensor & result)774 Tensor& logit_out(const Tensor& self,
775     std::optional<double> eps,
776     Tensor& result) {
777   return unary_op_impl_float_out(
778       result, self, logit_stub, Scalar(eps ? eps.value() : -1.0));
779 }
logit(const Tensor & self,std::optional<double> eps)780 Tensor logit(const Tensor& self, std::optional<double> eps) {
781   return unary_op_impl_float(
782       self, logit_stub, Scalar(eps ? eps.value() : -1.0));
783 }
logit_(Tensor & self,std::optional<double> eps)784 Tensor& logit_(Tensor& self, std::optional<double> eps) {
785   return at::logit_out(self, self, eps);
786 }
787 
special_logit_out(const Tensor & self,std::optional<double> eps,Tensor & result)788 Tensor& special_logit_out(const Tensor& self, std::optional<double> eps, Tensor& result) {
789   return at::logit_out(result, self, eps);
790 }
special_logit(const Tensor & self,std::optional<double> eps)791 Tensor special_logit(const Tensor& self, std::optional<double> eps) {
792   return self.logit(eps);
793 }
794 
795 // special_expit, alias for sigmoid
special_expit_out(const Tensor & self,Tensor & result)796 Tensor& special_expit_out(const Tensor& self, Tensor& result) {
797   return at::sigmoid_out(result, self);
798 }
special_expit(const Tensor & self)799 Tensor special_expit(const Tensor& self) {
800   return self.sigmoid();
801 }
802 
nan_to_num_out(const Tensor & self,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf,Tensor & result)803 Tensor& nan_to_num_out(const Tensor& self,
804     std::optional<double> nan,
805     std::optional<double> pos_inf,
806     std::optional<double> neg_inf,
807     Tensor& result) {
808   TORCH_CHECK(
809       self.scalar_type() == result.scalar_type(),
810       "nan_to_num: dtype of out: ",
811       result.scalar_type(),
812       " should be same as input: ",
813       self.scalar_type());
814 
815   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
816     at::native::resize_output(result, self.sizes());
817     result.copy_(self);
818     return result;
819   }
820 
821   auto iter = TensorIterator::unary_op(result, self);
822   nan_to_num_stub(iter.device_type(), iter, nan, pos_inf, neg_inf);
823   return result;
824 }
825 
nan_to_num(const Tensor & self,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf)826 Tensor nan_to_num(
827     const Tensor& self,
828     std::optional<double> nan,
829     std::optional<double> pos_inf,
830     std::optional<double> neg_inf) {
831   auto result = at::empty_like(self);
832   return at::nan_to_num_out(result, self, nan, pos_inf, neg_inf);
833 }
834 
nan_to_num_(Tensor & self,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf)835 Tensor& nan_to_num_(
836     Tensor& self,
837     std::optional<double> nan,
838     std::optional<double> pos_inf,
839     std::optional<double> neg_inf) {
840   return at::nan_to_num_out(self, self, nan, pos_inf, neg_inf);
841 }
842 
843 // Alias for trunc
fix_out(const Tensor & self,Tensor & result)844 Tensor& fix_out(const Tensor& self, Tensor& result) { return at::trunc_out(result, self); }
fix(const Tensor & self)845 Tensor fix(const Tensor& self) { return self.trunc(); }
fix_(Tensor & self)846 Tensor& fix_(Tensor& self) { return self.trunc_(); }
847 
positive(const Tensor & self)848 Tensor positive(const Tensor& self) {
849   TORCH_CHECK(self.scalar_type() != kBool, "The `+` operator, on a bool tensor is not supported.");
850   return self;
851 }
852 
negative_out(const Tensor & self,Tensor & result)853 Tensor& negative_out(const Tensor& self, Tensor& result) { return at::neg_out(result, self); }
negative(const Tensor & self)854 Tensor negative(const Tensor& self) { return self.neg(); }
negative_(Tensor & self)855 Tensor& negative_(Tensor& self) { return self.neg_(); }
856 
logical_not(const Tensor & self)857 Tensor logical_not(const Tensor& self) {
858   Tensor result = at::empty({0}, self.options().dtype(kBool));
859   return at::logical_not_out(result, self);
860 }
861 
logical_not_(Tensor & self)862 Tensor& logical_not_(Tensor& self) {
863   return at::logical_not_out(self, self);
864 }
865 
logical_not_out(const Tensor & self,Tensor & result)866 Tensor& logical_not_out(const Tensor& self, Tensor& result) {
867   TensorIterator iter = TensorIteratorConfig()
868     .check_all_same_dtype(false)
869     .add_output(result)
870     .add_const_input(self)
871     .build();
872   logical_not_stub(iter.device_type(), iter);
873   return result;
874 }
875 
876 namespace {
877 constexpr double HALF = 0.5;
878 constexpr double QUARTER = 0.25;
879 }
880 
mvlgamma_check(const Tensor & self,int64_t p)881 static inline void mvlgamma_check(const Tensor& self, int64_t p) {
882   TORCH_CHECK(self.scalar_type() != kBool, "The input tensor may not be a boolean tensor.");
883   TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1");
884 }
885 
mvlgamma(const Tensor & self,int64_t p)886 Tensor mvlgamma(const Tensor& self, int64_t p) {
887   mvlgamma_check(self, p);
888   auto dtype = c10::scalarTypeToTypeMeta(self.scalar_type());
889   if (at::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
890     // int -> float promotion
891     dtype = c10::get_default_dtype();
892   }
893   Tensor args = native::arange(
894       -p * HALF + HALF,
895       HALF,
896       HALF,
897       optTypeMetaToScalarType(dtype),
898       self.options().layout_opt(),
899       self.options().device_opt(),
900       self.options().pinned_memory_opt());
901   args = args.add(self.unsqueeze(-1));
902   const auto p2_sub_p = static_cast<double>(p * (p - 1));
903   return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
904 }
905 
mvlgamma_(Tensor & self,int64_t p)906 Tensor& mvlgamma_(Tensor& self, int64_t p) {
907   mvlgamma_check(self, p);
908   Tensor args = native::arange(
909       -p *HALF  + HALF,
910       HALF,
911       HALF,
912       optTypeMetaToScalarType(self.options().dtype_opt()),
913       self.options().layout_opt(),
914       self.options().device_opt(),
915       self.options().pinned_memory_opt());
916   args = args.add(self.unsqueeze(-1));
917   const auto p2_sub_p = static_cast<double>(p * (p - 1));
918   return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
919 }
920 
mvlgamma_out(const Tensor & self,int64_t p,Tensor & result)921 Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {
922   auto out = self.mvlgamma(p);
923   TORCH_CHECK(
924       at::can_cast(out.scalar_type(), result.scalar_type()),
925       "mvlgamma: result type ",
926       self.scalar_type(),
927       " can't be cast to the desired output type ",
928       out.scalar_type());
929   at::native::resize_output(result, out.sizes());
930   return result.copy_(out);
931 }
932 
special_multigammaln(const Tensor & self,int64_t p)933 Tensor special_multigammaln(const Tensor& self, int64_t p) {
934   return self.mvlgamma(p);
935 };
936 
special_multigammaln_out(const Tensor & self,int64_t p,Tensor & result)937 Tensor& special_multigammaln_out(const Tensor& self, int64_t p, Tensor& result) {
938   return at::mvlgamma_out(result, self, p);
939 };
940 
frexp(const Tensor & self)941 std::tuple<Tensor, Tensor> frexp(const Tensor& self) {
942   Tensor mantissa = at::empty_like(self);
943   Tensor exponent = at::empty_like(self, self.options().dtype(at::kInt));
944 
945   at::frexp_out(mantissa, exponent, self);
946   return std::tuple<Tensor, Tensor>(mantissa, exponent);
947 }
948 
frexp_out(const Tensor & self,Tensor & mantissa,Tensor & exponent)949 std::tuple<Tensor&, Tensor&> frexp_out(const Tensor& self,
950                                        Tensor& mantissa, Tensor& exponent) {
951   // torch.frexp is implemented for floating-point dtypes for now,
952   // should add support for integral dtypes in the future.
953   TORCH_CHECK(at::isFloatingType(self.scalar_type()),
954               "torch.frexp() only supports floating-point dtypes");
955 
956   TORCH_CHECK(mantissa.dtype() == self.dtype(),
957               "torch.frexp() expects mantissa to have dtype ", self.dtype(),
958               " but got ", mantissa.dtype());
959   TORCH_CHECK(exponent.dtype() == at::kInt,
960               "torch.frexp() expects exponent to have int dtype "
961               "but got ", exponent.dtype());
962 
963   auto iter = TensorIteratorConfig()
964     .add_output(mantissa)
965     .add_output(exponent)
966     .add_const_input(self)
967     .check_all_same_dtype(false)
968     .set_check_mem_overlap(true)
969     .build();
970   frexp_stub(iter.device_type(), iter);
971 
972   return std::tuple<Tensor&, Tensor&>(mantissa, exponent);
973 }
974 
975 // alias for lgamma, implements special.gammaln equivalent to
976 // scipy.special.gammaln
special_gammaln(const Tensor & self)977 Tensor special_gammaln(const Tensor& self) { return self.lgamma(); }
special_gammaln_out(const Tensor & self,Tensor & result)978 Tensor& special_gammaln_out(const Tensor& self, Tensor& result) { return at::lgamma_out(result, self); }
979 
980 DEFINE_DISPATCH(abs_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
981 DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
982 DEFINE_DISPATCH(conj_physical_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
983 DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
984 DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
985 DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
986 DEFINE_DISPATCH(atanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
987 DEFINE_DISPATCH(asin_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
988 DEFINE_DISPATCH(atan_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
989 DEFINE_DISPATCH(bitwise_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
990 DEFINE_DISPATCH(ceil_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
991 DEFINE_DISPATCH(cos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
992 DEFINE_DISPATCH(cosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
993 DEFINE_DISPATCH(digamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
994 DEFINE_DISPATCH(special_entr_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
995 DEFINE_DISPATCH(special_erfcx_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
996 DEFINE_DISPATCH(erf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
997 DEFINE_DISPATCH(erfc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
998 DEFINE_DISPATCH(erfinv_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
999 DEFINE_DISPATCH(exp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1000 DEFINE_DISPATCH(exp2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1001 DEFINE_DISPATCH(expm1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1002 DEFINE_DISPATCH(floor_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1003 DEFINE_DISPATCH(frac_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1004 DEFINE_DISPATCH(frexp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1005 DEFINE_DISPATCH(i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1006 DEFINE_DISPATCH(special_i0e_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1007 DEFINE_DISPATCH(special_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1008 DEFINE_DISPATCH(special_i1e_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1009 DEFINE_DISPATCH(log_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1010 DEFINE_DISPATCH(log10_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1011 DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1012 DEFINE_DISPATCH(log2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1013 DEFINE_DISPATCH(logical_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1014 DEFINE_DISPATCH(special_ndtri_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1015 DEFINE_DISPATCH(special_log_ndtr_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1016 DEFINE_DISPATCH(neg_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1017 DEFINE_DISPATCH(nan_to_num_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1018 DEFINE_DISPATCH(polygamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1019 DEFINE_DISPATCH(reciprocal_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1020 DEFINE_DISPATCH(round_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1021 DEFINE_DISPATCH(round_decimals_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1022 DEFINE_DISPATCH(rsqrt_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1023 DEFINE_DISPATCH(sigmoid_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1024 DEFINE_DISPATCH(logit_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1025 DEFINE_DISPATCH(sign_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1026 DEFINE_DISPATCH(signbit_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1027 DEFINE_DISPATCH(sgn_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1028 DEFINE_DISPATCH(sin_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1029 DEFINE_DISPATCH(sinc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1030 DEFINE_DISPATCH(sinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1031 DEFINE_DISPATCH(sqrt_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1032 DEFINE_DISPATCH(tan_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1033 DEFINE_DISPATCH(tanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1034 DEFINE_DISPATCH(trigamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1035 DEFINE_DISPATCH(trunc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1036 DEFINE_DISPATCH(lgamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1037 DEFINE_DISPATCH(special_airy_ai_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1038 DEFINE_DISPATCH(special_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1039 DEFINE_DISPATCH(special_bessel_j1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1040 DEFINE_DISPATCH(special_bessel_y0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1041 DEFINE_DISPATCH(special_bessel_y1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1042 DEFINE_DISPATCH(special_modified_bessel_i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1043 DEFINE_DISPATCH(special_modified_bessel_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1044 DEFINE_DISPATCH(special_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1045 DEFINE_DISPATCH(special_modified_bessel_k1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1046 DEFINE_DISPATCH(special_scaled_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1047 DEFINE_DISPATCH(special_scaled_modified_bessel_k1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1048 DEFINE_DISPATCH(special_spherical_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1049 
1050 } // namespace at::native
1051