xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/BinaryOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/TensorBase.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <c10/core/Scalar.h>
6 #include <c10/util/TypeSafeSignMath.h>
7 
8 
9 namespace at {
10 struct TensorIterator;
11 struct TensorIteratorBase;
12 }
13 
14 namespace at::native {
15 
alpha_check(const ScalarType dtype,const Scalar & alpha)16 inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
17   TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
18               "Boolean alpha only supported for Boolean results.");
19   TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
20               || alpha.isIntegral(true),
21               "For integral input tensors, argument alpha must not be a floating point number.");
22   TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
23               "For non-complex input tensors, argument alpha must not be a complex number.")
24 }
25 
26 // Basic checking for all sub functions.
sub_check(const TensorBase & self,const TensorBase & other)27 inline void sub_check(const TensorBase& self, const TensorBase& other) {
28   TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
29               "Subtraction, the `-` operator, with two bool tensors is not supported. "
30               "Use the `^` or `logical_xor()` operator instead.")
31   TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
32               "Subtraction, the `-` operator, with a bool tensor is not supported. "
33               "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
34 }
35 
sub_check(const TensorBase & self,const Scalar & scalar)36 inline void sub_check(const TensorBase& self, const Scalar& scalar) {
37   TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
38               "Subtraction, the `-` operator, with two bool tensors is not supported. "
39               "Use the `^` or `logical_xor()` operator instead.")
40   TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
41               "Subtraction, the `-` operator, with a bool tensor is not supported. "
42               "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
43 }
44 
45 using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
46 using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
47 using structured_binary_fn = void(*)(TensorIteratorBase&);
48 
49 using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
50 using binary_fn_double = void(*)(TensorIterator&, double);
51 using binary_fn = void(*)(TensorIterator&);
52 using binary_clamp_fn_alpha =
53     void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
54 
55 // NB: codegenned
56 DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
57 
58 DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
59 DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub);
60 DECLARE_DISPATCH(structured_binary_fn, mul_stub);
61 DECLARE_DISPATCH(structured_binary_fn, div_true_stub);
62 DECLARE_DISPATCH(structured_binary_fn, div_floor_stub);
63 DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub);
64 DECLARE_DISPATCH(structured_binary_fn, atan2_stub);
65 DECLARE_DISPATCH(structured_binary_fn, remainder_stub);
66 DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub);
67 DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub);
68 DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub);
69 DECLARE_DISPATCH(structured_binary_fn, lshift_stub);
70 DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
71 DECLARE_DISPATCH(binary_fn, logical_xor_stub);
72 DECLARE_DISPATCH(binary_fn, logical_and_stub);
73 DECLARE_DISPATCH(binary_fn, logical_or_stub);
74 DECLARE_DISPATCH(structured_binary_fn, lt_stub);
75 DECLARE_DISPATCH(structured_binary_fn, le_stub);
76 DECLARE_DISPATCH(structured_binary_fn, gt_stub);
77 DECLARE_DISPATCH(structured_binary_fn, ge_stub);
78 DECLARE_DISPATCH(structured_binary_fn, eq_stub);
79 DECLARE_DISPATCH(structured_binary_fn, ne_stub);
80 DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
81 DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
82 DECLARE_DISPATCH(structured_binary_fn, maximum_stub);
83 DECLARE_DISPATCH(structured_binary_fn, minimum_stub);
84 DECLARE_DISPATCH(structured_binary_fn, fmax_stub);
85 DECLARE_DISPATCH(structured_binary_fn, fmin_stub);
86 DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub);
87 DECLARE_DISPATCH(binary_fn_double, huber_stub);
88 DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub);
89 DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
90 DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub);
91 DECLARE_DISPATCH(structured_binary_fn, mse_stub);
92 DECLARE_DISPATCH(structured_binary_fn, fmod_stub);
93 DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub);
94 DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub);
95 DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
96 DECLARE_DISPATCH(structured_binary_fn, lcm_stub);
97 DECLARE_DISPATCH(structured_binary_fn, hypot_stub);
98 DECLARE_DISPATCH(structured_binary_fn, igamma_stub);
99 DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
100 DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
101 DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
102 DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
103 DECLARE_DISPATCH(structured_binary_fn, xlogy_stub);
104 DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
105 DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
106 DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub);
107 DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub);
108 DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub);
109 DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub);
110 DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub);
111 DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub);
112 DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub);
113 DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub);
114 DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub);
115 DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub);
116 DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub);
117 DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub);
118 
119 } // namespace at::native
120