xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/zmath.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // Complex number math operations that act as no-ops for other dtypes.
4 #include <c10/util/complex.h>
5 #include <c10/util/MathConstants.h>
6 #include<ATen/NumericUtils.h>
7 
8 namespace at::native {
9 inline namespace CPU_CAPABILITY {
10 
11 template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
zabs(SCALAR_TYPE z)12 inline VALUE_TYPE zabs (SCALAR_TYPE z) {
13   return z;
14 }
15 
16 template<>
17 inline c10::complex<float> zabs <c10::complex<float>> (c10::complex<float> z) {
18   return c10::complex<float>(std::abs(z));
19 }
20 
21 template<>
22 inline float zabs <c10::complex<float>, float> (c10::complex<float> z) {
23   return std::abs(z);
24 }
25 
26 template<>
27 inline c10::complex<double> zabs <c10::complex<double>> (c10::complex<double> z) {
28   return c10::complex<double>(std::abs(z));
29 }
30 
31 template<>
32 inline double zabs <c10::complex<double>, double> (c10::complex<double> z) {
33   return std::abs(z);
34 }
35 
36 // This overload corresponds to non-complex dtypes.
37 // The function is consistent with its NumPy equivalent
38 // for non-complex dtypes where `pi` is returned for
39 // negative real numbers and `0` is returned for 0 or positive
40 // real numbers.
41 // Note: `nan` is propagated.
42 template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
angle_impl(SCALAR_TYPE z)43 inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
44   if (at::_isnan(z)) {
45     return z;
46   }
47   return z < 0 ? c10::pi<double> : 0;
48 }
49 
50 template<>
51 inline c10::complex<float> angle_impl <c10::complex<float>> (c10::complex<float> z) {
52   return c10::complex<float>(std::arg(z), 0.0);
53 }
54 
55 template<>
56 inline float angle_impl <c10::complex<float>, float> (c10::complex<float> z) {
57   return std::arg(z);
58 }
59 
60 template<>
61 inline c10::complex<double> angle_impl <c10::complex<double>> (c10::complex<double> z) {
62   return c10::complex<double>(std::arg(z), 0.0);
63 }
64 
65 template<>
66 inline double angle_impl <c10::complex<double>, double> (c10::complex<double> z) {
67   return std::arg(z);
68 }
69 
70 template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
real_impl(SCALAR_TYPE z)71 constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) {
72   return z; //No-Op
73 }
74 
75 template<>
76 constexpr c10::complex<float> real_impl <c10::complex<float>> (c10::complex<float> z) {
77   return c10::complex<float>(z.real(), 0.0);
78 }
79 
80 template<>
81 constexpr float real_impl <c10::complex<float>, float> (c10::complex<float> z) {
82   return z.real();
83 }
84 
85 template<>
86 constexpr c10::complex<double> real_impl <c10::complex<double>> (c10::complex<double> z) {
87   return c10::complex<double>(z.real(), 0.0);
88 }
89 
90 template<>
91 constexpr double real_impl <c10::complex<double>, double> (c10::complex<double> z) {
92   return z.real();
93 }
94 
95 template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
imag_impl(SCALAR_TYPE)96 constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) {
97   return 0;
98 }
99 
100 template<>
101 constexpr c10::complex<float> imag_impl <c10::complex<float>> (c10::complex<float> z) {
102   return c10::complex<float>(z.imag(), 0.0);
103 }
104 
105 template<>
106 constexpr float imag_impl <c10::complex<float>, float> (c10::complex<float> z) {
107   return z.imag();
108 }
109 
110 template<>
111 constexpr c10::complex<double> imag_impl <c10::complex<double>> (c10::complex<double> z) {
112   return c10::complex<double>(z.imag(), 0.0);
113 }
114 
115 template<>
116 constexpr double imag_impl <c10::complex<double>, double> (c10::complex<double> z) {
117   return z.imag();
118 }
119 
120 template <typename TYPE>
conj_impl(TYPE z)121 inline TYPE conj_impl (TYPE z) {
122   return z; //No-Op
123 }
124 
125 template<>
126 inline c10::complex<at::Half> conj_impl <c10::complex<at::Half>> (c10::complex<at::Half> z) {
127   return c10::complex<at::Half>{z.real(), -z.imag()};
128 }
129 
130 template<>
131 inline c10::complex<float> conj_impl <c10::complex<float>> (c10::complex<float> z) {
132   return c10::complex<float>(z.real(), -z.imag());
133 }
134 
135 template<>
136 inline c10::complex<double> conj_impl <c10::complex<double>> (c10::complex<double> z) {
137   return c10::complex<double>(z.real(), -z.imag());
138 }
139 
140 template <typename TYPE>
ceil_impl(TYPE z)141 inline TYPE ceil_impl (TYPE z) {
142   return std::ceil(z);
143 }
144 
145 template <>
ceil_impl(c10::complex<float> z)146 inline c10::complex<float> ceil_impl (c10::complex<float> z) {
147   return c10::complex<float>(std::ceil(z.real()), std::ceil(z.imag()));
148 }
149 
150 template <>
ceil_impl(c10::complex<double> z)151 inline c10::complex<double> ceil_impl (c10::complex<double> z) {
152   return c10::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
153 }
154 
155 template<typename T>
sgn_impl(c10::complex<T> z)156 inline c10::complex<T> sgn_impl (c10::complex<T> z) {
157   if (z == c10::complex<T>(0, 0)) {
158     return c10::complex<T>(0, 0);
159   } else {
160     return z / zabs(z);
161   }
162 }
163 
164 template <typename TYPE>
floor_impl(TYPE z)165 inline TYPE floor_impl (TYPE z) {
166   return std::floor(z);
167 }
168 
169 template <>
floor_impl(c10::complex<float> z)170 inline c10::complex<float> floor_impl (c10::complex<float> z) {
171   return c10::complex<float>(std::floor(z.real()), std::floor(z.imag()));
172 }
173 
174 template <>
floor_impl(c10::complex<double> z)175 inline c10::complex<double> floor_impl (c10::complex<double> z) {
176   return c10::complex<double>(std::floor(z.real()), std::floor(z.imag()));
177 }
178 
179 template <typename TYPE>
round_impl(TYPE z)180 inline TYPE round_impl (TYPE z) {
181   return std::nearbyint(z);
182 }
183 
184 template <>
round_impl(c10::complex<float> z)185 inline c10::complex<float> round_impl (c10::complex<float> z) {
186   return c10::complex<float>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
187 }
188 
189 template <>
round_impl(c10::complex<double> z)190 inline c10::complex<double> round_impl (c10::complex<double> z) {
191   return c10::complex<double>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
192 }
193 
194 template <typename TYPE>
trunc_impl(TYPE z)195 inline TYPE trunc_impl (TYPE z) {
196   return std::trunc(z);
197 }
198 
199 template <>
trunc_impl(c10::complex<float> z)200 inline c10::complex<float> trunc_impl (c10::complex<float> z) {
201   return c10::complex<float>(std::trunc(z.real()), std::trunc(z.imag()));
202 }
203 
204 template <>
trunc_impl(c10::complex<double> z)205 inline c10::complex<double> trunc_impl (c10::complex<double> z) {
206   return c10::complex<double>(std::trunc(z.real()), std::trunc(z.imag()));
207 }
208 
209 template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
max_impl(TYPE a,TYPE b)210 inline TYPE max_impl (TYPE a, TYPE b) {
211   if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
212     return std::numeric_limits<TYPE>::quiet_NaN();
213   } else {
214     return std::max(a, b);
215   }
216 }
217 
218 template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
max_impl(TYPE a,TYPE b)219 inline TYPE max_impl (TYPE a, TYPE b) {
220   if (_isnan<TYPE>(a)) {
221     return a;
222   } else if (_isnan<TYPE>(b)) {
223     return b;
224   } else {
225     return std::abs(a) > std::abs(b) ? a : b;
226   }
227 }
228 
229 template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
min_impl(TYPE a,TYPE b)230 inline TYPE min_impl (TYPE a, TYPE b) {
231   if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
232     return std::numeric_limits<TYPE>::quiet_NaN();
233   } else {
234     return std::min(a, b);
235   }
236 }
237 
238 template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
min_impl(TYPE a,TYPE b)239 inline TYPE min_impl (TYPE a, TYPE b) {
240   if (_isnan<TYPE>(a)) {
241     return a;
242   } else if (_isnan<TYPE>(b)) {
243     return b;
244   } else {
245     return std::abs(a) < std::abs(b) ? a : b;
246   }
247 }
248 
249 } // end namespace
250 } //end at::native
251