1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/BinaryOps.h>
3
4 #include <type_traits>
5 #include <utility>
6
7 #include <ATen/core/Tensor.h>
8 #include <ATen/ScalarOps.h>
9 #include <ATen/TensorIterator.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorMeta.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_add_relu_native.h>
18 #include <ATen/ops/_efficientzerotensor.h>
19 #include <ATen/ops/_test_serialization_subcmul_native.h>
20 #include <ATen/ops/_to_copy.h>
21 #include <ATen/ops/add.h>
22 #include <ATen/ops/add_native.h>
23 #include <ATen/ops/add_ops.h>
24 #include <ATen/ops/and_native.h>
25 #include <ATen/ops/arctan2_native.h>
26 #include <ATen/ops/atan2.h>
27 #include <ATen/ops/atan2_native.h>
28 #include <ATen/ops/bitwise_and.h>
29 #include <ATen/ops/bitwise_and_native.h>
30 #include <ATen/ops/bitwise_left_shift.h>
31 #include <ATen/ops/bitwise_left_shift_native.h>
32 #include <ATen/ops/bitwise_or.h>
33 #include <ATen/ops/bitwise_or_native.h>
34 #include <ATen/ops/bitwise_right_shift.h>
35 #include <ATen/ops/bitwise_right_shift_native.h>
36 #include <ATen/ops/bitwise_xor.h>
37 #include <ATen/ops/bitwise_xor_native.h>
38 #include <ATen/ops/copysign.h>
39 #include <ATen/ops/copysign_native.h>
40 #include <ATen/ops/div.h>
41 #include <ATen/ops/div_native.h>
42 #include <ATen/ops/div_ops.h>
43 #include <ATen/ops/divide_native.h>
44 #include <ATen/ops/empty.h>
45 #include <ATen/ops/eq_native.h>
46 #include <ATen/ops/floor_divide.h>
47 #include <ATen/ops/floor_divide_native.h>
48 #include <ATen/ops/fmax_native.h>
49 #include <ATen/ops/fmin_native.h>
50 #include <ATen/ops/fmod.h>
51 #include <ATen/ops/fmod_native.h>
52 #include <ATen/ops/full.h>
53 #include <ATen/ops/gcd_native.h>
54 #include <ATen/ops/ge.h>
55 #include <ATen/ops/ge_native.h>
56 #include <ATen/ops/greater_equal_native.h>
57 #include <ATen/ops/greater_native.h>
58 #include <ATen/ops/gt.h>
59 #include <ATen/ops/gt_native.h>
60 #include <ATen/ops/heaviside_native.h>
61 #include <ATen/ops/hypot_native.h>
62 #include <ATen/ops/igamma.h>
63 #include <ATen/ops/igamma_native.h>
64 #include <ATen/ops/igammac.h>
65 #include <ATen/ops/igammac_native.h>
66 #include <ATen/ops/lcm_native.h>
67 #include <ATen/ops/ldexp.h>
68 #include <ATen/ops/ldexp_native.h>
69 #include <ATen/ops/le.h>
70 #include <ATen/ops/le_native.h>
71 #include <ATen/ops/less_equal_native.h>
72 #include <ATen/ops/less_native.h>
73 #include <ATen/ops/linalg_cross_native.h>
74 #include <ATen/ops/linalg_cross_ops.h>
75 #include <ATen/ops/logaddexp2_native.h>
76 #include <ATen/ops/logaddexp_native.h>
77 #include <ATen/ops/logical_and.h>
78 #include <ATen/ops/logical_and_native.h>
79 #include <ATen/ops/logical_or.h>
80 #include <ATen/ops/logical_or_native.h>
81 #include <ATen/ops/logical_xor.h>
82 #include <ATen/ops/logical_xor_native.h>
83 #include <ATen/ops/logit_backward_native.h>
84 #include <ATen/ops/lshift_native.h>
85 #include <ATen/ops/lt.h>
86 #include <ATen/ops/lt_native.h>
87 #include <ATen/ops/max_native.h>
88 #include <ATen/ops/maximum.h>
89 #include <ATen/ops/maximum_native.h>
90 #include <ATen/ops/min_native.h>
91 #include <ATen/ops/minimum.h>
92 #include <ATen/ops/minimum_native.h>
93 #include <ATen/ops/mul.h>
94 #include <ATen/ops/mul_native.h>
95 #include <ATen/ops/mul_ops.h>
96 #include <ATen/ops/multiply_native.h>
97 #include <ATen/ops/ne.h>
98 #include <ATen/ops/ne_native.h>
99 #include <ATen/ops/nextafter_native.h>
100 #include <ATen/ops/not_equal_native.h>
101 #include <ATen/ops/or_native.h>
102 #include <ATen/ops/pow.h>
103 #include <ATen/ops/remainder.h>
104 #include <ATen/ops/remainder_native.h>
105 #include <ATen/ops/rshift_native.h>
106 #include <ATen/ops/rsub_native.h>
107 #include <ATen/ops/sigmoid_backward_native.h>
108 #include <ATen/ops/special_chebyshev_polynomial_t.h>
109 #include <ATen/ops/special_chebyshev_polynomial_t_native.h>
110 #include <ATen/ops/special_chebyshev_polynomial_u.h>
111 #include <ATen/ops/special_chebyshev_polynomial_u_native.h>
112 #include <ATen/ops/special_chebyshev_polynomial_v.h>
113 #include <ATen/ops/special_chebyshev_polynomial_v_native.h>
114 #include <ATen/ops/special_chebyshev_polynomial_w.h>
115 #include <ATen/ops/special_chebyshev_polynomial_w_native.h>
116 #include <ATen/ops/special_gammainc_native.h>
117 #include <ATen/ops/special_gammaincc_native.h>
118 #include <ATen/ops/special_hermite_polynomial_h.h>
119 #include <ATen/ops/special_hermite_polynomial_h_native.h>
120 #include <ATen/ops/special_hermite_polynomial_he.h>
121 #include <ATen/ops/special_hermite_polynomial_he_native.h>
122 #include <ATen/ops/special_laguerre_polynomial_l.h>
123 #include <ATen/ops/special_laguerre_polynomial_l_native.h>
124 #include <ATen/ops/special_legendre_polynomial_p.h>
125 #include <ATen/ops/special_legendre_polynomial_p_native.h>
126 #include <ATen/ops/special_shifted_chebyshev_polynomial_t.h>
127 #include <ATen/ops/special_shifted_chebyshev_polynomial_t_native.h>
128 #include <ATen/ops/special_shifted_chebyshev_polynomial_u.h>
129 #include <ATen/ops/special_shifted_chebyshev_polynomial_u_native.h>
130 #include <ATen/ops/special_shifted_chebyshev_polynomial_v.h>
131 #include <ATen/ops/special_shifted_chebyshev_polynomial_v_native.h>
132 #include <ATen/ops/special_shifted_chebyshev_polynomial_w.h>
133 #include <ATen/ops/special_shifted_chebyshev_polynomial_w_native.h>
134 #include <ATen/ops/special_xlog1py.h>
135 #include <ATen/ops/special_xlog1py_native.h>
136 #include <ATen/ops/special_xlogy_native.h>
137 #include <ATen/ops/special_zeta.h>
138 #include <ATen/ops/special_zeta_native.h>
139 #include <ATen/ops/sub.h>
140 #include <ATen/ops/sub_native.h>
141 #include <ATen/ops/subtract_native.h>
142 #include <ATen/ops/tanh_backward_native.h>
143 #include <ATen/ops/true_divide_native.h>
144 #include <ATen/ops/xlogy.h>
145 #include <ATen/ops/xlogy_native.h>
146 #include <ATen/ops/xor_native.h>
147 #endif
148
149 namespace at::meta {
150
TORCH_META_FUNC2(add,Tensor)151 TORCH_META_FUNC2(add, Tensor) (
152 const Tensor& self, const Tensor& other, const Scalar& alpha
153 ) {
154 build_borrowing_binary_op(maybe_get_output(), self, other);
155 native::alpha_check(dtype(), alpha);
156 }
157
TORCH_META_FUNC2(sub,Tensor)158 TORCH_META_FUNC2(sub, Tensor) (
159 const Tensor& self, const Tensor& other, const Scalar& alpha
160 ) {
161 native::sub_check(self, other);
162 build_borrowing_binary_op(maybe_get_output(), self, other);
163 native::alpha_check(dtype(), alpha);
164 }
165
TORCH_META_FUNC2(mul,Tensor)166 TORCH_META_FUNC2(mul, Tensor) (
167 const Tensor& self, const Tensor& other
168 ) {
169 build_borrowing_binary_op(maybe_get_output(), self, other);
170 }
171
TORCH_META_FUNC2(div,Tensor)172 TORCH_META_FUNC2(div, Tensor) (const Tensor& self, const Tensor& other) {
173 build_borrowing_binary_float_op(maybe_get_output(), self, other);
174 }
175
TORCH_META_FUNC2(div,Tensor_mode)176 TORCH_META_FUNC2(div, Tensor_mode) (const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode) {
177 if (!rounding_mode.has_value()) {
178 build_borrowing_binary_float_op(maybe_get_output(), self, other);
179 // NOLINTNEXTLINE(bugprone-branch-clone)
180 } else if (*rounding_mode == "trunc") {
181 build_borrowing_binary_op(maybe_get_output(), self, other);
182 } else if (*rounding_mode == "floor") {
183 build_borrowing_binary_op(maybe_get_output(), self, other);
184 } else {
185 TORCH_CHECK(false,
186 "div expected rounding_mode to be one of None, 'trunc', or 'floor' "
187 "but found '", *rounding_mode, "'");
188 }
189 }
190
TORCH_META_FUNC(special_xlog1py)191 TORCH_META_FUNC(special_xlog1py) (const Tensor& self, const Tensor& other) {
192 build_borrowing_binary_float_op(maybe_get_output(), self, other);
193 }
194
TORCH_META_FUNC(special_zeta)195 TORCH_META_FUNC(special_zeta) (const Tensor& self, const Tensor& other) {
196 build_borrowing_binary_float_op(maybe_get_output(), self, other);
197 }
198
TORCH_META_FUNC(special_chebyshev_polynomial_t)199 TORCH_META_FUNC(special_chebyshev_polynomial_t) (const Tensor& self, const Tensor& n) {
200 build_borrowing_binary_float_op(maybe_get_output(), self, n);
201 }
202
TORCH_META_FUNC(special_chebyshev_polynomial_u)203 TORCH_META_FUNC(special_chebyshev_polynomial_u) (const Tensor& self, const Tensor& n) {
204 build_borrowing_binary_float_op(maybe_get_output(), self, n);
205 }
206
TORCH_META_FUNC(special_chebyshev_polynomial_v)207 TORCH_META_FUNC(special_chebyshev_polynomial_v) (const Tensor& self, const Tensor& n) {
208 build_borrowing_binary_float_op(maybe_get_output(), self, n);
209 }
210
TORCH_META_FUNC(special_chebyshev_polynomial_w)211 TORCH_META_FUNC(special_chebyshev_polynomial_w) (const Tensor& self, const Tensor& n) {
212 build_borrowing_binary_float_op(maybe_get_output(), self, n);
213 }
214
TORCH_META_FUNC(special_hermite_polynomial_h)215 TORCH_META_FUNC(special_hermite_polynomial_h) (const Tensor& self, const Tensor& n) {
216 build_borrowing_binary_float_op(maybe_get_output(), self, n);
217 }
218
TORCH_META_FUNC(special_hermite_polynomial_he)219 TORCH_META_FUNC(special_hermite_polynomial_he) (const Tensor& self, const Tensor& n) {
220 build_borrowing_binary_float_op(maybe_get_output(), self, n);
221 }
222
TORCH_META_FUNC(special_laguerre_polynomial_l)223 TORCH_META_FUNC(special_laguerre_polynomial_l) (const Tensor& self, const Tensor& n) {
224 build_borrowing_binary_float_op(maybe_get_output(), self, n);
225 }
226
TORCH_META_FUNC(special_legendre_polynomial_p)227 TORCH_META_FUNC(special_legendre_polynomial_p) (const Tensor& self, const Tensor& n) {
228 build_borrowing_binary_float_op(maybe_get_output(), self, n);
229 }
230
TORCH_META_FUNC(special_shifted_chebyshev_polynomial_t)231 TORCH_META_FUNC(special_shifted_chebyshev_polynomial_t) (const Tensor& self, const Tensor& n) {
232 build_borrowing_binary_float_op(maybe_get_output(), self, n);
233 }
234
TORCH_META_FUNC(special_shifted_chebyshev_polynomial_u)235 TORCH_META_FUNC(special_shifted_chebyshev_polynomial_u) (const Tensor& self, const Tensor& n) {
236 build_borrowing_binary_float_op(maybe_get_output(), self, n);
237 }
238
TORCH_META_FUNC(special_shifted_chebyshev_polynomial_v)239 TORCH_META_FUNC(special_shifted_chebyshev_polynomial_v) (const Tensor& self, const Tensor& n) {
240 build_borrowing_binary_float_op(maybe_get_output(), self, n);
241 }
242
TORCH_META_FUNC(special_shifted_chebyshev_polynomial_w)243 TORCH_META_FUNC(special_shifted_chebyshev_polynomial_w) (const Tensor& self, const Tensor& n) {
244 build_borrowing_binary_float_op(maybe_get_output(), self, n);
245 }
246
TORCH_META_FUNC2(copysign,Tensor)247 TORCH_META_FUNC2(copysign, Tensor) (
248 const Tensor& self, const Tensor& other
249 ) {
250 build_borrowing_binary_float_op(maybe_get_output(), self, other);
251 }
252
TORCH_META_FUNC(heaviside)253 TORCH_META_FUNC(heaviside) (
254 const Tensor& self, const Tensor& other
255 ) {
256 TORCH_CHECK(!self.is_complex() && !other.is_complex() &&
257 (maybe_get_output().defined() ? !maybe_get_output().is_complex() : true),
258 "heaviside is not yet implemented for complex tensors.");
259 TORCH_CHECK(self.dtype() == other.dtype() &&
260 (maybe_get_output().defined() ? maybe_get_output().dtype() == self.dtype() : true),
261 "heaviside is not yet implemented for tensors with different dtypes.");
262
263 build_binary_op(maybe_get_output(), self, other);
264 }
265
TORCH_META_FUNC(atan2)266 TORCH_META_FUNC(atan2) (const Tensor& self, const Tensor& other) {
267 build_borrowing_binary_float_op(maybe_get_output(), self, other);
268 }
269
TORCH_META_FUNC2(remainder,Tensor)270 TORCH_META_FUNC2(remainder, Tensor)(const Tensor& self, const Tensor& other) {
271 build_borrowing_binary_op(maybe_get_output(), self, other);
272 }
273
TORCH_META_FUNC2(bitwise_left_shift,Tensor)274 TORCH_META_FUNC2(bitwise_left_shift, Tensor) (
275 const Tensor& self, const Tensor& other
276 ) {
277 build_borrowing_binary_op(maybe_get_output(), self, other);
278 }
279
TORCH_META_FUNC2(bitwise_right_shift,Tensor)280 TORCH_META_FUNC2(bitwise_right_shift, Tensor) (
281 const Tensor& self, const Tensor& other
282 ) {
283 build_borrowing_binary_op(maybe_get_output(), self, other);
284 }
285
TORCH_META_FUNC2(bitwise_and,Tensor)286 TORCH_META_FUNC2(bitwise_and, Tensor) (const Tensor& self, const Tensor& other) {
287 build_borrowing_binary_op(maybe_get_output(), self, other);
288 }
289
TORCH_META_FUNC2(bitwise_or,Tensor)290 TORCH_META_FUNC2(bitwise_or, Tensor) (const Tensor& self, const Tensor& other) {
291 build_borrowing_binary_op(maybe_get_output(), self, other);
292 }
293
TORCH_META_FUNC2(bitwise_xor,Tensor)294 TORCH_META_FUNC2(bitwise_xor, Tensor) (const Tensor& self, const Tensor& other) {
295 build_borrowing_binary_op(maybe_get_output(), self, other);
296 }
297
TORCH_META_FUNC2(fmod,Tensor)298 TORCH_META_FUNC2(fmod, Tensor) (const Tensor& self, const Tensor& other) {
299 build_borrowing_binary_op(maybe_get_output(), self, other);
300 }
301
TORCH_META_FUNC2(xlogy,Tensor)302 TORCH_META_FUNC2(xlogy, Tensor) (const Tensor& self, const Tensor& other) {
303 build_borrowing_binary_float_op(maybe_get_output(), self, other);
304 }
305
TORCH_META_FUNC(logit_backward)306 TORCH_META_FUNC(logit_backward) (const Tensor& grad_output, const Tensor& input, std::optional<double> eps) {
307 build_borrowing_binary_op(maybe_get_output(), grad_output, input);
308 }
309
TORCH_META_FUNC(sigmoid_backward)310 TORCH_META_FUNC(sigmoid_backward) (const Tensor& grad_output, const Tensor& output) {
311 build_borrowing_binary_op(maybe_get_output(), grad_output, output);
312 }
313
TORCH_META_FUNC(tanh_backward)314 TORCH_META_FUNC(tanh_backward) (const Tensor& grad_output, const Tensor& output) {
315 build_borrowing_binary_op(maybe_get_output(), grad_output, output);
316 }
317
318 // These are normal binary ops that preserve dtype
319 #define CREATE_BINARY_META_FUNC(func) \
320 TORCH_META_FUNC(func) (const Tensor& self, const Tensor& other) { \
321 build_borrowing_binary_op(maybe_get_output(), self, other); \
322 }
323
324 CREATE_BINARY_META_FUNC(logaddexp);
325 CREATE_BINARY_META_FUNC(logaddexp2);
326 CREATE_BINARY_META_FUNC(gcd);
327 CREATE_BINARY_META_FUNC(lcm);
328 CREATE_BINARY_META_FUNC(hypot);
329 CREATE_BINARY_META_FUNC(igamma);
330 CREATE_BINARY_META_FUNC(igammac);
331 CREATE_BINARY_META_FUNC(nextafter);
332
TORCH_META_FUNC(maximum)333 TORCH_META_FUNC(maximum) (const Tensor& self, const Tensor& other) {
334 TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum not implemented for complex tensors.");
335 build_borrowing_binary_op(maybe_get_output(), self, other);
336 }
337
TORCH_META_FUNC(minimum)338 TORCH_META_FUNC(minimum) (const Tensor& self, const Tensor& other) {
339 TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum not implemented for complex tensors.");
340 build_borrowing_binary_op(maybe_get_output(), self, other);
341 }
342
TORCH_META_FUNC(fmax)343 TORCH_META_FUNC(fmax) (const Tensor& self, const Tensor& other) {
344 TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmax not implemented for complex tensors.");
345 build_binary_op(maybe_get_output(), self, other);
346 }
347
TORCH_META_FUNC(fmin)348 TORCH_META_FUNC(fmin) (const Tensor& self, const Tensor& other) {
349 TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmin not implemented for complex tensors.");
350 build_binary_op(maybe_get_output(), self, other);
351 }
352
353 #define CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(func) \
354 TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \
355 const Tensor& result = maybe_get_output(); \
356 build_borrowing_comparison_op(result, self, other); \
357 } \
358 \
359 TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \
360 auto other_tensor = \
361 native::wrapped_scalar_tensor(other); \
362 build_borrowing_except_last_argument_comparison_op(maybe_get_output(), self, other_tensor); \
363 }
364
365 CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(eq);
366 CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(ne);
367 CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(lt);
368 CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(le);
369 CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(gt);
370 CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(ge);
371
372 } // namespace at::meta
373
374
375 namespace at::native {
376
377 DEFINE_DISPATCH(add_clamp_stub);
378 DEFINE_DISPATCH(mul_stub);
379 DEFINE_DISPATCH(sub_stub);
380 DEFINE_DISPATCH(div_true_stub);
381 DEFINE_DISPATCH(div_floor_stub);
382 DEFINE_DISPATCH(div_trunc_stub);
383 DEFINE_DISPATCH(remainder_stub);
384 DEFINE_DISPATCH(atan2_stub);
385 DEFINE_DISPATCH(bitwise_and_stub);
386 DEFINE_DISPATCH(bitwise_or_stub);
387 DEFINE_DISPATCH(bitwise_xor_stub);
388 DEFINE_DISPATCH(lshift_stub);
389 DEFINE_DISPATCH(rshift_stub);
390 DEFINE_DISPATCH(logical_and_stub);
391 DEFINE_DISPATCH(logical_or_stub);
392 DEFINE_DISPATCH(logical_xor_stub);
393 DEFINE_DISPATCH(lt_stub);
394 DEFINE_DISPATCH(le_stub);
395 DEFINE_DISPATCH(gt_stub);
396 DEFINE_DISPATCH(ge_stub);
397 DEFINE_DISPATCH(eq_stub);
398 DEFINE_DISPATCH(ne_stub);
399 DEFINE_DISPATCH(sigmoid_backward_stub);
400 DEFINE_DISPATCH(logit_backward_stub);
401 DEFINE_DISPATCH(tanh_backward_stub);
402 DEFINE_DISPATCH(maximum_stub);
403 DEFINE_DISPATCH(minimum_stub);
404 DEFINE_DISPATCH(fmax_stub);
405 DEFINE_DISPATCH(fmin_stub);
406 DEFINE_DISPATCH(fmod_stub);
407 DEFINE_DISPATCH(logaddexp_stub);
408 DEFINE_DISPATCH(logaddexp2_stub);
409 DEFINE_DISPATCH(gcd_stub);
410 DEFINE_DISPATCH(lcm_stub);
411 DEFINE_DISPATCH(hypot_stub);
412 DEFINE_DISPATCH(igamma_stub);
413 DEFINE_DISPATCH(igammac_stub);
414 DEFINE_DISPATCH(nextafter_stub);
415 DEFINE_DISPATCH(heaviside_stub);
416 DEFINE_DISPATCH(copysign_stub);
417 DEFINE_DISPATCH(xlogy_stub);
418 DEFINE_DISPATCH(xlog1py_stub);
419 DEFINE_DISPATCH(zeta_stub);
420 DEFINE_DISPATCH(chebyshev_polynomial_t_stub);
421 DEFINE_DISPATCH(chebyshev_polynomial_u_stub);
422 DEFINE_DISPATCH(chebyshev_polynomial_v_stub);
423 DEFINE_DISPATCH(chebyshev_polynomial_w_stub);
424 DEFINE_DISPATCH(hermite_polynomial_h_stub);
425 DEFINE_DISPATCH(hermite_polynomial_he_stub);
426 DEFINE_DISPATCH(laguerre_polynomial_l_stub);
427 DEFINE_DISPATCH(legendre_polynomial_p_stub);
428 DEFINE_DISPATCH(shifted_chebyshev_polynomial_t_stub);
429 DEFINE_DISPATCH(shifted_chebyshev_polynomial_u_stub);
430 DEFINE_DISPATCH(shifted_chebyshev_polynomial_v_stub);
431 DEFINE_DISPATCH(shifted_chebyshev_polynomial_w_stub);
432
TORCH_IMPL_FUNC(sub_out)433 TORCH_IMPL_FUNC(sub_out) (
434 const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
435 ) {
436 add_stub(device_type(), *this, -alpha);
437 TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype());
438 }
439
TORCH_IMPL_FUNC(mul_out)440 TORCH_IMPL_FUNC(mul_out) (
441 const Tensor& self, const Tensor& other, const Tensor& result
442 ) {
443 mul_stub(device_type(), *this);
444 }
445
TORCH_IMPL_FUNC(div_out)446 TORCH_IMPL_FUNC(div_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
447 div_true_stub(device_type(), *this);
448 }
449
TORCH_IMPL_FUNC(div_out_mode)450 TORCH_IMPL_FUNC(div_out_mode) (
451 const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode, const Tensor& result
452 ) {
453 if (!rounding_mode.has_value()) {
454 div_true_stub(device_type(), *this);
455 } else if (*rounding_mode == "trunc") {
456 div_trunc_stub(device_type(), *this);
457 } else if (*rounding_mode == "floor") {
458 div_floor_stub(device_type(), *this);
459 }
460 }
461
TORCH_IMPL_FUNC(logit_backward_out)462 TORCH_IMPL_FUNC(logit_backward_out) (const Tensor& grad_output, const Tensor& input, std::optional<double> eps, const Tensor& result) {
463 logit_backward_stub(device_type(), *this, Scalar(eps ? eps.value() : -1.0));
464 }
465
TORCH_IMPL_FUNC(sigmoid_backward_out)466 TORCH_IMPL_FUNC(sigmoid_backward_out) (const Tensor& grad_output, const Tensor& output, const Tensor& result) {
467 sigmoid_backward_stub(device_type(), *this);
468 }
469
TORCH_IMPL_FUNC(special_xlog1py_out)470 TORCH_IMPL_FUNC(special_xlog1py_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
471 xlog1py_stub(device_type(), *this);
472 }
473
TORCH_IMPL_FUNC(special_zeta_out)474 TORCH_IMPL_FUNC(special_zeta_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
475 zeta_stub(device_type(), *this);
476 }
477
TORCH_IMPL_FUNC(special_chebyshev_polynomial_t_out)478 TORCH_IMPL_FUNC(special_chebyshev_polynomial_t_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
479 chebyshev_polynomial_t_stub(device_type(), *this);
480 }
481
TORCH_IMPL_FUNC(special_chebyshev_polynomial_u_out)482 TORCH_IMPL_FUNC(special_chebyshev_polynomial_u_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
483 chebyshev_polynomial_u_stub(device_type(), *this);
484 }
485
TORCH_IMPL_FUNC(special_chebyshev_polynomial_v_out)486 TORCH_IMPL_FUNC(special_chebyshev_polynomial_v_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
487 chebyshev_polynomial_v_stub(device_type(), *this);
488 }
489
TORCH_IMPL_FUNC(special_chebyshev_polynomial_w_out)490 TORCH_IMPL_FUNC(special_chebyshev_polynomial_w_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
491 chebyshev_polynomial_w_stub(device_type(), *this);
492 }
493
TORCH_IMPL_FUNC(special_hermite_polynomial_h_out)494 TORCH_IMPL_FUNC(special_hermite_polynomial_h_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
495 hermite_polynomial_h_stub(device_type(), *this);
496 }
497
TORCH_IMPL_FUNC(special_hermite_polynomial_he_out)498 TORCH_IMPL_FUNC(special_hermite_polynomial_he_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
499 hermite_polynomial_he_stub(device_type(), *this);
500 }
501
TORCH_IMPL_FUNC(special_laguerre_polynomial_l_out)502 TORCH_IMPL_FUNC(special_laguerre_polynomial_l_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
503 laguerre_polynomial_l_stub(device_type(), *this);
504 }
505
TORCH_IMPL_FUNC(special_legendre_polynomial_p_out)506 TORCH_IMPL_FUNC(special_legendre_polynomial_p_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
507 legendre_polynomial_p_stub(device_type(), *this);
508 }
509
TORCH_IMPL_FUNC(special_shifted_chebyshev_polynomial_t_out)510 TORCH_IMPL_FUNC(special_shifted_chebyshev_polynomial_t_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
511 shifted_chebyshev_polynomial_t_stub(device_type(), *this);
512 }
513
TORCH_IMPL_FUNC(special_shifted_chebyshev_polynomial_u_out)514 TORCH_IMPL_FUNC(special_shifted_chebyshev_polynomial_u_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
515 shifted_chebyshev_polynomial_u_stub(device_type(), *this);
516 }
517
TORCH_IMPL_FUNC(special_shifted_chebyshev_polynomial_v_out)518 TORCH_IMPL_FUNC(special_shifted_chebyshev_polynomial_v_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
519 shifted_chebyshev_polynomial_v_stub(device_type(), *this);
520 }
521
TORCH_IMPL_FUNC(special_shifted_chebyshev_polynomial_w_out)522 TORCH_IMPL_FUNC(special_shifted_chebyshev_polynomial_w_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
523 shifted_chebyshev_polynomial_w_stub(device_type(), *this);
524 }
525
TORCH_IMPL_FUNC(tanh_backward_out)526 TORCH_IMPL_FUNC(tanh_backward_out) (const Tensor& grad_output, const Tensor& output, const Tensor& result) {
527 tanh_backward_stub(device_type(), *this);
528 }
529
530 #define CREATE_BINARY_TORCH_IMPL_FUNC(func_out, func_stub) \
531 TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& other, const Tensor& result) { \
532 func_stub(device_type(), *this); \
533 }
534
535 CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_and_out, bitwise_and_stub);
536 CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_or_out, bitwise_or_stub);
537 CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_xor_out, bitwise_xor_stub);
538 CREATE_BINARY_TORCH_IMPL_FUNC(maximum_out, maximum_stub);
539 CREATE_BINARY_TORCH_IMPL_FUNC(minimum_out, minimum_stub);
540 CREATE_BINARY_TORCH_IMPL_FUNC(fmax_out, fmax_stub);
541 CREATE_BINARY_TORCH_IMPL_FUNC(fmin_out, fmin_stub);
542 CREATE_BINARY_TORCH_IMPL_FUNC(fmod_out, fmod_stub);
543 CREATE_BINARY_TORCH_IMPL_FUNC(logaddexp_out, logaddexp_stub);
544 CREATE_BINARY_TORCH_IMPL_FUNC(logaddexp2_out, logaddexp2_stub);
545 CREATE_BINARY_TORCH_IMPL_FUNC(gcd_out, gcd_stub);
546 CREATE_BINARY_TORCH_IMPL_FUNC(lcm_out, lcm_stub);
547 CREATE_BINARY_TORCH_IMPL_FUNC(hypot_out, hypot_stub);
548 CREATE_BINARY_TORCH_IMPL_FUNC(igamma_out, igamma_stub);
549 CREATE_BINARY_TORCH_IMPL_FUNC(igammac_out, igammac_stub);
550 CREATE_BINARY_TORCH_IMPL_FUNC(nextafter_out, nextafter_stub);
551 CREATE_BINARY_TORCH_IMPL_FUNC(remainder_out, remainder_stub);
552 CREATE_BINARY_TORCH_IMPL_FUNC(xlogy_out, xlogy_stub);
553
special_xlog1py(const Scalar & x,const Tensor & y)554 Tensor special_xlog1py(const Scalar& x, const Tensor& y) {
555 return at::special_xlog1py(wrapped_scalar_tensor(x), y);
556 }
557
special_xlog1py(const Tensor & x,const Scalar & y)558 Tensor special_xlog1py(const Tensor& x, const Scalar& y) {
559 return at::special_xlog1py(x, wrapped_scalar_tensor(y));
560 }
561
special_xlog1py_out(const Scalar & self,const Tensor & other,Tensor & result)562 Tensor& special_xlog1py_out(const Scalar& self, const Tensor& other, Tensor& result) {
563 return at::special_xlog1py_out(result, wrapped_scalar_tensor(self), other);
564 }
565
special_xlog1py_out(const Tensor & self,const Scalar & other,Tensor & result)566 Tensor& special_xlog1py_out(const Tensor& self, const Scalar& other, Tensor& result) {
567 return at::special_xlog1py_out(result, self, wrapped_scalar_tensor(other));
568 }
569
special_zeta(const Scalar & x,const Tensor & y)570 Tensor special_zeta(const Scalar& x, const Tensor& y) {
571 return at::special_zeta(wrapped_scalar_tensor(x), y);
572 }
573
special_zeta(const Tensor & x,const Scalar & y)574 Tensor special_zeta(const Tensor& x, const Scalar& y) {
575 return at::special_zeta(x, wrapped_scalar_tensor(y));
576 }
577
special_zeta_out(const Scalar & self,const Tensor & other,Tensor & result)578 Tensor& special_zeta_out(const Scalar& self, const Tensor& other, Tensor& result) {
579 return at::special_zeta_out(result, wrapped_scalar_tensor(self), other);
580 }
581
special_zeta_out(const Tensor & self,const Scalar & other,Tensor & result)582 Tensor& special_zeta_out(const Tensor& self, const Scalar& other, Tensor& result) {
583 return at::special_zeta_out(result, self, wrapped_scalar_tensor(other));
584 }
585
special_chebyshev_polynomial_t(const Scalar & x,const Tensor & n)586 Tensor special_chebyshev_polynomial_t(const Scalar& x, const Tensor& n) {
587 return at::special_chebyshev_polynomial_t(wrapped_scalar_tensor(x), n);
588 }
589
special_chebyshev_polynomial_t(const Tensor & x,const Scalar & n)590 Tensor special_chebyshev_polynomial_t(const Tensor& x, const Scalar& n) {
591 return at::special_chebyshev_polynomial_t(x, wrapped_scalar_tensor(n));
592 }
593
special_chebyshev_polynomial_t_out(const Scalar & self,const Tensor & n,Tensor & result)594 Tensor& special_chebyshev_polynomial_t_out(const Scalar& self, const Tensor& n, Tensor& result) {
595 return at::special_chebyshev_polynomial_t_out(result, wrapped_scalar_tensor(self), n);
596 }
597
special_chebyshev_polynomial_t_out(const Tensor & self,const Scalar & n,Tensor & result)598 Tensor& special_chebyshev_polynomial_t_out(const Tensor& self, const Scalar& n, Tensor& result) {
599 return at::special_chebyshev_polynomial_t_out(result, self, wrapped_scalar_tensor(n));
600 }
601
special_chebyshev_polynomial_u(const Scalar & x,const Tensor & n)602 Tensor special_chebyshev_polynomial_u(const Scalar& x, const Tensor& n) {
603 return at::special_chebyshev_polynomial_u(wrapped_scalar_tensor(x), n);
604 }
605
special_chebyshev_polynomial_u(const Tensor & x,const Scalar & n)606 Tensor special_chebyshev_polynomial_u(const Tensor& x, const Scalar& n) {
607 return at::special_chebyshev_polynomial_u(x, wrapped_scalar_tensor(n));
608 }
609
special_chebyshev_polynomial_u_out(const Scalar & self,const Tensor & n,Tensor & result)610 Tensor& special_chebyshev_polynomial_u_out(const Scalar& self, const Tensor& n, Tensor& result) {
611 return at::special_chebyshev_polynomial_u_out(result, wrapped_scalar_tensor(self), n);
612 }
613
special_chebyshev_polynomial_u_out(const Tensor & self,const Scalar & n,Tensor & result)614 Tensor& special_chebyshev_polynomial_u_out(const Tensor& self, const Scalar& n, Tensor& result) {
615 return at::special_chebyshev_polynomial_u_out(result, self, wrapped_scalar_tensor(n));
616 }
617
special_chebyshev_polynomial_v(const Scalar & x,const Tensor & n)618 Tensor special_chebyshev_polynomial_v(const Scalar& x, const Tensor& n) {
619 return at::special_chebyshev_polynomial_v(wrapped_scalar_tensor(x), n);
620 }
621
special_chebyshev_polynomial_v(const Tensor & x,const Scalar & n)622 Tensor special_chebyshev_polynomial_v(const Tensor& x, const Scalar& n) {
623 return at::special_chebyshev_polynomial_v(x, wrapped_scalar_tensor(n));
624 }
625
special_chebyshev_polynomial_v_out(const Scalar & self,const Tensor & n,Tensor & result)626 Tensor& special_chebyshev_polynomial_v_out(const Scalar& self, const Tensor& n, Tensor& result) {
627 return at::special_chebyshev_polynomial_v_out(result, wrapped_scalar_tensor(self), n);
628 }
629
special_chebyshev_polynomial_v_out(const Tensor & self,const Scalar & n,Tensor & result)630 Tensor& special_chebyshev_polynomial_v_out(const Tensor& self, const Scalar& n, Tensor& result) {
631 return at::special_chebyshev_polynomial_v_out(result, self, wrapped_scalar_tensor(n));
632 }
633
special_chebyshev_polynomial_w(const Scalar & x,const Tensor & n)634 Tensor special_chebyshev_polynomial_w(const Scalar& x, const Tensor& n) {
635 return at::special_chebyshev_polynomial_w(wrapped_scalar_tensor(x), n);
636 }
637
special_chebyshev_polynomial_w(const Tensor & x,const Scalar & n)638 Tensor special_chebyshev_polynomial_w(const Tensor& x, const Scalar& n) {
639 return at::special_chebyshev_polynomial_w(x, wrapped_scalar_tensor(n));
640 }
641
special_chebyshev_polynomial_w_out(const Scalar & self,const Tensor & n,Tensor & result)642 Tensor& special_chebyshev_polynomial_w_out(const Scalar& self, const Tensor& n, Tensor& result) {
643 return at::special_chebyshev_polynomial_w_out(result, wrapped_scalar_tensor(self), n);
644 }
645
special_chebyshev_polynomial_w_out(const Tensor & self,const Scalar & n,Tensor & result)646 Tensor& special_chebyshev_polynomial_w_out(const Tensor& self, const Scalar& n, Tensor& result) {
647 return at::special_chebyshev_polynomial_w_out(result, self, wrapped_scalar_tensor(n));
648 }
649
special_hermite_polynomial_h(const Scalar & x,const Tensor & n)650 Tensor special_hermite_polynomial_h(const Scalar& x, const Tensor& n) {
651 return at::special_hermite_polynomial_h(wrapped_scalar_tensor(x), n);
652 }
653
special_hermite_polynomial_h(const Tensor & x,const Scalar & n)654 Tensor special_hermite_polynomial_h(const Tensor& x, const Scalar& n) {
655 return at::special_hermite_polynomial_h(x, wrapped_scalar_tensor(n));
656 }
657
special_hermite_polynomial_h_out(const Scalar & self,const Tensor & n,Tensor & result)658 Tensor& special_hermite_polynomial_h_out(const Scalar& self, const Tensor& n, Tensor& result) {
659 return at::special_hermite_polynomial_h_out(result, wrapped_scalar_tensor(self), n);
660 }
661
special_hermite_polynomial_h_out(const Tensor & self,const Scalar & n,Tensor & result)662 Tensor& special_hermite_polynomial_h_out(const Tensor& self, const Scalar& n, Tensor& result) {
663 return at::special_hermite_polynomial_h_out(result, self, wrapped_scalar_tensor(n));
664 }
665
special_hermite_polynomial_he(const Scalar & x,const Tensor & n)666 Tensor special_hermite_polynomial_he(const Scalar& x, const Tensor& n) {
667 return at::special_hermite_polynomial_he(wrapped_scalar_tensor(x), n);
668 }
669
special_hermite_polynomial_he(const Tensor & x,const Scalar & n)670 Tensor special_hermite_polynomial_he(const Tensor& x, const Scalar& n) {
671 return at::special_hermite_polynomial_he(x, wrapped_scalar_tensor(n));
672 }
673
special_hermite_polynomial_he_out(const Scalar & self,const Tensor & n,Tensor & result)674 Tensor& special_hermite_polynomial_he_out(const Scalar& self, const Tensor& n, Tensor& result) {
675 return at::special_hermite_polynomial_he_out(result, wrapped_scalar_tensor(self), n);
676 }
677
special_hermite_polynomial_he_out(const Tensor & self,const Scalar & n,Tensor & result)678 Tensor& special_hermite_polynomial_he_out(const Tensor& self, const Scalar& n, Tensor& result) {
679 return at::special_hermite_polynomial_he_out(result, self, wrapped_scalar_tensor(n));
680 }
681
special_laguerre_polynomial_l(const Scalar & x,const Tensor & n)682 Tensor special_laguerre_polynomial_l(const Scalar& x, const Tensor& n) {
683 return at::special_laguerre_polynomial_l(wrapped_scalar_tensor(x), n);
684 }
685
special_laguerre_polynomial_l(const Tensor & x,const Scalar & n)686 Tensor special_laguerre_polynomial_l(const Tensor& x, const Scalar& n) {
687 return at::special_laguerre_polynomial_l(x, wrapped_scalar_tensor(n));
688 }
689
special_laguerre_polynomial_l_out(const Scalar & self,const Tensor & n,Tensor & result)690 Tensor& special_laguerre_polynomial_l_out(const Scalar& self, const Tensor& n, Tensor& result) {
691 return at::special_laguerre_polynomial_l_out(result, wrapped_scalar_tensor(self), n);
692 }
693
special_laguerre_polynomial_l_out(const Tensor & self,const Scalar & n,Tensor & result)694 Tensor& special_laguerre_polynomial_l_out(const Tensor& self, const Scalar& n, Tensor& result) {
695 return at::special_laguerre_polynomial_l_out(result, self, wrapped_scalar_tensor(n));
696 }
697
special_legendre_polynomial_p(const Scalar & x,const Tensor & n)698 Tensor special_legendre_polynomial_p(const Scalar& x, const Tensor& n) {
699 return at::special_legendre_polynomial_p(wrapped_scalar_tensor(x), n);
700 }
701
special_legendre_polynomial_p(const Tensor & x,const Scalar & n)702 Tensor special_legendre_polynomial_p(const Tensor& x, const Scalar& n) {
703 return at::special_legendre_polynomial_p(x, wrapped_scalar_tensor(n));
704 }
705
special_legendre_polynomial_p_out(const Scalar & self,const Tensor & n,Tensor & result)706 Tensor& special_legendre_polynomial_p_out(const Scalar& self, const Tensor& n, Tensor& result) {
707 return at::special_legendre_polynomial_p_out(result, wrapped_scalar_tensor(self), n);
708 }
709
special_legendre_polynomial_p_out(const Tensor & self,const Scalar & n,Tensor & result)710 Tensor& special_legendre_polynomial_p_out(const Tensor& self, const Scalar& n, Tensor& result) {
711 return at::special_legendre_polynomial_p_out(result, self, wrapped_scalar_tensor(n));
712 }
713
special_shifted_chebyshev_polynomial_t(const Scalar & x,const Tensor & n)714 Tensor special_shifted_chebyshev_polynomial_t(const Scalar& x, const Tensor& n) {
715 return at::special_shifted_chebyshev_polynomial_t(wrapped_scalar_tensor(x), n);
716 }
717
special_shifted_chebyshev_polynomial_t(const Tensor & x,const Scalar & n)718 Tensor special_shifted_chebyshev_polynomial_t(const Tensor& x, const Scalar& n) {
719 return at::special_shifted_chebyshev_polynomial_t(x, wrapped_scalar_tensor(n));
720 }
721
special_shifted_chebyshev_polynomial_t_out(const Scalar & self,const Tensor & n,Tensor & result)722 Tensor& special_shifted_chebyshev_polynomial_t_out(const Scalar& self, const Tensor& n, Tensor& result) {
723 return at::special_shifted_chebyshev_polynomial_t_out(result, wrapped_scalar_tensor(self), n);
724 }
725
special_shifted_chebyshev_polynomial_t_out(const Tensor & self,const Scalar & n,Tensor & result)726 Tensor& special_shifted_chebyshev_polynomial_t_out(const Tensor& self, const Scalar& n, Tensor& result) {
727 return at::special_shifted_chebyshev_polynomial_t_out(result, self, wrapped_scalar_tensor(n));
728 }
729
special_shifted_chebyshev_polynomial_u(const Scalar & x,const Tensor & n)730 Tensor special_shifted_chebyshev_polynomial_u(const Scalar& x, const Tensor& n) {
731 return at::special_shifted_chebyshev_polynomial_u(wrapped_scalar_tensor(x), n);
732 }
733
special_shifted_chebyshev_polynomial_u(const Tensor & x,const Scalar & n)734 Tensor special_shifted_chebyshev_polynomial_u(const Tensor& x, const Scalar& n) {
735 return at::special_shifted_chebyshev_polynomial_u(x, wrapped_scalar_tensor(n));
736 }
737
special_shifted_chebyshev_polynomial_u_out(const Scalar & self,const Tensor & n,Tensor & result)738 Tensor& special_shifted_chebyshev_polynomial_u_out(const Scalar& self, const Tensor& n, Tensor& result) {
739 return at::special_shifted_chebyshev_polynomial_u_out(result, wrapped_scalar_tensor(self), n);
740 }
741
special_shifted_chebyshev_polynomial_u_out(const Tensor & self,const Scalar & n,Tensor & result)742 Tensor& special_shifted_chebyshev_polynomial_u_out(const Tensor& self, const Scalar& n, Tensor& result) {
743 return at::special_shifted_chebyshev_polynomial_u_out(result, self, wrapped_scalar_tensor(n));
744 }
745
special_shifted_chebyshev_polynomial_v(const Scalar & x,const Tensor & n)746 Tensor special_shifted_chebyshev_polynomial_v(const Scalar& x, const Tensor& n) {
747 return at::special_shifted_chebyshev_polynomial_v(wrapped_scalar_tensor(x), n);
748 }
749
special_shifted_chebyshev_polynomial_v(const Tensor & x,const Scalar & n)750 Tensor special_shifted_chebyshev_polynomial_v(const Tensor& x, const Scalar& n) {
751 return at::special_shifted_chebyshev_polynomial_v(x, wrapped_scalar_tensor(n));
752 }
753
special_shifted_chebyshev_polynomial_v_out(const Scalar & self,const Tensor & n,Tensor & result)754 Tensor& special_shifted_chebyshev_polynomial_v_out(const Scalar& self, const Tensor& n, Tensor& result) {
755 return at::special_shifted_chebyshev_polynomial_v_out(result, wrapped_scalar_tensor(self), n);
756 }
757
special_shifted_chebyshev_polynomial_v_out(const Tensor & self,const Scalar & n,Tensor & result)758 Tensor& special_shifted_chebyshev_polynomial_v_out(const Tensor& self, const Scalar& n, Tensor& result) {
759 return at::special_shifted_chebyshev_polynomial_v_out(result, self, wrapped_scalar_tensor(n));
760 }
761
special_shifted_chebyshev_polynomial_w(const Scalar & x,const Tensor & n)762 Tensor special_shifted_chebyshev_polynomial_w(const Scalar& x, const Tensor& n) {
763 return at::special_shifted_chebyshev_polynomial_w(wrapped_scalar_tensor(x), n);
764 }
765
special_shifted_chebyshev_polynomial_w(const Tensor & x,const Scalar & n)766 Tensor special_shifted_chebyshev_polynomial_w(const Tensor& x, const Scalar& n) {
767 return at::special_shifted_chebyshev_polynomial_w(x, wrapped_scalar_tensor(n));
768 }
769
special_shifted_chebyshev_polynomial_w_out(const Scalar & self,const Tensor & n,Tensor & result)770 Tensor& special_shifted_chebyshev_polynomial_w_out(const Scalar& self, const Tensor& n, Tensor& result) {
771 return at::special_shifted_chebyshev_polynomial_w_out(result, wrapped_scalar_tensor(self), n);
772 }
773
special_shifted_chebyshev_polynomial_w_out(const Tensor & self,const Scalar & n,Tensor & result)774 Tensor& special_shifted_chebyshev_polynomial_w_out(const Tensor& self, const Scalar& n, Tensor& result) {
775 return at::special_shifted_chebyshev_polynomial_w_out(result, self, wrapped_scalar_tensor(n));
776 }
777
special_gammainc_out(const Tensor & self,const Tensor & other,Tensor & result)778 Tensor& special_gammainc_out(const Tensor& self, const Tensor& other, Tensor& result) {
779 return at::igamma_out(result, self, other);
780 }
781
special_gammainc(const Tensor & self,const Tensor & other)782 Tensor special_gammainc(const Tensor& self, const Tensor& other) {
783 return at::igamma(self, other);
784 }
785
special_gammaincc_out(const Tensor & self,const Tensor & other,Tensor & result)786 Tensor& special_gammaincc_out(const Tensor& self, const Tensor& other, Tensor& result) {
787 return at::igammac_out(result, self, other);
788 }
789
special_gammaincc(const Tensor & self,const Tensor & other)790 Tensor special_gammaincc(const Tensor& self, const Tensor& other) {
791 return at::igammac(self, other);
792 }
793
TORCH_IMPL_FUNC(atan2_out)794 TORCH_IMPL_FUNC(atan2_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
795 atan2_stub(device_type(), *this);
796 }
797
arctan2(const Tensor & self,const Tensor & other)798 Tensor arctan2(const Tensor& self, const Tensor& other) {
799 return at::atan2(self, other);
800 }
801
arctan2_(Tensor & self,const Tensor & other)802 Tensor& arctan2_(Tensor& self, const Tensor& other) {
803 return self.atan2_(other);
804 }
805
arctan2_out(const Tensor & self,const Tensor & other,Tensor & result)806 Tensor& arctan2_out(const Tensor& self, const Tensor& other, Tensor& result) {
807 return at::atan2_out(result, self, other);
808 }
809
add_relu_impl(Tensor & result,const Tensor & self,const Tensor & other,const Scalar & alpha)810 static Tensor& add_relu_impl(
811 Tensor& result, const Tensor& self, const Tensor& other, const Scalar& alpha) {
812 auto iter = TensorIterator::binary_op(result, self, other);
813 Scalar min_val;
814 Scalar max_val;
815 if (self.dtype() == at::kInt) {
816 min_val = 0;
817 max_val = std::numeric_limits<int32_t>::max();
818 } else if (self.dtype() == at::kLong) {
819 min_val = 0;
820 max_val = std::numeric_limits<int64_t>::max();
821 } else if (self.dtype() == at::kShort) {
822 min_val = 0;
823 max_val = std::numeric_limits<int16_t>::max();
824 } else if (self.dtype() == at::kChar) {
825 min_val = 0;
826 max_val = std::numeric_limits<int8_t>::max();
827 } else if (self.dtype() == at::kFloat) {
828 min_val = 0.0;
829 max_val = std::numeric_limits<float>::max();
830 } else if (self.dtype() == at::kDouble) {
831 min_val = 0.0;
832 max_val = std::numeric_limits<double>::max();
833 } else {
834 TORCH_INTERNAL_ASSERT(
835 false, "Unsupported datatype for add_relu:", self.dtype().name());
836 }
837
838 result = iter.output();
839 add_clamp_stub(iter.device_type(), iter, alpha, min_val, max_val);
840 return result;
841 }
842
add_relu_out(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & result)843 Tensor& add_relu_out(const Tensor& self, const Tensor& other, const Scalar& alpha, Tensor& result) {
844 return add_relu_impl(result, self, other, alpha);
845 }
846
add_relu(const Tensor & self,const Tensor & other,const Scalar & alpha)847 Tensor add_relu(const Tensor& self, const Tensor& other, const Scalar& alpha) {
848 Tensor result;
849 return add_relu_impl(result, self, other, alpha);
850 }
851
add_relu(const Tensor & self,const Scalar & other,const Scalar & alpha)852 Tensor add_relu(const Tensor& self, const Scalar& other, const Scalar& alpha) {
853 return add_relu(self, wrapped_scalar_tensor(other), alpha);
854 }
855
add_relu_(Tensor & self,const Tensor & other,const Scalar & alpha)856 Tensor& add_relu_(Tensor& self, const Tensor& other, const Scalar& alpha) {
857 return add_relu_impl(self, self, other, alpha);
858 }
859
add_relu_(Tensor & self,const Scalar & other,const Scalar & alpha)860 Tensor& add_relu_(Tensor& self, const Scalar& other, const Scalar& alpha) {
861 return add_relu_(self, wrapped_scalar_tensor(other), alpha);
862 }
863
TORCH_IMPL_FUNC(copysign_out)864 TORCH_IMPL_FUNC(copysign_out) (
865 const Tensor& self, const Tensor& other, const Tensor& result
866 ) {
867 copysign_stub(device_type(), *this);
868 }
869
copysign(const Tensor & self,const Scalar & other)870 Tensor copysign(const Tensor& self, const Scalar& other) {
871 // redispatch!
872 return at::copysign(self, wrapped_scalar_tensor(other));
873 }
874
copysign_(Tensor & self,const Scalar & other)875 Tensor& copysign_(Tensor& self, const Scalar& other) {
876 // redispatch!
877 return self.copysign_(wrapped_scalar_tensor(other));
878 }
879
copysign_out(const Tensor & self,const Scalar & other,Tensor & result)880 Tensor& copysign_out(const Tensor& self, const Scalar& other, Tensor& result) {
881 // redispatch!
882 return at::copysign_out(result, self, wrapped_scalar_tensor(other));
883 }
884
885 // WARNING: There doesn't appear to be any testing for this function
886 // with sparse self input.
div(const Tensor & self,const Scalar & other)887 Tensor div(const Tensor& self, const Scalar& other) {
888 return self.div(wrapped_scalar_tensor(other)); // redispatch!
889 }
890
891 // WARNING: This function, with a sparse self, is currently only
892 // exercised by DistributedDataParallelTest.test_sparse_gradients
893 // (you need to exercise it from C++, because this overload is never
894 // used for Python)
div_(Tensor & self,const Scalar & other)895 Tensor& div_(Tensor& self, const Scalar& other) {
896 return self.div_(wrapped_scalar_tensor(other)); // redispatch!
897 }
898
div(const Tensor & self,const Scalar & other,std::optional<c10::string_view> rounding_mode)899 Tensor div(const Tensor& self, const Scalar& other, std::optional<c10::string_view> rounding_mode) {
900 return self.div(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch!
901 }
902
div_(Tensor & self,const Scalar & other,std::optional<c10::string_view> rounding_mode)903 Tensor& div_(Tensor& self, const Scalar& other, std::optional<c10::string_view> rounding_mode) {
904 return self.div_(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch!
905 }
906
907 // divide, alias for div
divide_out(const Tensor & self,const Tensor & other,Tensor & result)908 Tensor& divide_out(const Tensor& self, const Tensor& other, Tensor& result) {
909 return at::div_out(result, self, other);
910 }
911
divide(const Tensor & self,const Tensor & other)912 Tensor divide(const Tensor& self, const Tensor& other) {
913 return self.div(other);
914 }
915
divide_(Tensor & self,const Tensor & other)916 Tensor& divide_(Tensor& self, const Tensor& other) {
917 return self.div_(other);
918 }
919
divide(const Tensor & self,const Scalar & other)920 Tensor divide(const Tensor& self, const Scalar& other) {
921 return self.div(other);
922 }
923
divide_(Tensor & self,const Scalar & other)924 Tensor& divide_(Tensor& self, const Scalar& other) {
925 return self.div_(other);
926 }
927
divide_out(const Tensor & self,const Tensor & other,std::optional<c10::string_view> rounding_mode,Tensor & result)928 Tensor& divide_out(const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode, Tensor& result) {
929 return at::div_out(result, self, other, std::move(rounding_mode));
930 }
931
divide(const Tensor & self,const Tensor & other,std::optional<c10::string_view> rounding_mode)932 Tensor divide(const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode) {
933 return self.div(other, std::move(rounding_mode));
934 }
935
divide_(Tensor & self,const Tensor & other,std::optional<c10::string_view> rounding_mode)936 Tensor& divide_(Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode) {
937 return self.div_(other, std::move(rounding_mode));
938 }
939
divide(const Tensor & self,const Scalar & other,std::optional<c10::string_view> rounding_mode)940 Tensor divide(const Tensor& self, const Scalar& other, std::optional<c10::string_view> rounding_mode) {
941 return self.div(other, std::move(rounding_mode));
942 }
943
divide_(Tensor & self,const Scalar & other,std::optional<c10::string_view> rounding_mode)944 Tensor& divide_(Tensor& self, const Scalar& other, std::optional<c10::string_view> rounding_mode) {
945 return self.div_(other, std::move(rounding_mode));
946 }
947
948 // true_divide, an alias for div
true_divide_out(const Tensor & self,const Tensor & divisor,Tensor & result)949 Tensor& true_divide_out(const Tensor& self, const Tensor& divisor, Tensor& result) {
950 return at::div_out(result, self, divisor);
951 }
952
true_divide(const Tensor & self,const Tensor & divisor)953 Tensor true_divide(const Tensor& self, const Tensor& divisor) {
954 return self.div(divisor);
955 }
956
true_divide_(Tensor & self,const Tensor & divisor)957 Tensor& true_divide_(Tensor& self, const Tensor& divisor) {
958 return self.div_(divisor);
959 }
960
true_divide(const Tensor & self,const Scalar & divisor)961 Tensor true_divide(const Tensor& self, const Scalar& divisor) {
962 return self.div(divisor);
963 }
964
true_divide_(Tensor & self,const Scalar & divisor)965 Tensor& true_divide_(Tensor& self, const Scalar& divisor) {
966 return self.div_(divisor);
967 }
968
floor_divide_out(const Tensor & self,const Tensor & other,Tensor & result)969 Tensor& floor_divide_out(const Tensor& self, const Tensor& other, Tensor& result) {
970 auto iter = TensorIterator::binary_op(result, self, other);
971 div_floor_stub(iter.device_type(), iter);
972 if (!result.defined()) {
973 result = iter.output();
974 }
975 return result;
976 }
977
floor_divide(const Tensor & self,const Tensor & other)978 Tensor floor_divide(const Tensor& self, const Tensor& other) {
979 Tensor result;
980 auto iter = TensorIterator::binary_op(result, self, other);
981 div_floor_stub(iter.device_type(), iter);
982 return iter.output();
983 }
984
floor_divide_(Tensor & self,const Tensor & other)985 Tensor& floor_divide_(Tensor& self, const Tensor& other) {
986 return native::floor_divide_out(self, other, self);
987 }
988
989 // TODO: Make this structured to undo the perf regression from native:: removal
990 // in call here
mul(const Tensor & self,const Scalar & other)991 Tensor mul(const Tensor& self, const Scalar& other) {
992 return at::mul(self, wrapped_scalar_tensor(other)); // redispatch!
993 }
994
mul_(Tensor & self,const Scalar & other)995 Tensor& mul_(Tensor& self, const Scalar& other) {
996 return at::mul_out(self, wrapped_scalar_tensor(other), self); // redispatch!
997 }
998
mul__scalar_sparse_csr(Tensor & self,const Scalar & other)999 Tensor& mul__scalar_sparse_csr(Tensor& self, const Scalar& other) {
1000 self.values().mul_(other);
1001 return self;
1002 }
1003
correct_out_device(const Tensor & self,const Tensor & other)1004 static Device correct_out_device(const Tensor& self, const Tensor& other) {
1005 if (self.device() == at::kCPU){
1006 return other.device();
1007 } else {
1008 return self.device();
1009 }
1010 }
1011
mul_zerotensor(const Tensor & self,const Tensor & other)1012 Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
1013 auto out_device = correct_out_device(self, other);
1014 // hack to use the TensorIterator to get the correct broadcasting and type promotion logic
1015 auto device_ = Device(DeviceType::Meta);
1016 constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
1017 auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
1018 return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
1019 }
1020
div_zerotensor(const Tensor & self,const Tensor & other)1021 Tensor div_zerotensor(const Tensor& self, const Tensor& other) {
1022 auto out_device = correct_out_device(self, other);
1023 // hack to use the TensorIterator to get the correct broadcasting and type promotion logic
1024 auto device_ = Device(DeviceType::Meta);
1025 constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
1026 auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
1027
1028 if (self._is_zerotensor()) {
1029 if (other._is_zerotensor()) {
1030 // 0/0, return full NAN
1031 return at::full(meta_out.sizes(), std::numeric_limits<float>::quiet_NaN(), meta_out.options().device(out_device));
1032 }
1033 else {
1034 // 0/x, return zero tensor
1035 return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
1036 }
1037 }
1038 else {
1039 if (other._is_zerotensor()) {
1040 // x/0, return full INF
1041 return at::full(meta_out.sizes(), std::numeric_limits<float>::infinity(), meta_out.options().device(out_device));
1042 }
1043 else {
1044 // x/y -- unreachable, see TORCH_INTERNAL_ASSERT above
1045 return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
1046 }
1047 }
1048 }
1049
maybe_add_maybe_sub(const Tensor & self,const Tensor & other,const Scalar & alpha)1050 static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const Scalar& alpha) {
1051 auto out_device = correct_out_device(self, other);
1052 // hack to use the TensorIterator to get the correct broadcasting and type promotion logic
1053 auto device_ = Device(DeviceType::Meta);
1054 constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
1055 auto meta_out = at::_ops::add_Tensor::redispatch(
1056 meta_dks, self.to(device_), other.to(device_), alpha);
1057
1058 auto get_out_like = [&] (const Tensor& tensor)
1059 {
1060 auto sizes = meta_out.sizes();
1061 return at::_to_copy(tensor.expand(sizes), meta_out.options().device(out_device));
1062 };
1063
1064 if (self._is_zerotensor()) {
1065 if (other._is_zerotensor()) {
1066 return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
1067 }
1068 auto res = get_out_like(other);
1069 return alpha.equal(1) ? std::move(res) : res.mul(alpha);
1070 } else {
1071 return get_out_like(self);
1072 }
1073 }
add_zerotensor(const Tensor & self,const Tensor & other,const Scalar & alpha)1074 Tensor add_zerotensor(const Tensor& self, const Tensor& other, const Scalar& alpha) {
1075 return maybe_add_maybe_sub(self, other, alpha);
1076 }
1077
sub_zerotensor(const Tensor & self,const Tensor & other,const Scalar & alpha)1078 Tensor sub_zerotensor(const Tensor& self, const Tensor& other, const Scalar& alpha) {
1079 return maybe_add_maybe_sub(self, other, -alpha);
1080 }
1081
linalg_cross_zerotensor(const Tensor & input,const Tensor & other,const int64_t dim)1082 Tensor linalg_cross_zerotensor(
1083 const Tensor& input,
1084 const Tensor& other,
1085 const int64_t dim)
1086 {
1087 auto out_device = correct_out_device(input, other);
1088 // hack to use the TensorIterator to get the correct broadcasting and type
1089 // promotion logic (see add_zerotensor)
1090 auto device = Device(DeviceType::Meta);
1091 auto meta_out = at::_ops::linalg_cross::redispatch(
1092 c10::DispatchKeySet(at::DispatchKey::Meta),
1093 input.to(device),
1094 other.to(device),
1095 dim);
1096
1097 return at::_efficientzerotensor(
1098 meta_out.sizes(),
1099 meta_out.options().device(out_device));
1100 }
1101
1102 // multiply, alias for mul
multiply_out(const Tensor & self,const Tensor & other,Tensor & result)1103 Tensor& multiply_out(const Tensor& self, const Tensor& other, Tensor& result) {
1104 return at::mul_out(result, self, other);
1105 }
1106
multiply(const Tensor & self,const Tensor & other)1107 Tensor multiply(const Tensor& self, const Tensor& other) {
1108 return self.mul(other);
1109 }
1110
multiply_(Tensor & self,const Tensor & other)1111 Tensor& multiply_(Tensor& self, const Tensor& other) {
1112 return self.mul_(other);
1113 }
1114
multiply(const Tensor & self,const Scalar & other)1115 Tensor multiply(const Tensor& self, const Scalar& other) {
1116 return self.mul(other);
1117 }
1118
multiply_(Tensor & self,const Scalar & other)1119 Tensor& multiply_(Tensor& self, const Scalar& other) {
1120 return self.mul_(other);
1121 }
1122
sub(const Tensor & self,const Scalar & other,const Scalar & alpha)1123 Tensor sub(const Tensor& self, const Scalar& other, const Scalar& alpha) {
1124 return at::sub(self, wrapped_scalar_tensor(other), alpha); // redispatch!
1125 }
1126
sub_(Tensor & self,const Scalar & other,const Scalar & alpha)1127 Tensor& sub_(Tensor& self, const Scalar& other, const Scalar& alpha) {
1128 return self.sub_(wrapped_scalar_tensor(other), alpha); // redispatch!
1129 }
1130
1131 // subtract, alias for sub
subtract_out(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & result)1132 Tensor& subtract_out(const Tensor& self, const Tensor& other, const Scalar& alpha, Tensor& result) {
1133 return at::sub_out(result, self, other, alpha);
1134 }
1135
subtract(const Tensor & self,const Tensor & other,const Scalar & alpha)1136 Tensor subtract(const Tensor& self, const Tensor& other, const Scalar& alpha) {
1137 return self.sub(other, alpha);
1138 }
1139
subtract_(Tensor & self,const Tensor & other,const Scalar & alpha)1140 Tensor& subtract_(Tensor& self, const Tensor& other, const Scalar& alpha) {
1141 return self.sub_(other, alpha);
1142 }
1143
subtract(const Tensor & self,const Scalar & other,const Scalar & alpha)1144 Tensor subtract(const Tensor& self, const Scalar& other, const Scalar& alpha) {
1145 return self.sub(other, alpha);
1146 }
1147
subtract_(Tensor & self,const Scalar & other,const Scalar & alpha)1148 Tensor& subtract_(Tensor& self, const Scalar& other, const Scalar& alpha) {
1149 return self.sub_(other, alpha);
1150 }
1151
rsub(const Tensor & self,const Tensor & other,const Scalar & alpha)1152 Tensor rsub(const Tensor& self, const Tensor& other, const Scalar& alpha) {
1153 return at::sub(other, self, alpha); // redispatch!
1154 }
1155
1156 // TODO: Make this structured to undo the perf regression from native:: removal
1157 // in call here
1158
add(const Tensor & self,const Scalar & other,const Scalar & alpha)1159 Tensor add(const Tensor& self, const Scalar& other, const Scalar& alpha) {
1160 return at::add(self, wrapped_scalar_tensor(other), alpha);
1161 }
1162
add_(Tensor & self,const Scalar & other,const Scalar & alpha)1163 Tensor& add_(Tensor& self, const Scalar& other, const Scalar& alpha) {
1164 return self.add_(wrapped_scalar_tensor(other), alpha);
1165 }
1166
remainder(const Tensor & self,const Scalar & other)1167 Tensor remainder(const Tensor& self, const Scalar& other) {
1168 // redispatch
1169 return at::remainder(self, wrapped_scalar_tensor(other));
1170 }
1171
remainder_(Tensor & self,const Scalar & other)1172 Tensor& remainder_(Tensor& self, const Scalar& other) {
1173 // redispatch
1174 return self.remainder_(wrapped_scalar_tensor(other));
1175 }
1176
remainder_out(const Tensor & self,const Scalar & other,Tensor & result)1177 Tensor& remainder_out(const Tensor& self, const Scalar& other, Tensor& result) {
1178 // redispatch
1179 return at::remainder_out(result, self, wrapped_scalar_tensor(other));
1180 }
1181
remainder(const Scalar & self,const Tensor & other)1182 Tensor remainder(const Scalar& self, const Tensor& other) {
1183 return at::remainder(wrapped_scalar_tensor(self), other);
1184 }
1185
rsub(const Tensor & self,const Scalar & other,const Scalar & alpha)1186 Tensor rsub(const Tensor& self, const Scalar& other, const Scalar& alpha) {
1187 return native::rsub(self, wrapped_scalar_tensor(other), alpha);
1188 }
1189
bitwise_and_out(const Tensor & self,const Scalar & other,Tensor & result)1190 Tensor& bitwise_and_out(const Tensor& self, const Scalar& other, Tensor& result) {
1191 return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other));
1192 }
1193
bitwise_and(const Tensor & self,const Scalar & other)1194 Tensor bitwise_and(const Tensor& self, const Scalar& other) {
1195 return at::bitwise_and(self, wrapped_scalar_tensor(other));
1196 }
1197
bitwise_and(const Scalar & self,const Tensor & other)1198 Tensor bitwise_and(const Scalar& self, const Tensor& other) {
1199 return at::bitwise_and(wrapped_scalar_tensor(self), other);
1200 }
1201
bitwise_and_(Tensor & self,const Scalar & other)1202 Tensor& bitwise_and_(Tensor& self, const Scalar& other) {
1203 return self.bitwise_and_(wrapped_scalar_tensor(other));
1204 }
1205
1206 // Legacy and interfaces. They are aliased to bitwise_and* functions
__and__(const Tensor & self,const Tensor & other)1207 Tensor __and__(const Tensor& self, const Tensor& other) {
1208 return at::bitwise_and(self, other);
1209 }
1210
__and__(const Tensor & self,const Scalar & other)1211 Tensor __and__(const Tensor& self, const Scalar& other) {
1212 return at::bitwise_and(self, other);
1213 }
1214
__iand__(Tensor & self,const Tensor & other)1215 Tensor& __iand__(Tensor& self, const Tensor& other) {
1216 return self.bitwise_and_(other);
1217 }
1218
__iand__(Tensor & self,const Scalar & other)1219 Tensor& __iand__(Tensor& self, const Scalar& other) {
1220 return self.bitwise_and_(other);
1221 }
1222
bitwise_or_out(const Tensor & self,const Scalar & other,Tensor & result)1223 Tensor& bitwise_or_out(const Tensor& self, const Scalar& other, Tensor& result) {
1224 return at::bitwise_or_out(result, self, wrapped_scalar_tensor(other));
1225 }
1226
bitwise_or(const Tensor & self,const Scalar & other)1227 Tensor bitwise_or(const Tensor& self, const Scalar& other) {
1228 return at::bitwise_or(self, wrapped_scalar_tensor(other));
1229 }
1230
bitwise_or(const Scalar & self,const Tensor & other)1231 Tensor bitwise_or(const Scalar& self, const Tensor& other) {
1232 return at::bitwise_or(wrapped_scalar_tensor(self), other);
1233 }
1234
bitwise_or_(Tensor & self,const Scalar & other)1235 Tensor& bitwise_or_(Tensor& self, const Scalar& other) {
1236 return self.bitwise_or_(wrapped_scalar_tensor(other));
1237 }
1238
1239 // Legacy or interfaces. They are aliased to bitwise_or* functions
__or__(const Tensor & self,const Tensor & other)1240 Tensor __or__(const Tensor& self, const Tensor& other) {
1241 return at::bitwise_or(self, other);
1242 }
1243
__or__(const Tensor & self,const Scalar & other)1244 Tensor __or__(const Tensor& self, const Scalar& other) {
1245 return at::bitwise_or(self, other);
1246 }
1247
__ior__(Tensor & self,const Tensor & other)1248 Tensor& __ior__(Tensor& self, const Tensor& other) {
1249 return self.bitwise_or_(other);
1250 }
1251
__ior__(Tensor & self,const Scalar & other)1252 Tensor& __ior__(Tensor& self, const Scalar& other) {
1253 return self.bitwise_or_(other);
1254 }
1255
bitwise_xor_out(const Tensor & self,const Scalar & other,Tensor & result)1256 Tensor& bitwise_xor_out(const Tensor& self, const Scalar& other, Tensor& result) {
1257 return at::bitwise_xor_out(result, self, wrapped_scalar_tensor(other));
1258 }
1259
bitwise_xor(const Tensor & self,const Scalar & other)1260 Tensor bitwise_xor(const Tensor& self, const Scalar& other) {
1261 return at::bitwise_xor(self, wrapped_scalar_tensor(other));
1262 }
1263
bitwise_xor(const Scalar & self,const Tensor & other)1264 Tensor bitwise_xor(const Scalar& self, const Tensor& other) {
1265 return at::bitwise_xor(wrapped_scalar_tensor(self), other);
1266 }
1267
bitwise_xor_(Tensor & self,const Scalar & other)1268 Tensor& bitwise_xor_(Tensor& self, const Scalar& other) {
1269 return self.bitwise_xor_(wrapped_scalar_tensor(other));
1270 }
1271
1272 // Legacy xor interfaces. They are aliased to bitwise_xor* functions
__xor__(const Tensor & self,const Tensor & other)1273 Tensor __xor__(const Tensor& self, const Tensor& other) {
1274 return at::bitwise_xor(self, other);
1275 }
1276
__xor__(const Tensor & self,const Scalar & other)1277 Tensor __xor__(const Tensor& self, const Scalar& other) {
1278 return at::bitwise_xor(self, other);
1279 }
1280
__ixor__(Tensor & self,const Tensor & other)1281 Tensor& __ixor__(Tensor& self, const Tensor& other) {
1282 return self.bitwise_xor_(other);
1283 }
1284
__ixor__(Tensor & self,const Scalar & other)1285 Tensor& __ixor__(Tensor& self, const Scalar& other) {
1286 return self.bitwise_xor_(other);
1287 }
1288
__lshift__(const Tensor & self,const Tensor & other)1289 Tensor __lshift__(const Tensor& self, const Tensor& other) {
1290 Tensor result;
1291 auto iter = TensorIterator::binary_op(result, self, other);
1292 lshift_stub(iter.device_type(), iter);
1293 return iter.output();
1294 }
1295
__lshift__(const Tensor & self,const Scalar & other)1296 Tensor __lshift__(const Tensor& self, const Scalar& other) {
1297 Tensor result;
1298 auto wrapper = wrapped_scalar_tensor(other);
1299 auto iter = TensorIterator::binary_op(result, self, wrapper);
1300 lshift_stub(iter.device_type(), iter);
1301 return iter.output();
1302 }
1303
__ilshift__(Tensor & self,const Tensor & other)1304 Tensor& __ilshift__(Tensor& self, const Tensor& other) {
1305 auto iter = TensorIterator::binary_op(self, self, other);
1306 lshift_stub(iter.device_type(), iter);
1307 return self;
1308 }
1309
__ilshift__(Tensor & self,const Scalar & other)1310 Tensor& __ilshift__(Tensor& self, const Scalar& other) {
1311 auto wrapper = wrapped_scalar_tensor(other);
1312 auto iter = TensorIterator::binary_op(self, self, wrapper);
1313 lshift_stub(iter.device_type(), iter);
1314 return self;
1315 }
1316
TORCH_IMPL_FUNC(bitwise_left_shift_out)1317 TORCH_IMPL_FUNC(bitwise_left_shift_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
1318 lshift_stub(device_type(), *this);
1319 }
1320
bitwise_left_shift_out(const Tensor & self,const Scalar & other,Tensor & result)1321 Tensor& bitwise_left_shift_out(const Tensor& self, const Scalar& other, Tensor& result) {
1322 return at::bitwise_left_shift_out(result, self, wrapped_scalar_tensor(other));
1323 }
1324
bitwise_left_shift(const Tensor & self,const Scalar & other)1325 Tensor bitwise_left_shift(const Tensor& self, const Scalar& other) {
1326 return at::bitwise_left_shift(self, wrapped_scalar_tensor(other));
1327 }
1328
bitwise_left_shift_(Tensor & self,const Scalar & other)1329 Tensor& bitwise_left_shift_(Tensor& self, const Scalar& other) {
1330 return at::bitwise_left_shift_out(self, self, wrapped_scalar_tensor(other));
1331 }
1332
bitwise_left_shift(const Scalar & self,const Tensor & other)1333 Tensor bitwise_left_shift(const Scalar& self, const Tensor& other) {
1334 return at::bitwise_left_shift(wrapped_scalar_tensor(self), other);
1335 }
1336
__rshift__(const Tensor & self,const Tensor & other)1337 Tensor __rshift__(const Tensor& self, const Tensor& other) {
1338 Tensor result;
1339 auto iter = TensorIterator::binary_op(result, self, other);
1340 rshift_stub(iter.device_type(), iter);
1341 return iter.output();
1342 }
1343
__rshift__(const Tensor & self,const Scalar & other)1344 Tensor __rshift__(const Tensor& self, const Scalar& other) {
1345 Tensor result;
1346 auto wrapper = wrapped_scalar_tensor(other);
1347 auto iter = TensorIterator::binary_op(result, self, wrapper);
1348 rshift_stub(iter.device_type(), iter);
1349 return iter.output();
1350 }
1351
__irshift__(Tensor & self,const Tensor & other)1352 Tensor& __irshift__(Tensor& self, const Tensor& other) {
1353 auto iter = TensorIterator::binary_op(self, self, other);
1354 rshift_stub(iter.device_type(), iter);
1355 return self;
1356 }
1357
__irshift__(Tensor & self,const Scalar & other)1358 Tensor& __irshift__(Tensor& self, const Scalar& other) {
1359 auto wrapper = wrapped_scalar_tensor(other);
1360 auto iter = TensorIterator::binary_op(self, self, wrapper);
1361 rshift_stub(iter.device_type(), iter);
1362 return self;
1363 }
1364
TORCH_IMPL_FUNC(bitwise_right_shift_out)1365 TORCH_IMPL_FUNC(bitwise_right_shift_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
1366 rshift_stub(device_type(), *this);
1367 }
1368
bitwise_right_shift_out(const Tensor & self,const Scalar & other,Tensor & result)1369 Tensor& bitwise_right_shift_out(const Tensor& self, const Scalar& other, Tensor& result) {
1370 return at::bitwise_right_shift_out(result, self, wrapped_scalar_tensor(other));
1371 }
1372
bitwise_right_shift(const Tensor & self,const Scalar & other)1373 Tensor bitwise_right_shift(const Tensor& self, const Scalar& other) {
1374 return at::bitwise_right_shift(self, wrapped_scalar_tensor(other));
1375 }
1376
bitwise_right_shift_(Tensor & self,const Scalar & other)1377 Tensor& bitwise_right_shift_(Tensor& self, const Scalar& other) {
1378 return at::bitwise_right_shift_out(self, self, wrapped_scalar_tensor(other));
1379 }
1380
bitwise_right_shift(const Scalar & self,const Tensor & other)1381 Tensor bitwise_right_shift(const Scalar& self, const Tensor& other) {
1382 return at::bitwise_right_shift(wrapped_scalar_tensor(self), other);
1383 }
1384
1385 template <typename Stub>
comparison_op_out(Tensor & result,const Tensor & self,const Tensor & other,Stub & stub)1386 Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
1387 auto iter = TensorIterator::comparison_op(result, self, other);
1388 stub(iter.device_type(), iter);
1389 return result;
1390 }
1391
1392 template <typename OutImpl>
comparison_op(const Tensor & self,const Tensor & other,OutImpl & out_impl)1393 Tensor comparison_op(const Tensor& self, const Tensor& other, OutImpl& out_impl) {
1394 Tensor result = at::empty({0}, self.options().dtype(kBool));
1395 return out_impl(result, self, other);
1396 }
1397
1398 template <typename OutImpl>
comparison_op_(Tensor & self,const Tensor & other,OutImpl & out_impl)1399 Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) {
1400 return out_impl(self, self, other);
1401 }
1402
1403 template <typename OutImpl>
comparison_op_out(Tensor & result,const Tensor & self,const Scalar & other,OutImpl & out_impl)1404 Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Scalar& other, OutImpl& out_impl) {
1405 return out_impl(result, self, wrapped_scalar_tensor(other));
1406 }
1407
1408 template <typename OutImpl>
comparison_op(const Tensor & self,const Scalar & other,OutImpl & out_impl)1409 Tensor comparison_op(const Tensor& self, const Scalar& other, OutImpl& out_impl) {
1410 return comparison_op(self, wrapped_scalar_tensor(other), out_impl);
1411 }
1412
1413 template <typename OutImpl>
comparison_op_(Tensor & self,const Scalar & other,OutImpl & out_impl)1414 Tensor& comparison_op_(Tensor& self, const Scalar& other, OutImpl& out_impl) {
1415 return out_impl(self, self, wrapped_scalar_tensor(other));
1416 }
1417
1418 // We need explicit cast to OutFunc because each *_out func is overloaded twice. Without An explicit cast, merely
1419 // referring to *_out function is ambiguous.
1420 using OutFunc = std::add_const_t<Tensor&(&)(Tensor&, const Tensor&, const Tensor&)>;
1421
1422 // less, alias for torch.lt
less_out(const Tensor & self,const Tensor & other,Tensor & result)1423 Tensor& less_out(const Tensor& self, const Tensor& other, Tensor& result) { return at::lt_out(result, self, other); }
less(const Tensor & self,const Tensor & other)1424 Tensor less(const Tensor& self, const Tensor& other) { return self.lt(other); }
less_(Tensor & self,const Tensor & other)1425 Tensor& less_(Tensor& self, const Tensor& other) { return self.lt_(other); }
less_out(const Tensor & self,const Scalar & other,Tensor & result)1426 Tensor& less_out(const Tensor& self, const Scalar& other, Tensor& result) { return at::lt_out(result, self, other); }
less(const Tensor & self,const Scalar & other)1427 Tensor less(const Tensor& self, const Scalar& other) { return self.lt(other); }
less_(Tensor & self,const Scalar & other)1428 Tensor& less_(Tensor& self, const Scalar& other) { return self.lt_(other); }
1429
1430 // less_equal, alias for torch.le
less_equal_out(const Tensor & self,const Tensor & other,Tensor & result)1431 Tensor& less_equal_out(const Tensor& self, const Tensor& other, Tensor& result) { return at::le_out(result, self, other); }
less_equal(const Tensor & self,const Tensor & other)1432 Tensor less_equal(const Tensor& self, const Tensor& other) { return self.le(other); }
less_equal_(Tensor & self,const Tensor & other)1433 Tensor& less_equal_(Tensor& self, const Tensor& other) { return self.le_(other); }
less_equal_out(const Tensor & self,const Scalar & other,Tensor & result)1434 Tensor& less_equal_out(const Tensor& self, const Scalar& other, Tensor& result) { return at::le_out(result, self, other); }
less_equal(const Tensor & self,const Scalar & other)1435 Tensor less_equal(const Tensor& self, const Scalar& other) { return self.le(other); }
less_equal_(Tensor & self,const Scalar & other)1436 Tensor& less_equal_(Tensor& self, const Scalar& other) { return self.le_(other); }
1437
1438 // greater, alias for torch.gt
greater_out(const Tensor & self,const Tensor & other,Tensor & result)1439 Tensor& greater_out(const Tensor& self, const Tensor& other, Tensor& result) { return at::gt_out(result, self, other); }
greater(const Tensor & self,const Tensor & other)1440 Tensor greater(const Tensor& self, const Tensor& other) { return self.gt(other); }
greater_(Tensor & self,const Tensor & other)1441 Tensor& greater_(Tensor& self, const Tensor& other) { return self.gt_(other); }
greater_out(const Tensor & self,const Scalar & other,Tensor & result)1442 Tensor& greater_out(const Tensor& self, const Scalar& other, Tensor& result) { return at::gt_out(result, self, other); }
greater(const Tensor & self,const Scalar & other)1443 Tensor greater(const Tensor& self, const Scalar& other) { return self.gt(other); }
greater_(Tensor & self,const Scalar & other)1444 Tensor& greater_(Tensor& self, const Scalar& other) { return self.gt_(other); }
1445
1446 // greater_equal, alias for torch.ge
greater_equal_out(const Tensor & self,const Tensor & other,Tensor & result)1447 Tensor& greater_equal_out(const Tensor& self, const Tensor& other, Tensor& result) { return at::ge_out(result, self, other); }
greater_equal(const Tensor & self,const Tensor & other)1448 Tensor greater_equal(const Tensor& self, const Tensor& other) { return self.ge(other); }
greater_equal_(Tensor & self,const Tensor & other)1449 Tensor& greater_equal_(Tensor& self, const Tensor& other) { return self.ge_(other); }
greater_equal_out(const Tensor & self,const Scalar & other,Tensor & result)1450 Tensor& greater_equal_out(const Tensor& self, const Scalar& other, Tensor& result) { return at::ge_out(result, self, other); }
greater_equal(const Tensor & self,const Scalar & other)1451 Tensor greater_equal(const Tensor& self, const Scalar& other) { return self.ge(other); }
greater_equal_(Tensor & self,const Scalar & other)1452 Tensor& greater_equal_(Tensor& self, const Scalar& other) { return self.ge_(other); }
1453
1454 #define CREATE_COMPARISON_SCALAR_TENSOR_IMPL_FUNC(func) \
1455 TORCH_IMPL_FUNC(func##_Tensor_out) \
1456 (const Tensor& self, const Tensor& other, const Tensor& result) { \
1457 func##_stub(device_type(), *this); \
1458 } \
1459 \
1460 TORCH_IMPL_FUNC(func##_Scalar_out) \
1461 (const Tensor& self, const Scalar& other, const Tensor& result) { \
1462 func##_stub(device_type(), *this); \
1463 }
1464
1465 CREATE_COMPARISON_SCALAR_TENSOR_IMPL_FUNC(eq);
1466 CREATE_COMPARISON_SCALAR_TENSOR_IMPL_FUNC(ne);
1467 CREATE_COMPARISON_SCALAR_TENSOR_IMPL_FUNC(gt);
1468 CREATE_COMPARISON_SCALAR_TENSOR_IMPL_FUNC(ge);
1469 CREATE_COMPARISON_SCALAR_TENSOR_IMPL_FUNC(lt);
1470 CREATE_COMPARISON_SCALAR_TENSOR_IMPL_FUNC(le);
1471
1472 // not_equal, alias for torch.ne
not_equal_out(const Tensor & self,const Tensor & other,Tensor & result)1473 Tensor& not_equal_out(const Tensor& self, const Tensor& other, Tensor& result) { return at::ne_out(result, self, other); }
not_equal(const Tensor & self,const Tensor & other)1474 Tensor not_equal(const Tensor& self, const Tensor& other) { return self.ne(other); }
not_equal_(Tensor & self,const Tensor & other)1475 Tensor& not_equal_(Tensor& self, const Tensor& other) { return self.ne_(other); }
not_equal_out(const Tensor & self,const Scalar & other,Tensor & result)1476 Tensor& not_equal_out(const Tensor& self, const Scalar& other, Tensor& result) { return at::ne_out(result, self, other); }
not_equal(const Tensor & self,const Scalar & other)1477 Tensor not_equal(const Tensor& self, const Scalar& other) { return self.ne(other); }
not_equal_(Tensor & self,const Scalar & other)1478 Tensor& not_equal_(Tensor& self, const Scalar& other) { return self.ne_(other); }
1479
logical_and_out(const Tensor & self,const Tensor & other,Tensor & result)1480 Tensor& logical_and_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_and_stub); }
logical_and(const Tensor & self,const Tensor & other)1481 Tensor logical_and(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out)); }
logical_and_(Tensor & self,const Tensor & other)1482 Tensor& logical_and_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out)); }
1483
logical_or_out(const Tensor & self,const Tensor & other,Tensor & result)1484 Tensor& logical_or_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_or_stub); }
logical_or(const Tensor & self,const Tensor & other)1485 Tensor logical_or(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out)); }
logical_or_(Tensor & self,const Tensor & other)1486 Tensor& logical_or_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out)); }
1487
logical_xor_out(const Tensor & self,const Tensor & other,Tensor & result)1488 Tensor& logical_xor_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_xor_stub); }
logical_xor(const Tensor & self,const Tensor & other)1489 Tensor logical_xor(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
logical_xor_(Tensor & self,const Tensor & other)1490 Tensor& logical_xor_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
1491
1492 // binary max, alias for maximum
max_out(const Tensor & self,const Tensor & other,Tensor & result)1493 Tensor& max_out(const Tensor& self, const Tensor& other, Tensor& result) {
1494 return at::maximum_out(result, self, other);
1495 }
1496
max(const Tensor & self,const Tensor & other)1497 Tensor max(const Tensor& self, const Tensor& other) {
1498 return at::maximum(self, other);
1499 }
1500
1501 // binary min, alias for minimum
min_out(const Tensor & self,const Tensor & other,Tensor & result)1502 Tensor& min_out(const Tensor& self, const Tensor& other, Tensor& result) {
1503 return at::minimum_out(result, self, other);
1504 }
1505
min(const Tensor & self,const Tensor & other)1506 Tensor min(const Tensor& self, const Tensor& other) {
1507 return at::minimum(self, other);
1508 }
1509
floor_divide(const Tensor & self,const Scalar & other)1510 Tensor floor_divide(const Tensor& self, const Scalar& other) {
1511 return at::floor_divide(self, wrapped_scalar_tensor(other));
1512 }
1513
floor_divide_(Tensor & self,const Scalar & other)1514 Tensor& floor_divide_(Tensor& self, const Scalar& other) {
1515 return at::floor_divide_out(self, self, wrapped_scalar_tensor(other));
1516 }
1517
fmod_out(const Tensor & self,const Scalar & other,Tensor & result)1518 Tensor& fmod_out(const Tensor& self, const Scalar& other, Tensor & result) {
1519 // redispatch
1520 return at::fmod_out(result, self, wrapped_scalar_tensor(other));
1521 }
1522
fmod(const Tensor & self,const Scalar & other)1523 Tensor fmod(const Tensor& self, const Scalar& other) {
1524 // redispatch
1525 return at::fmod(self, wrapped_scalar_tensor(other));
1526 }
1527
fmod_(Tensor & self,const Scalar & other)1528 Tensor& fmod_(Tensor& self, const Scalar& other) {
1529 // redispatch
1530 return self.fmod_(wrapped_scalar_tensor(other));
1531 }
1532
1533 // Note: this function is only for testing.
1534 // It is undocumented and should not be used outside of tests.
_test_serialization_subcmul(const Tensor & self,const Tensor & other,const Scalar & alpha)1535 Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, const Scalar& alpha) {
1536 return self - (other * alpha);
1537 }
1538
TORCH_IMPL_FUNC(heaviside_out)1539 TORCH_IMPL_FUNC(heaviside_out) (
1540 const Tensor& self, const Tensor& other, const Tensor& result
1541 ) {
1542 heaviside_stub(device_type(), *this);
1543 }
1544
_pow2(const Tensor & self,const Tensor & other)1545 static inline Tensor _pow2(const Tensor& self, const Tensor& other) {
1546 const auto self_dtype = self.scalar_type();
1547 // All integral types are promoted to float32
1548 if (isIntegralType(self_dtype, true) || self_dtype == kFloat) {
1549 return at::pow(2.0, other);
1550 }
1551 // For double and reduced floating types do regular type promotion
1552 return at::full({}, 2.0, self.options()).pow(other);
1553 }
1554
ldexp_out(const Tensor & self,const Tensor & other,Tensor & result)1555 Tensor& ldexp_out(const Tensor& self, const Tensor& other, Tensor& result) {
1556 return at::mul_out(result, self, _pow2(self, other));
1557 }
1558
1559
ldexp(const Tensor & self,const Tensor & other)1560 Tensor ldexp(const Tensor& self, const Tensor& other) {
1561 return at::mul(self, _pow2(self, other));
1562 }
1563
ldexp_(Tensor & self,const Tensor & other)1564 Tensor& ldexp_(Tensor& self, const Tensor& other) {
1565 return at::ldexp_out(self, self, other);
1566 }
1567
xlogy_out(const Scalar & self,const Tensor & other,Tensor & result)1568 Tensor& xlogy_out(const Scalar& self, const Tensor& other, Tensor& result) {
1569 return at::xlogy_out(result, wrapped_scalar_tensor(self), other);
1570 }
1571
xlogy_out(const Tensor & self,const Scalar & other,Tensor & result)1572 Tensor& xlogy_out(const Tensor& self, const Scalar& other, Tensor& result) {
1573 return at::xlogy_out(result, self, wrapped_scalar_tensor(other));
1574 }
1575
xlogy(const Scalar & x,const Tensor & y)1576 Tensor xlogy(const Scalar& x, const Tensor& y) {
1577 return at::xlogy(wrapped_scalar_tensor(x), y);
1578 }
1579
xlogy(const Tensor & x,const Scalar & y)1580 Tensor xlogy(const Tensor& x, const Scalar& y) {
1581 return at::xlogy(x, wrapped_scalar_tensor(y));
1582 }
1583
xlogy_(Tensor & x,const Scalar & y)1584 Tensor& xlogy_(Tensor& x, const Scalar& y) {
1585 return at::xlogy_(x, wrapped_scalar_tensor(y));
1586 }
1587
special_xlogy_out(const Tensor & self,const Tensor & other,Tensor & result)1588 Tensor& special_xlogy_out(const Tensor& self, const Tensor& other, Tensor& result) {
1589 return at::xlogy_out(result, self, other);
1590 }
1591
special_xlogy_out(const Scalar & self,const Tensor & other,Tensor & result)1592 Tensor& special_xlogy_out(const Scalar& self, const Tensor& other, Tensor& result) {
1593 return at::xlogy_out(result, self, other);
1594 }
1595
special_xlogy_out(const Tensor & self,const Scalar & other,Tensor & result)1596 Tensor& special_xlogy_out(const Tensor& self, const Scalar& other, Tensor& result) {
1597 return at::xlogy_out(result, self, other);
1598 }
1599
special_xlogy(const Tensor & x,const Tensor & y)1600 Tensor special_xlogy(const Tensor& x, const Tensor& y) {
1601 return at::xlogy(x, y);
1602 }
1603
special_xlogy(const Scalar & x,const Tensor & y)1604 Tensor special_xlogy(const Scalar& x, const Tensor& y) {
1605 return at::xlogy(x, y);
1606 }
1607
special_xlogy(const Tensor & x,const Scalar & y)1608 Tensor special_xlogy(const Tensor& x, const Scalar& y) {
1609 return at::xlogy(x, y);
1610 }
1611
1612 } // namespace at::native
1613