1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/MemoryOverlap.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/ScalarOps.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/TensorOperators.h>
10 #include <ATen/WrapDimUtils.h>
11
12 #include <ATen/native/Resize.h>
13 #include <ATen/native/UnaryOps.h>
14 #include <ATen/native/ComplexHelper.h>
15
16 #include <c10/util/MathConstants.h>
17
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/_conj_native.h>
23 #include <ATen/ops/_conj_physical.h>
24 #include <ATen/ops/_conj_physical_native.h>
25 #include <ATen/ops/_neg_view_native.h>
26 #include <ATen/ops/abs.h>
27 #include <ATen/ops/abs_native.h>
28 #include <ATen/ops/absolute_native.h>
29 #include <ATen/ops/acos.h>
30 #include <ATen/ops/acos_native.h>
31 #include <ATen/ops/acosh.h>
32 #include <ATen/ops/acosh_native.h>
33 #include <ATen/ops/angle.h>
34 #include <ATen/ops/angle_native.h>
35 #include <ATen/ops/arange_native.h>
36 #include <ATen/ops/arccos_native.h>
37 #include <ATen/ops/arccosh_native.h>
38 #include <ATen/ops/arcsin_native.h>
39 #include <ATen/ops/arcsinh_native.h>
40 #include <ATen/ops/arctan_native.h>
41 #include <ATen/ops/arctanh_native.h>
42 #include <ATen/ops/asin.h>
43 #include <ATen/ops/asin_native.h>
44 #include <ATen/ops/asinh.h>
45 #include <ATen/ops/asinh_native.h>
46 #include <ATen/ops/atan.h>
47 #include <ATen/ops/atan_native.h>
48 #include <ATen/ops/atanh.h>
49 #include <ATen/ops/atanh_native.h>
50 #include <ATen/ops/bitwise_not_native.h>
51 #include <ATen/ops/can_cast.h>
52 #include <ATen/ops/ceil_native.h>
53 #include <ATen/ops/conj_native.h>
54 #include <ATen/ops/conj_physical.h>
55 #include <ATen/ops/conj_physical_native.h>
56 #include <ATen/ops/cos_native.h>
57 #include <ATen/ops/cosh_native.h>
58 #include <ATen/ops/deg2rad.h>
59 #include <ATen/ops/deg2rad_native.h>
60 #include <ATen/ops/digamma.h>
61 #include <ATen/ops/digamma_native.h>
62 #include <ATen/ops/empty.h>
63 #include <ATen/ops/empty_like.h>
64 #include <ATen/ops/erf.h>
65 #include <ATen/ops/erf_native.h>
66 #include <ATen/ops/erfc.h>
67 #include <ATen/ops/erfc_native.h>
68 #include <ATen/ops/erfinv.h>
69 #include <ATen/ops/erfinv_native.h>
70 #include <ATen/ops/exp2.h>
71 #include <ATen/ops/exp2_native.h>
72 #include <ATen/ops/exp_native.h>
73 #include <ATen/ops/expm1.h>
74 #include <ATen/ops/expm1_native.h>
75 #include <ATen/ops/fix_native.h>
76 #include <ATen/ops/floor_native.h>
77 #include <ATen/ops/frac_native.h>
78 #include <ATen/ops/frexp.h>
79 #include <ATen/ops/frexp_native.h>
80 #include <ATen/ops/i0.h>
81 #include <ATen/ops/i0_native.h>
82 #include <ATen/ops/imag_native.h>
83 #include <ATen/ops/lgamma.h>
84 #include <ATen/ops/lgamma_native.h>
85 #include <ATen/ops/log10_native.h>
86 #include <ATen/ops/log1p.h>
87 #include <ATen/ops/log1p_native.h>
88 #include <ATen/ops/log2_native.h>
89 #include <ATen/ops/log_native.h>
90 #include <ATen/ops/logical_not.h>
91 #include <ATen/ops/logical_not_native.h>
92 #include <ATen/ops/logit.h>
93 #include <ATen/ops/logit_native.h>
94 #include <ATen/ops/mul.h>
95 #include <ATen/ops/mvlgamma.h>
96 #include <ATen/ops/mvlgamma_native.h>
97 #include <ATen/ops/nan_to_num.h>
98 #include <ATen/ops/nan_to_num_native.h>
99 #include <ATen/ops/neg.h>
100 #include <ATen/ops/neg_native.h>
101 #include <ATen/ops/negative_native.h>
102 #include <ATen/ops/polygamma.h>
103 #include <ATen/ops/polygamma_native.h>
104 #include <ATen/ops/positive_native.h>
105 #include <ATen/ops/pow.h>
106 #include <ATen/ops/rad2deg.h>
107 #include <ATen/ops/rad2deg_native.h>
108 #include <ATen/ops/real.h>
109 #include <ATen/ops/real_native.h>
110 #include <ATen/ops/reciprocal_native.h>
111 #include <ATen/ops/resolve_conj_native.h>
112 #include <ATen/ops/resolve_neg_native.h>
113 #include <ATen/ops/round.h>
114 #include <ATen/ops/round_native.h>
115 #include <ATen/ops/rsqrt_native.h>
116 #include <ATen/ops/select.h>
117 #include <ATen/ops/sgn_native.h>
118 #include <ATen/ops/sigmoid.h>
119 #include <ATen/ops/sigmoid_native.h>
120 #include <ATen/ops/sign_native.h>
121 #include <ATen/ops/signbit_native.h>
122 #include <ATen/ops/sin_native.h>
123 #include <ATen/ops/sinc.h>
124 #include <ATen/ops/sinc_native.h>
125 #include <ATen/ops/sinh_native.h>
126 #include <ATen/ops/special_airy_ai_native.h>
127 #include <ATen/ops/special_bessel_j0_native.h>
128 #include <ATen/ops/special_bessel_j1_native.h>
129 #include <ATen/ops/special_bessel_y0_native.h>
130 #include <ATen/ops/special_bessel_y1_native.h>
131 #include <ATen/ops/special_digamma_native.h>
132 #include <ATen/ops/special_entr_native.h>
133 #include <ATen/ops/special_erf_native.h>
134 #include <ATen/ops/special_erfc_native.h>
135 #include <ATen/ops/special_erfcx_native.h>
136 #include <ATen/ops/special_erfinv_native.h>
137 #include <ATen/ops/special_exp2_native.h>
138 #include <ATen/ops/special_expit_native.h>
139 #include <ATen/ops/special_expm1_native.h>
140 #include <ATen/ops/special_gammaln_native.h>
141 #include <ATen/ops/special_i0_native.h>
142 #include <ATen/ops/special_i0e_native.h>
143 #include <ATen/ops/special_i1_native.h>
144 #include <ATen/ops/special_i1e_native.h>
145 #include <ATen/ops/special_log1p_native.h>
146 #include <ATen/ops/special_log_ndtr_native.h>
147 #include <ATen/ops/special_logit_native.h>
148 #include <ATen/ops/special_modified_bessel_i0_native.h>
149 #include <ATen/ops/special_modified_bessel_i1_native.h>
150 #include <ATen/ops/special_modified_bessel_k0_native.h>
151 #include <ATen/ops/special_modified_bessel_k1_native.h>
152 #include <ATen/ops/special_multigammaln_native.h>
153 #include <ATen/ops/special_ndtr_native.h>
154 #include <ATen/ops/special_ndtri_native.h>
155 #include <ATen/ops/special_polygamma_native.h>
156 #include <ATen/ops/special_psi_native.h>
157 #include <ATen/ops/special_round_native.h>
158 #include <ATen/ops/special_scaled_modified_bessel_k0_native.h>
159 #include <ATen/ops/special_scaled_modified_bessel_k1_native.h>
160 #include <ATen/ops/special_sinc_native.h>
161 #include <ATen/ops/special_spherical_bessel_j0_native.h>
162 #include <ATen/ops/sqrt_native.h>
163 #include <ATen/ops/square_native.h>
164 #include <ATen/ops/tan_native.h>
165 #include <ATen/ops/tanh_native.h>
166 #include <ATen/ops/trunc.h>
167 #include <ATen/ops/trunc_native.h>
168 #include <ATen/ops/view_as_real.h>
169 #endif
170
171 #include <cmath>
172
173 namespace at::meta {
174
175 // Unary float operations always produce floating point
176 // outputs for floating point and integral types
177 // For complex inputs, the output type should be the same as input type.
178 #define CREATE_UNARY_FLOAT_META_FUNC(func) \
179 TORCH_META_FUNC(func) (const Tensor& self) { \
180 build_borrowing_unary_float_op(maybe_get_output(), self); \
181 }
182
183 CREATE_UNARY_FLOAT_META_FUNC(acos)
CREATE_UNARY_FLOAT_META_FUNC(acosh)184 CREATE_UNARY_FLOAT_META_FUNC(acosh)
185 CREATE_UNARY_FLOAT_META_FUNC(asin)
186 CREATE_UNARY_FLOAT_META_FUNC(asinh)
187 CREATE_UNARY_FLOAT_META_FUNC(atan)
188 CREATE_UNARY_FLOAT_META_FUNC(atanh)
189 CREATE_UNARY_FLOAT_META_FUNC(cos)
190 CREATE_UNARY_FLOAT_META_FUNC(cosh)
191 CREATE_UNARY_FLOAT_META_FUNC(digamma)
192 CREATE_UNARY_FLOAT_META_FUNC(erf)
193 CREATE_UNARY_FLOAT_META_FUNC(erfc)
194 CREATE_UNARY_FLOAT_META_FUNC(erfinv)
195 CREATE_UNARY_FLOAT_META_FUNC(exp)
196 CREATE_UNARY_FLOAT_META_FUNC(exp2)
197 CREATE_UNARY_FLOAT_META_FUNC(expm1)
198 CREATE_UNARY_FLOAT_META_FUNC(i0)
199 CREATE_UNARY_FLOAT_META_FUNC(lgamma)
200 CREATE_UNARY_FLOAT_META_FUNC(log)
201 CREATE_UNARY_FLOAT_META_FUNC(log10)
202 CREATE_UNARY_FLOAT_META_FUNC(log1p)
203 CREATE_UNARY_FLOAT_META_FUNC(log2)
204 CREATE_UNARY_FLOAT_META_FUNC(reciprocal)
205 CREATE_UNARY_FLOAT_META_FUNC(rsqrt)
206 CREATE_UNARY_FLOAT_META_FUNC(sigmoid)
207 CREATE_UNARY_FLOAT_META_FUNC(sin)
208 CREATE_UNARY_FLOAT_META_FUNC(sinc)
209 CREATE_UNARY_FLOAT_META_FUNC(sinh)
210 CREATE_UNARY_FLOAT_META_FUNC(special_entr)
211 CREATE_UNARY_FLOAT_META_FUNC(special_erfcx)
212 CREATE_UNARY_FLOAT_META_FUNC(special_i0e)
213 CREATE_UNARY_FLOAT_META_FUNC(special_i1)
214 CREATE_UNARY_FLOAT_META_FUNC(special_i1e)
215 CREATE_UNARY_FLOAT_META_FUNC(special_ndtri)
216 CREATE_UNARY_FLOAT_META_FUNC(special_log_ndtr)
217 CREATE_UNARY_FLOAT_META_FUNC(sqrt)
218 CREATE_UNARY_FLOAT_META_FUNC(tan)
219 CREATE_UNARY_FLOAT_META_FUNC(tanh)
220 CREATE_UNARY_FLOAT_META_FUNC(special_airy_ai)
221 CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j0)
222 CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j1)
223 CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y0)
224 CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y1)
225 CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i0)
226 CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i1)
227 CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k0)
228 CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k1)
229 CREATE_UNARY_FLOAT_META_FUNC(special_scaled_modified_bessel_k0)
230 CREATE_UNARY_FLOAT_META_FUNC(special_scaled_modified_bessel_k1)
231 CREATE_UNARY_FLOAT_META_FUNC(special_spherical_bessel_j0)
232
233 TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
234 TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
235 build_borrowing_unary_float_op(maybe_get_output(), self);
236 }
237
238 // These are normal unary ops that preserve dtype
239 #define CREATE_UNARY_META_FUNC(func) \
240 TORCH_META_FUNC(func) (const Tensor& self) { \
241 build_borrowing_unary_op(maybe_get_output(), self); \
242 }
243 CREATE_UNARY_META_FUNC(bitwise_not)
CREATE_UNARY_META_FUNC(frac)244 CREATE_UNARY_META_FUNC(frac)
245 CREATE_UNARY_META_FUNC(round)
246 CREATE_UNARY_META_FUNC(sgn)
247
248 TORCH_META_FUNC2(round, decimals)(const Tensor& self, int64_t decimals){
249 build_unary_op(maybe_get_output(), self);
250 }
251
TORCH_META_FUNC(neg)252 TORCH_META_FUNC(neg)(const Tensor& self) {
253 TORCH_CHECK(self.scalar_type() != kBool,
254 "Negation, the `-` operator, on a bool tensor is not supported. "
255 "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
256 build_borrowing_unary_op(maybe_get_output(), self);
257 }
258
TORCH_META_FUNC(trunc)259 TORCH_META_FUNC(trunc) (const Tensor& self) {
260 // Note: this is consistent with NumPy
261 TORCH_CHECK(!self.is_complex(),
262 "trunc is not supported for complex inputs");
263 build_borrowing_unary_op(maybe_get_output(), self);
264 }
265
TORCH_META_FUNC(floor)266 TORCH_META_FUNC(floor) (const Tensor& self) {
267 // Note: this is consistent with NumPy
268 TORCH_CHECK(!self.is_complex(),
269 "floor is not supported for complex inputs");
270 build_borrowing_unary_op(maybe_get_output(), self);
271 }
272
TORCH_META_FUNC(sign)273 TORCH_META_FUNC(sign) (const Tensor& self) {
274 TORCH_CHECK(!self.is_complex(),
275 "Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
276 build_borrowing_unary_op(maybe_get_output(), self);
277 }
278
TORCH_META_FUNC(signbit)279 TORCH_META_FUNC(signbit) (const Tensor& self) {
280 TORCH_CHECK(!self.is_complex(), "signbit is not implemented for complex tensors.");
281 TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
282 "signbit does not support non-boolean outputs.");
283 build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
284 }
285
TORCH_META_FUNC(ceil)286 TORCH_META_FUNC(ceil) (const Tensor& self) {
287 // Note: this is consistent with NumPy
288 TORCH_CHECK(!self.is_complex(),
289 "ceil is not supported for complex inputs");
290 build_borrowing_unary_op(maybe_get_output(), self);
291 }
292
293 } // namespace at::meta
294
295 namespace at::native {
296 // NOTE: These are helper functions that reduce redundant code in implementing the most typical kind of unary operators.
297 // YOU ARE NOT OBLIGED TO USE THESE HELPERS---if you're writing something more specialized, please don't try to make
298 // them work for your case, but just write something new instead. Here we use helper functions instead of a flat fat
299 // macro that implements everything, because the former allows some simple preprocessing that are unique to some
300 // operators (more is foreseeable) and is more flexible and elegant than the latter.
301 #define CREATE_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
302 TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& result) { \
303 func_stub(device_type(), *this); \
304 }
305
306 // This macro is as optional as the one above. torch.(ceil|floor|round|trunc) are no-ops for integers
307 // See gh-70918
308 #define CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(func_out, func_stub) \
309 TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& result) { \
310 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) { \
311 result.copy_(self); \
312 } else { \
313 func_stub(device_type(), *this); \
314 } \
315 }
CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(ceil_out,ceil_stub)316 CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(ceil_out, ceil_stub)
317 CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(floor_out, floor_stub)
318 CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(round_out, round_stub)
319 CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(trunc_out, trunc_stub)
320
321 CREATE_UNARY_TORCH_IMPL_FUNC(acos_out, acos_stub)
322 CREATE_UNARY_TORCH_IMPL_FUNC(acosh_out, acosh_stub)
323 CREATE_UNARY_TORCH_IMPL_FUNC(asin_out, asin_stub)
324 CREATE_UNARY_TORCH_IMPL_FUNC(asinh_out, asinh_stub)
325 CREATE_UNARY_TORCH_IMPL_FUNC(atan_out, atan_stub)
326 CREATE_UNARY_TORCH_IMPL_FUNC(atanh_out, atanh_stub)
327 CREATE_UNARY_TORCH_IMPL_FUNC(bitwise_not_out, bitwise_not_stub)
328 CREATE_UNARY_TORCH_IMPL_FUNC(cos_out, cos_stub)
329 CREATE_UNARY_TORCH_IMPL_FUNC(cosh_out, cosh_stub)
330 CREATE_UNARY_TORCH_IMPL_FUNC(digamma_out, digamma_stub)
331 CREATE_UNARY_TORCH_IMPL_FUNC(erf_out, erf_stub)
332 CREATE_UNARY_TORCH_IMPL_FUNC(erfc_out, erfc_stub)
333 CREATE_UNARY_TORCH_IMPL_FUNC(erfinv_out, erfinv_stub)
334 CREATE_UNARY_TORCH_IMPL_FUNC(exp_out, exp_stub)
335 CREATE_UNARY_TORCH_IMPL_FUNC(exp2_out, exp2_stub)
336 CREATE_UNARY_TORCH_IMPL_FUNC(expm1_out, expm1_stub)
337 CREATE_UNARY_TORCH_IMPL_FUNC(frac_out, frac_stub)
338 CREATE_UNARY_TORCH_IMPL_FUNC(i0_out, i0_stub)
339 CREATE_UNARY_TORCH_IMPL_FUNC(lgamma_out, lgamma_stub)
340 CREATE_UNARY_TORCH_IMPL_FUNC(log_out, log_stub)
341 CREATE_UNARY_TORCH_IMPL_FUNC(log10_out, log10_stub)
342 CREATE_UNARY_TORCH_IMPL_FUNC(log1p_out, log1p_stub)
343 CREATE_UNARY_TORCH_IMPL_FUNC(log2_out, log2_stub)
344 CREATE_UNARY_TORCH_IMPL_FUNC(neg_out, neg_stub)
345 CREATE_UNARY_TORCH_IMPL_FUNC(reciprocal_out, reciprocal_stub)
346 CREATE_UNARY_TORCH_IMPL_FUNC(rsqrt_out, rsqrt_stub)
347 CREATE_UNARY_TORCH_IMPL_FUNC(sigmoid_out, sigmoid_stub)
348 CREATE_UNARY_TORCH_IMPL_FUNC(sign_out, sign_stub)
349 CREATE_UNARY_TORCH_IMPL_FUNC(sin_out, sin_stub)
350 CREATE_UNARY_TORCH_IMPL_FUNC(sinc_out, sinc_stub)
351 CREATE_UNARY_TORCH_IMPL_FUNC(sinh_out, sinh_stub)
352 CREATE_UNARY_TORCH_IMPL_FUNC(special_entr_out, special_entr_stub)
353 CREATE_UNARY_TORCH_IMPL_FUNC(special_erfcx_out, special_erfcx_stub)
354 CREATE_UNARY_TORCH_IMPL_FUNC(special_i0e_out, special_i0e_stub)
355 CREATE_UNARY_TORCH_IMPL_FUNC(special_i1e_out, special_i1e_stub)
356 CREATE_UNARY_TORCH_IMPL_FUNC(special_i1_out, special_i1_stub)
357 CREATE_UNARY_TORCH_IMPL_FUNC(special_ndtri_out, special_ndtri_stub)
358 CREATE_UNARY_TORCH_IMPL_FUNC(special_log_ndtr_out, special_log_ndtr_stub)
359 CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
360 CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
361 CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
362 CREATE_UNARY_TORCH_IMPL_FUNC(special_airy_ai_out, special_airy_ai_stub)
363 CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub)
364 CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub)
365 CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y0_out, special_bessel_y0_stub)
366 CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y1_out, special_bessel_y1_stub)
367 CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i0_out, special_modified_bessel_i0_stub)
368 CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i1_out, special_modified_bessel_i1_stub)
369 CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k0_out, special_modified_bessel_k0_stub)
370 CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k1_out, special_modified_bessel_k1_stub)
371 CREATE_UNARY_TORCH_IMPL_FUNC(special_scaled_modified_bessel_k0_out, special_scaled_modified_bessel_k0_stub)
372 CREATE_UNARY_TORCH_IMPL_FUNC(special_scaled_modified_bessel_k1_out, special_scaled_modified_bessel_k1_stub)
373 CREATE_UNARY_TORCH_IMPL_FUNC(special_spherical_bessel_j0_out, special_spherical_bessel_j0_stub)
374
375 TORCH_IMPL_FUNC(round_decimals_out)
376 (const Tensor& self, int64_t decimals, const Tensor& result) {
377 if (decimals != 0) {
378 round_decimals_stub(device_type(), *this, decimals);
379 } else {
380 round_stub(device_type(), *this);
381 }
382 }
383
TORCH_IMPL_FUNC(polygamma_out)384 TORCH_IMPL_FUNC(polygamma_out)
385 (int64_t n, const Tensor& self, const Tensor& result) {
386 polygamma_stub(device_type(), *this, n);
387 }
388
TORCH_IMPL_FUNC(signbit_out)389 TORCH_IMPL_FUNC(signbit_out) (const Tensor& self, const Tensor& result) {
390 if (self.dtype() == at::kBool) {
391 result.fill_(false);
392 } else {
393 signbit_stub(device_type(), *this);
394 }
395 }
396
397 // since polygamma_ has different signature from its
398 // out and functional variant, we explicitly
399 // define it (instead of using structured kernel).
polygamma_(Tensor & self,int64_t n)400 Tensor& polygamma_(Tensor& self, int64_t n) {
401 return at::polygamma_out(self, n, self);
402 }
403
404 template <typename Stub>
unary_op_impl_out(Tensor & result,const Tensor & self,Stub & stub)405 static inline Tensor& unary_op_impl_out(Tensor& result, const Tensor& self, Stub& stub) {
406 auto iter = TensorIterator::unary_op(result, self);
407 stub(iter.device_type(), iter);
408 return result;
409 }
410
411 template <typename Stub, typename ...Args>
unary_op_impl_float_out(Tensor & result,const Tensor & self,Stub & stub,Args...args)412 static inline Tensor& unary_op_impl_float_out(Tensor& result, const Tensor& self, Stub& stub, Args... args) {
413 auto iter = TensorIterator::unary_float_op(result, self);
414 stub(iter.device_type(), iter, args...);
415 return result;
416 }
417
418 template <typename Stub, typename ...Args>
unary_op_impl_float(const Tensor & self,Stub & stub,Args...args)419 static inline Tensor unary_op_impl_float(const Tensor& self, Stub& stub, Args... args) {
420 Tensor result;
421 auto iter = TensorIterator::unary_float_op(result, self);
422 stub(iter.device_type(), iter, args...);
423 return iter.output();
424 }
425
426 // An alternate version of unary_op_impl_out that follows the same pattern
427 // for non-complex inputs, but returns a floating point tensor
428 // for complex inputs by default.
429 // Note: This is done by running the operation as usual and then copying the
430 // operation's result to the expected result type.
431 template <typename Stub>
unary_op_impl_with_complex_to_float_out(Tensor & result,const Tensor & self,Stub & stub,bool promotes_integer_to_float)432 static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, const Tensor& self, Stub& stub, bool promotes_integer_to_float) {
433 if (self.is_complex() && !result.is_complex()) {
434 // Checks if the corresponding float type can be cast to the desired dtype
435 const auto float_type = c10::toRealValueType(self.scalar_type());
436 TORCH_CHECK(canCast(float_type, result.scalar_type()),
437 "result type ", float_type, " can't be cast to the desired output type ",
438 result.scalar_type());
439
440 // Runs the function complex->complex, as TensorIterator expects
441 Tensor complex_result = at::empty({0}, self.options());
442 auto iter = TensorIterator::unary_op(complex_result, self);
443 stub(iter.device_type(), iter);
444
445 // Copies the complex result to the actual result and returns it
446 at::native::resize_output(result, complex_result.sizes());
447 result.copy_(at::real(complex_result));
448 return result;
449 }
450
451 if (promotes_integer_to_float) {
452 return unary_op_impl_float_out(result, self, stub);
453 }
454
455 return unary_op_impl_out(result, self, stub);
456 }
457
458 // out_impl passed into unary_op_impl and unary_op_impl_ must go through at:: device dispatch
459 // otherwise it won't dispatch to out-of-source devices like XLA.
460 // For example it must be at::bitwise_not_out instead of bitwise_not_out(which is at::native!).
461 template <typename OutImpl>
unary_op_impl(const Tensor & self,OutImpl & out_impl)462 static inline Tensor unary_op_impl(const Tensor& self, OutImpl& out_impl) {
463 Tensor result = at::empty({0}, self.options());
464 return out_impl(result, self);
465 }
466
467 // An alternate version of unary_op_impl that follows the same pattern
468 // for non-complex inputs, but returns a floating point tensor
469 // for complex inputs by default.
470 template <typename OutImpl>
unary_op_impl_with_complex_to_float(const Tensor & self,OutImpl & out_impl)471 static inline Tensor unary_op_impl_with_complex_to_float(const Tensor& self, OutImpl& out_impl) {
472 if (self.is_complex()) {
473 const auto float_type = c10::toRealValueType(self.scalar_type());
474 Tensor result = at::empty_like(self, self.options().dtype(float_type));
475 return out_impl(result, self);
476 }
477
478 Tensor result = at::empty({0}, self.options());
479 return out_impl(result, self);
480 }
481
482 template <typename OutImpl>
unary_op_impl_(Tensor & self,OutImpl & out_impl)483 static inline Tensor& unary_op_impl_(Tensor& self, OutImpl& out_impl) {
484 return out_impl(self, self);
485 }
486
487 // arccos, alias for acos
arccos_out(const Tensor & self,Tensor & result)488 Tensor& arccos_out(const Tensor& self, Tensor& result) { return at::acos_out(result, self); }
arccos(const Tensor & self)489 Tensor arccos(const Tensor& self) { return self.acos(); }
arccos_(Tensor & self)490 Tensor& arccos_(Tensor& self) { return self.acos_(); }
491
rad2deg_out(const Tensor & self,Tensor & result)492 Tensor& rad2deg_out(const Tensor& self, Tensor& result) {
493 TORCH_CHECK(!self.is_complex(), "rad2deg is not supported for complex tensors.");
494 constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564;
495 return at::mul_out(result, self, wrapped_scalar_tensor(Scalar(M_180_PI)));
496 }
rad2deg(const Tensor & self)497 Tensor rad2deg(const Tensor& self) {
498 // Note: int-> float promotion handled differently from other Unary ops,
499 // as it does not use the usual TensorIterator + Kernel Dispatch pattern.
500 auto options = self.options();
501 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
502 options = options.dtype(c10::get_default_dtype());
503 }
504 auto result = at::empty_like(self, options);
505 at::rad2deg_out(result, self);
506 return result;
507 }
rad2deg_(Tensor & self)508 Tensor& rad2deg_(Tensor& self) { return unary_op_impl_(self, at::rad2deg_out); }
509
deg2rad_out(const Tensor & self,Tensor & result)510 Tensor& deg2rad_out(const Tensor& self, Tensor& result) {
511 TORCH_CHECK(!self.is_complex(), "deg2rad is not supported for complex tensors.");
512 constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417;
513 return at::mul_out(result, self, wrapped_scalar_tensor(Scalar(M_PI_180)));
514 }
deg2rad(const Tensor & self)515 Tensor deg2rad(const Tensor& self) {
516 // Note: int-> float promotion handled differently from other Unary ops,
517 // as it does not use the usual TensorIterator + Kernel Dispatch pattern.
518 auto options = self.options();
519 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
520 options = options.dtype(c10::get_default_dtype());
521 }
522 auto result = at::empty_like(self, options);
523 at::deg2rad_out(result, self);
524 return result;
525 }
deg2rad_(Tensor & self)526 Tensor& deg2rad_(Tensor& self) { return unary_op_impl_(self, at::deg2rad_out); }
527
528 // arcsin, alias of asin
arcsin_out(const Tensor & self,Tensor & result)529 Tensor& arcsin_out(const Tensor& self, Tensor& result) { return at::asin_out(result, self); }
arcsin(const Tensor & self)530 Tensor arcsin(const Tensor& self) { return self.asin(); }
arcsin_(Tensor & self)531 Tensor& arcsin_(Tensor& self) { return self.asin_(); }
532
533 // arctan, alias of atan
arctan_out(const Tensor & self,Tensor & result)534 Tensor& arctan_out(const Tensor& self, Tensor& result) { return at::atan_out(result, self); }
arctan(const Tensor & self)535 Tensor arctan(const Tensor& self) { return self.atan(); }
arctan_(Tensor & self)536 Tensor& arctan_(Tensor& self) { return self.atan_(); }
537
538 // Note [Complex abs and angle]
539 // Complex inputs to abs and angle return float results by default.
540 // abs and angle, in both NumPy and C++, returns a float result when given a
541 // complex input. This makes sense mathematically since the absolute value
542 // and angle of a complex number has no imaginary part.
abs_out(const Tensor & self,Tensor & result)543 Tensor& abs_out(const Tensor& self, Tensor& result) {
544 return unary_op_impl_with_complex_to_float_out(result, self, abs_stub, /*promotes_integer_to_float=*/false);
545 }
abs(const Tensor & self)546 Tensor abs(const Tensor& self) {
547 return unary_op_impl_with_complex_to_float(self, at::abs_out);
548 }
abs_(Tensor & self)549 Tensor& abs_(Tensor& self) {
550 TORCH_CHECK(!self.is_complex(), "In-place abs is not supported for complex tensors.");
551 return unary_op_impl_(self, at::abs_out);
552 }
553
554 // Absolute, alias for abs
absolute_out(const Tensor & self,Tensor & result)555 Tensor& absolute_out(const Tensor& self, Tensor& result) {
556 return at::abs_out(result, self);
557 }
absolute(const Tensor & self)558 Tensor absolute(const Tensor& self) {
559 return self.abs();
560 }
absolute_(Tensor & self)561 Tensor& absolute_(Tensor& self) {
562 return self.abs_();
563 }
564
angle_out(const Tensor & self,Tensor & result)565 Tensor& angle_out(const Tensor& self, Tensor& result) {
566 return unary_op_impl_with_complex_to_float_out(result, self, angle_stub, /*promotes_integer_to_float=*/true);
567 }
angle(const Tensor & self)568 Tensor angle(const Tensor& self) {
569 if (self.is_complex()) {
570 const auto float_type = c10::toRealValueType(self.scalar_type());
571 Tensor result = at::empty({0}, self.options().dtype(float_type));
572 return at::angle_out(result, self);
573 }
574
575 return unary_op_impl_float(self, angle_stub);
576 }
577
real(const Tensor & self)578 Tensor real(const Tensor& self) {
579 if (self.is_complex()) {
580 Tensor real_tensor;
581 if (self.is_conj()) {
582 real_tensor = at::view_as_real(self._conj());
583 } else {
584 real_tensor = at::view_as_real(self);
585 }
586 return at::select(real_tensor, real_tensor.dim() - 1, 0);
587 } else {
588 return self;
589 }
590 }
591
_neg_view(const Tensor & self)592 Tensor _neg_view(const Tensor& self) {
593 Tensor self_ = self.alias();
594 self_._set_neg(!self.is_neg());
595 namedinference::propagate_names(self_, self);
596 return self_;
597 }
598
imag(const Tensor & self)599 Tensor imag(const Tensor& self) {
600 if (self.is_complex()) {
601 Tensor real_tensor;
602 if (self.is_conj()) {
603 real_tensor = at::view_as_real(self._conj());
604 // preemptively set the negative flag for the final imag tensor
605 real_tensor = real_tensor._neg_view();
606 } else {
607 real_tensor = at::view_as_real(self);
608 }
609 return at::select(real_tensor, real_tensor.dim() - 1, 1);
610 } else {
611 TORCH_CHECK(false, "imag is not implemented for tensors with non-complex dtypes.");
612 }
613 }
614
conj_physical_out(const Tensor & self,Tensor & result)615 Tensor& conj_physical_out(const Tensor& self, Tensor& result) {
616 return unary_op_impl_out(result, self, conj_physical_stub);
617 }
618
_conj_physical(const Tensor & self)619 Tensor _conj_physical(const Tensor& self) {
620 if (self.is_conj()) {
621 return self.conj().clone();
622 }
623 return unary_op_impl(self, at::conj_physical_out);
624 }
625
conj_physical(const Tensor & self)626 Tensor conj_physical(const Tensor& self) {
627 if (!self.is_complex()) return self;
628 return at::_conj_physical(self);
629 }
630
conj_physical_(Tensor & self)631 Tensor& conj_physical_(Tensor& self) {
632 if (!self.is_complex()) return self;
633 return unary_op_impl_out(self, self, conj_physical_stub);
634 }
635
636 // No op if the neg bit is not set
637 // else returns a new negated tensor with neg bit set to 0
resolve_neg(const Tensor & self)638 Tensor resolve_neg(const Tensor& self) {
639 if (!self.is_neg()) { return self; }
640 // negation is materialized in `copy_()` that clone ultimately calls into
641 return self.clone();
642 }
643
644 // No op if the conj bit is not set
645 // else returns a new negated tensor with neg bit set to 0
resolve_conj(const Tensor & self)646 Tensor resolve_conj(const Tensor& self) {
647 if (!self.is_conj()) { return self; }
648 // conjugation is materialized in `copy_()` that clone ultimately calls into
649 return self.clone();
650 }
651
_conj(const Tensor & self)652 Tensor _conj(const Tensor& self) {
653 Tensor self_ = self.alias();
654 self_._set_conj(!self.is_conj());
655 namedinference::propagate_names(self_, self);
656 return self_;
657 }
658
conj(const Tensor & self)659 Tensor conj(const Tensor& self) {
660 // This might look like an infinite recursion but it's not.
661 // This actually calls into `conj()` defined in the Tensor class.
662 return self.conj();
663 }
664
665 // special_exp2, alias for exp2
special_exp2_out(const Tensor & self,Tensor & result)666 Tensor& special_exp2_out(const Tensor& self, Tensor& result) { return at::exp2_out(result, self); }
special_exp2(const Tensor & self)667 Tensor special_exp2(const Tensor& self) { return self.exp2(); }
668
669 // special_expm1, alias for expm1
special_expm1_out(const Tensor & self,Tensor & result)670 Tensor& special_expm1_out(const Tensor& self, Tensor& result) { return at::expm1_out(result, self); }
special_expm1(const Tensor & self)671 Tensor special_expm1(const Tensor& self) { return self.expm1(); }
672
673 // special_erf, alias for erf
special_erf_out(const Tensor & self,Tensor & result)674 Tensor& special_erf_out(const Tensor& self, Tensor& result) { return at::erf_out(result, self); }
special_erf(const Tensor & self)675 Tensor special_erf(const Tensor& self) { return self.erf(); }
676
677 // special_erfc, alias for erfc
special_erfc_out(const Tensor & self,Tensor & result)678 Tensor& special_erfc_out(const Tensor& self, Tensor& result) { return at::erfc_out(result, self); }
special_erfc(const Tensor & self)679 Tensor special_erfc(const Tensor& self) { return self.erfc(); }
680
681 // special_erfinv, alias for erfinv
special_erfinv_out(const Tensor & self,Tensor & result)682 Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); }
special_erfinv(const Tensor & self)683 Tensor special_erfinv(const Tensor& self) { return self.erfinv(); }
684
685 // special_polygamma, alias for polygamma
special_polygamma_out(int64_t n,const Tensor & self,Tensor & result)686 Tensor& special_polygamma_out(int64_t n, const Tensor& self, Tensor& result) { return at::polygamma_out(result, n, self); }
special_polygamma(int64_t n,const Tensor & self)687 Tensor special_polygamma(int64_t n, const Tensor& self) { return self.polygamma(n); }
688
689 // special_psi, alias for digamma
special_psi_out(const Tensor & self,Tensor & result)690 Tensor& special_psi_out(const Tensor& self, Tensor& result) { return at::digamma_out(result, self); }
special_psi(const Tensor & self)691 Tensor special_psi(const Tensor& self) { return self.digamma(); }
692 // special_digamma, alias for digamma
special_digamma_out(const Tensor & self,Tensor & result)693 Tensor& special_digamma_out(const Tensor& self, Tensor& result) { return at::digamma_out(result, self); }
special_digamma(const Tensor & self)694 Tensor special_digamma(const Tensor& self) { return self.digamma(); }
695
696 // special_i0, alias for i0
special_i0_out(const Tensor & self,Tensor & result)697 Tensor& special_i0_out(const Tensor& self, Tensor& result) { return at::i0_out(result, self); }
special_i0(const Tensor & self)698 Tensor special_i0(const Tensor& self) { return self.i0(); }
699
700 // special_log1p, alias for log1p
special_log1p_out(const Tensor & self,Tensor & result)701 Tensor& special_log1p_out(const Tensor& self, Tensor& result) { return at::log1p_out(result, self); }
special_log1p(const Tensor & self)702 Tensor special_log1p(const Tensor& self) { return self.log1p(); }
703
704 // special_round, alias for round
special_round_out(const Tensor & self,int64_t decimals,Tensor & result)705 Tensor& special_round_out(const Tensor& self, int64_t decimals, Tensor& result) { return at::round_out(result, self, decimals); }
special_round(const Tensor & self,int64_t decimals)706 Tensor special_round(const Tensor& self, int64_t decimals) { return self.round(decimals); }
707
708 // special_sinc, alias for sinc
special_sinc_out(const Tensor & self,Tensor & result)709 Tensor& special_sinc_out(const Tensor& self, Tensor& result) { return at::sinc_out(result, self); }
special_sinc(const Tensor & self)710 Tensor special_sinc(const Tensor& self) { return self.sinc(); }
711
712 namespace {
713
calc_ndtr(const Tensor & self)714 inline Tensor calc_ndtr(const Tensor& self) {
715 auto x_sqrt_2 = self * M_SQRT1_2;
716 return (1 + at::erf(x_sqrt_2)) * 0.5;
717 }
718
719 } // namespace
720
721 // special_ndtr
special_ndtr_out(const Tensor & self,Tensor & result)722 Tensor& special_ndtr_out(const Tensor& self, Tensor& result) {
723 TORCH_CHECK(
724 self.device() == result.device(),
725 "Expected all tensors to be on the same device, but found at least two devices, ",
726 self.device(),
727 " and ",
728 result.device(),
729 "!");
730
731 auto ndtr = calc_ndtr(self);
732 TORCH_CHECK(
733 at::can_cast(ndtr.scalar_type(), result.scalar_type()),
734 "result type ",
735 ndtr.scalar_type(),
736 " can't be cast to the desired output type ",
737 result.scalar_type());
738
739 at::native::resize_output(result, ndtr.sizes());
740 return result.copy_(ndtr);
741 }
special_ndtr(const Tensor & self)742 Tensor special_ndtr(const Tensor& self) {
743 return calc_ndtr(self);
744 }
745
746 // FIXME: remove const_cast once unary_op_impl_out is updated
TORCH_IMPL_FUNC(sgn_out)747 TORCH_IMPL_FUNC(sgn_out) (const Tensor& self, const Tensor& result) {
748 if (self.is_complex()) {
749 sgn_stub(device_type(), *this);
750 } else {
751 sign_stub(device_type(), *this);
752 }
753 }
754
755 // arccosh, alias for acosh
arccosh_out(const Tensor & self,Tensor & result)756 Tensor& arccosh_out(const Tensor& self, Tensor& result) { return at::acosh_out(result, self); }
arccosh(const Tensor & self)757 Tensor arccosh(const Tensor& self) { return at::acosh(self); }
arccosh_(Tensor & self)758 Tensor& arccosh_(Tensor& self) { return at::acosh_(self); }
759
760 // arcsinh, alias for asinh
arcsinh_out(const Tensor & self,Tensor & result)761 Tensor& arcsinh_out(const Tensor& self, Tensor& result) { return at::asinh_out(result, self); }
arcsinh(const Tensor & self)762 Tensor arcsinh(const Tensor& self) { return self.asinh(); }
arcsinh_(Tensor & self)763 Tensor& arcsinh_(Tensor& self) { return self.asinh_(); }
764
765 // arctanh, alias for atanh
arctanh_out(const Tensor & self,Tensor & result)766 Tensor& arctanh_out(const Tensor& self, Tensor& result) { return at::atanh_out(result, self); }
arctanh(const Tensor & self)767 Tensor arctanh(const Tensor& self) { return self.atanh(); }
arctanh_(Tensor & self)768 Tensor& arctanh_(Tensor& self) { return self.atanh_(); }
769
square_out(const Tensor & self,Tensor & result)770 Tensor& square_out(const Tensor& self, Tensor& result) { return at::pow_out(result, self, 2); }
square(const Tensor & self)771 Tensor square(const Tensor& self) { return at::pow(self, 2); }
square_(Tensor & self)772 Tensor& square_(Tensor& self) { return self.pow_(2); }
773
logit_out(const Tensor & self,std::optional<double> eps,Tensor & result)774 Tensor& logit_out(const Tensor& self,
775 std::optional<double> eps,
776 Tensor& result) {
777 return unary_op_impl_float_out(
778 result, self, logit_stub, Scalar(eps ? eps.value() : -1.0));
779 }
logit(const Tensor & self,std::optional<double> eps)780 Tensor logit(const Tensor& self, std::optional<double> eps) {
781 return unary_op_impl_float(
782 self, logit_stub, Scalar(eps ? eps.value() : -1.0));
783 }
logit_(Tensor & self,std::optional<double> eps)784 Tensor& logit_(Tensor& self, std::optional<double> eps) {
785 return at::logit_out(self, self, eps);
786 }
787
special_logit_out(const Tensor & self,std::optional<double> eps,Tensor & result)788 Tensor& special_logit_out(const Tensor& self, std::optional<double> eps, Tensor& result) {
789 return at::logit_out(result, self, eps);
790 }
special_logit(const Tensor & self,std::optional<double> eps)791 Tensor special_logit(const Tensor& self, std::optional<double> eps) {
792 return self.logit(eps);
793 }
794
795 // special_expit, alias for sigmoid
special_expit_out(const Tensor & self,Tensor & result)796 Tensor& special_expit_out(const Tensor& self, Tensor& result) {
797 return at::sigmoid_out(result, self);
798 }
special_expit(const Tensor & self)799 Tensor special_expit(const Tensor& self) {
800 return self.sigmoid();
801 }
802
nan_to_num_out(const Tensor & self,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf,Tensor & result)803 Tensor& nan_to_num_out(const Tensor& self,
804 std::optional<double> nan,
805 std::optional<double> pos_inf,
806 std::optional<double> neg_inf,
807 Tensor& result) {
808 TORCH_CHECK(
809 self.scalar_type() == result.scalar_type(),
810 "nan_to_num: dtype of out: ",
811 result.scalar_type(),
812 " should be same as input: ",
813 self.scalar_type());
814
815 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
816 at::native::resize_output(result, self.sizes());
817 result.copy_(self);
818 return result;
819 }
820
821 auto iter = TensorIterator::unary_op(result, self);
822 nan_to_num_stub(iter.device_type(), iter, nan, pos_inf, neg_inf);
823 return result;
824 }
825
nan_to_num(const Tensor & self,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf)826 Tensor nan_to_num(
827 const Tensor& self,
828 std::optional<double> nan,
829 std::optional<double> pos_inf,
830 std::optional<double> neg_inf) {
831 auto result = at::empty_like(self);
832 return at::nan_to_num_out(result, self, nan, pos_inf, neg_inf);
833 }
834
nan_to_num_(Tensor & self,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf)835 Tensor& nan_to_num_(
836 Tensor& self,
837 std::optional<double> nan,
838 std::optional<double> pos_inf,
839 std::optional<double> neg_inf) {
840 return at::nan_to_num_out(self, self, nan, pos_inf, neg_inf);
841 }
842
843 // Alias for trunc
fix_out(const Tensor & self,Tensor & result)844 Tensor& fix_out(const Tensor& self, Tensor& result) { return at::trunc_out(result, self); }
fix(const Tensor & self)845 Tensor fix(const Tensor& self) { return self.trunc(); }
fix_(Tensor & self)846 Tensor& fix_(Tensor& self) { return self.trunc_(); }
847
positive(const Tensor & self)848 Tensor positive(const Tensor& self) {
849 TORCH_CHECK(self.scalar_type() != kBool, "The `+` operator, on a bool tensor is not supported.");
850 return self;
851 }
852
negative_out(const Tensor & self,Tensor & result)853 Tensor& negative_out(const Tensor& self, Tensor& result) { return at::neg_out(result, self); }
negative(const Tensor & self)854 Tensor negative(const Tensor& self) { return self.neg(); }
negative_(Tensor & self)855 Tensor& negative_(Tensor& self) { return self.neg_(); }
856
logical_not(const Tensor & self)857 Tensor logical_not(const Tensor& self) {
858 Tensor result = at::empty({0}, self.options().dtype(kBool));
859 return at::logical_not_out(result, self);
860 }
861
logical_not_(Tensor & self)862 Tensor& logical_not_(Tensor& self) {
863 return at::logical_not_out(self, self);
864 }
865
logical_not_out(const Tensor & self,Tensor & result)866 Tensor& logical_not_out(const Tensor& self, Tensor& result) {
867 TensorIterator iter = TensorIteratorConfig()
868 .check_all_same_dtype(false)
869 .add_output(result)
870 .add_const_input(self)
871 .build();
872 logical_not_stub(iter.device_type(), iter);
873 return result;
874 }
875
876 namespace {
877 constexpr double HALF = 0.5;
878 constexpr double QUARTER = 0.25;
879 }
880
mvlgamma_check(const Tensor & self,int64_t p)881 static inline void mvlgamma_check(const Tensor& self, int64_t p) {
882 TORCH_CHECK(self.scalar_type() != kBool, "The input tensor may not be a boolean tensor.");
883 TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1");
884 }
885
mvlgamma(const Tensor & self,int64_t p)886 Tensor mvlgamma(const Tensor& self, int64_t p) {
887 mvlgamma_check(self, p);
888 auto dtype = c10::scalarTypeToTypeMeta(self.scalar_type());
889 if (at::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
890 // int -> float promotion
891 dtype = c10::get_default_dtype();
892 }
893 Tensor args = native::arange(
894 -p * HALF + HALF,
895 HALF,
896 HALF,
897 optTypeMetaToScalarType(dtype),
898 self.options().layout_opt(),
899 self.options().device_opt(),
900 self.options().pinned_memory_opt());
901 args = args.add(self.unsqueeze(-1));
902 const auto p2_sub_p = static_cast<double>(p * (p - 1));
903 return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
904 }
905
mvlgamma_(Tensor & self,int64_t p)906 Tensor& mvlgamma_(Tensor& self, int64_t p) {
907 mvlgamma_check(self, p);
908 Tensor args = native::arange(
909 -p *HALF + HALF,
910 HALF,
911 HALF,
912 optTypeMetaToScalarType(self.options().dtype_opt()),
913 self.options().layout_opt(),
914 self.options().device_opt(),
915 self.options().pinned_memory_opt());
916 args = args.add(self.unsqueeze(-1));
917 const auto p2_sub_p = static_cast<double>(p * (p - 1));
918 return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
919 }
920
mvlgamma_out(const Tensor & self,int64_t p,Tensor & result)921 Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {
922 auto out = self.mvlgamma(p);
923 TORCH_CHECK(
924 at::can_cast(out.scalar_type(), result.scalar_type()),
925 "mvlgamma: result type ",
926 self.scalar_type(),
927 " can't be cast to the desired output type ",
928 out.scalar_type());
929 at::native::resize_output(result, out.sizes());
930 return result.copy_(out);
931 }
932
special_multigammaln(const Tensor & self,int64_t p)933 Tensor special_multigammaln(const Tensor& self, int64_t p) {
934 return self.mvlgamma(p);
935 };
936
special_multigammaln_out(const Tensor & self,int64_t p,Tensor & result)937 Tensor& special_multigammaln_out(const Tensor& self, int64_t p, Tensor& result) {
938 return at::mvlgamma_out(result, self, p);
939 };
940
frexp(const Tensor & self)941 std::tuple<Tensor, Tensor> frexp(const Tensor& self) {
942 Tensor mantissa = at::empty_like(self);
943 Tensor exponent = at::empty_like(self, self.options().dtype(at::kInt));
944
945 at::frexp_out(mantissa, exponent, self);
946 return std::tuple<Tensor, Tensor>(mantissa, exponent);
947 }
948
frexp_out(const Tensor & self,Tensor & mantissa,Tensor & exponent)949 std::tuple<Tensor&, Tensor&> frexp_out(const Tensor& self,
950 Tensor& mantissa, Tensor& exponent) {
951 // torch.frexp is implemented for floating-point dtypes for now,
952 // should add support for integral dtypes in the future.
953 TORCH_CHECK(at::isFloatingType(self.scalar_type()),
954 "torch.frexp() only supports floating-point dtypes");
955
956 TORCH_CHECK(mantissa.dtype() == self.dtype(),
957 "torch.frexp() expects mantissa to have dtype ", self.dtype(),
958 " but got ", mantissa.dtype());
959 TORCH_CHECK(exponent.dtype() == at::kInt,
960 "torch.frexp() expects exponent to have int dtype "
961 "but got ", exponent.dtype());
962
963 auto iter = TensorIteratorConfig()
964 .add_output(mantissa)
965 .add_output(exponent)
966 .add_const_input(self)
967 .check_all_same_dtype(false)
968 .set_check_mem_overlap(true)
969 .build();
970 frexp_stub(iter.device_type(), iter);
971
972 return std::tuple<Tensor&, Tensor&>(mantissa, exponent);
973 }
974
975 // alias for lgamma, implements special.gammaln equivalent to
976 // scipy.special.gammaln
special_gammaln(const Tensor & self)977 Tensor special_gammaln(const Tensor& self) { return self.lgamma(); }
special_gammaln_out(const Tensor & self,Tensor & result)978 Tensor& special_gammaln_out(const Tensor& self, Tensor& result) { return at::lgamma_out(result, self); }
979
980 DEFINE_DISPATCH(abs_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
981 DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
982 DEFINE_DISPATCH(conj_physical_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
983 DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
984 DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
985 DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
986 DEFINE_DISPATCH(atanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
987 DEFINE_DISPATCH(asin_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
988 DEFINE_DISPATCH(atan_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
989 DEFINE_DISPATCH(bitwise_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
990 DEFINE_DISPATCH(ceil_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
991 DEFINE_DISPATCH(cos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
992 DEFINE_DISPATCH(cosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
993 DEFINE_DISPATCH(digamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
994 DEFINE_DISPATCH(special_entr_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
995 DEFINE_DISPATCH(special_erfcx_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
996 DEFINE_DISPATCH(erf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
997 DEFINE_DISPATCH(erfc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
998 DEFINE_DISPATCH(erfinv_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
999 DEFINE_DISPATCH(exp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1000 DEFINE_DISPATCH(exp2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1001 DEFINE_DISPATCH(expm1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1002 DEFINE_DISPATCH(floor_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1003 DEFINE_DISPATCH(frac_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1004 DEFINE_DISPATCH(frexp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1005 DEFINE_DISPATCH(i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1006 DEFINE_DISPATCH(special_i0e_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1007 DEFINE_DISPATCH(special_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1008 DEFINE_DISPATCH(special_i1e_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1009 DEFINE_DISPATCH(log_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1010 DEFINE_DISPATCH(log10_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1011 DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1012 DEFINE_DISPATCH(log2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1013 DEFINE_DISPATCH(logical_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1014 DEFINE_DISPATCH(special_ndtri_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1015 DEFINE_DISPATCH(special_log_ndtr_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1016 DEFINE_DISPATCH(neg_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1017 DEFINE_DISPATCH(nan_to_num_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1018 DEFINE_DISPATCH(polygamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1019 DEFINE_DISPATCH(reciprocal_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1020 DEFINE_DISPATCH(round_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1021 DEFINE_DISPATCH(round_decimals_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1022 DEFINE_DISPATCH(rsqrt_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1023 DEFINE_DISPATCH(sigmoid_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1024 DEFINE_DISPATCH(logit_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1025 DEFINE_DISPATCH(sign_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1026 DEFINE_DISPATCH(signbit_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1027 DEFINE_DISPATCH(sgn_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1028 DEFINE_DISPATCH(sin_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1029 DEFINE_DISPATCH(sinc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1030 DEFINE_DISPATCH(sinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1031 DEFINE_DISPATCH(sqrt_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1032 DEFINE_DISPATCH(tan_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1033 DEFINE_DISPATCH(tanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1034 DEFINE_DISPATCH(trigamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1035 DEFINE_DISPATCH(trunc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1036 DEFINE_DISPATCH(lgamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1037 DEFINE_DISPATCH(special_airy_ai_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1038 DEFINE_DISPATCH(special_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1039 DEFINE_DISPATCH(special_bessel_j1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1040 DEFINE_DISPATCH(special_bessel_y0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1041 DEFINE_DISPATCH(special_bessel_y1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1042 DEFINE_DISPATCH(special_modified_bessel_i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1043 DEFINE_DISPATCH(special_modified_bessel_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1044 DEFINE_DISPATCH(special_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1045 DEFINE_DISPATCH(special_modified_bessel_k1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1046 DEFINE_DISPATCH(special_scaled_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1047 DEFINE_DISPATCH(special_scaled_modified_bessel_k1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1048 DEFINE_DISPATCH(special_spherical_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
1049
1050 } // namespace at::native
1051