1 #pragma once
2
3 // Complex number math operations that act as no-ops for other dtypes.
4 #include <c10/util/complex.h>
5 #include <c10/util/MathConstants.h>
6 #include<ATen/NumericUtils.h>
7
8 namespace at::native {
9 inline namespace CPU_CAPABILITY {
10
11 template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
zabs(SCALAR_TYPE z)12 inline VALUE_TYPE zabs (SCALAR_TYPE z) {
13 return z;
14 }
15
16 template<>
17 inline c10::complex<float> zabs <c10::complex<float>> (c10::complex<float> z) {
18 return c10::complex<float>(std::abs(z));
19 }
20
21 template<>
22 inline float zabs <c10::complex<float>, float> (c10::complex<float> z) {
23 return std::abs(z);
24 }
25
26 template<>
27 inline c10::complex<double> zabs <c10::complex<double>> (c10::complex<double> z) {
28 return c10::complex<double>(std::abs(z));
29 }
30
31 template<>
32 inline double zabs <c10::complex<double>, double> (c10::complex<double> z) {
33 return std::abs(z);
34 }
35
36 // This overload corresponds to non-complex dtypes.
37 // The function is consistent with its NumPy equivalent
38 // for non-complex dtypes where `pi` is returned for
39 // negative real numbers and `0` is returned for 0 or positive
40 // real numbers.
41 // Note: `nan` is propagated.
42 template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
angle_impl(SCALAR_TYPE z)43 inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
44 if (at::_isnan(z)) {
45 return z;
46 }
47 return z < 0 ? c10::pi<double> : 0;
48 }
49
50 template<>
51 inline c10::complex<float> angle_impl <c10::complex<float>> (c10::complex<float> z) {
52 return c10::complex<float>(std::arg(z), 0.0);
53 }
54
55 template<>
56 inline float angle_impl <c10::complex<float>, float> (c10::complex<float> z) {
57 return std::arg(z);
58 }
59
60 template<>
61 inline c10::complex<double> angle_impl <c10::complex<double>> (c10::complex<double> z) {
62 return c10::complex<double>(std::arg(z), 0.0);
63 }
64
65 template<>
66 inline double angle_impl <c10::complex<double>, double> (c10::complex<double> z) {
67 return std::arg(z);
68 }
69
70 template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
real_impl(SCALAR_TYPE z)71 constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) {
72 return z; //No-Op
73 }
74
75 template<>
76 constexpr c10::complex<float> real_impl <c10::complex<float>> (c10::complex<float> z) {
77 return c10::complex<float>(z.real(), 0.0);
78 }
79
80 template<>
81 constexpr float real_impl <c10::complex<float>, float> (c10::complex<float> z) {
82 return z.real();
83 }
84
85 template<>
86 constexpr c10::complex<double> real_impl <c10::complex<double>> (c10::complex<double> z) {
87 return c10::complex<double>(z.real(), 0.0);
88 }
89
90 template<>
91 constexpr double real_impl <c10::complex<double>, double> (c10::complex<double> z) {
92 return z.real();
93 }
94
95 template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
imag_impl(SCALAR_TYPE)96 constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) {
97 return 0;
98 }
99
100 template<>
101 constexpr c10::complex<float> imag_impl <c10::complex<float>> (c10::complex<float> z) {
102 return c10::complex<float>(z.imag(), 0.0);
103 }
104
105 template<>
106 constexpr float imag_impl <c10::complex<float>, float> (c10::complex<float> z) {
107 return z.imag();
108 }
109
110 template<>
111 constexpr c10::complex<double> imag_impl <c10::complex<double>> (c10::complex<double> z) {
112 return c10::complex<double>(z.imag(), 0.0);
113 }
114
115 template<>
116 constexpr double imag_impl <c10::complex<double>, double> (c10::complex<double> z) {
117 return z.imag();
118 }
119
120 template <typename TYPE>
conj_impl(TYPE z)121 inline TYPE conj_impl (TYPE z) {
122 return z; //No-Op
123 }
124
125 template<>
126 inline c10::complex<at::Half> conj_impl <c10::complex<at::Half>> (c10::complex<at::Half> z) {
127 return c10::complex<at::Half>{z.real(), -z.imag()};
128 }
129
130 template<>
131 inline c10::complex<float> conj_impl <c10::complex<float>> (c10::complex<float> z) {
132 return c10::complex<float>(z.real(), -z.imag());
133 }
134
135 template<>
136 inline c10::complex<double> conj_impl <c10::complex<double>> (c10::complex<double> z) {
137 return c10::complex<double>(z.real(), -z.imag());
138 }
139
140 template <typename TYPE>
ceil_impl(TYPE z)141 inline TYPE ceil_impl (TYPE z) {
142 return std::ceil(z);
143 }
144
145 template <>
ceil_impl(c10::complex<float> z)146 inline c10::complex<float> ceil_impl (c10::complex<float> z) {
147 return c10::complex<float>(std::ceil(z.real()), std::ceil(z.imag()));
148 }
149
150 template <>
ceil_impl(c10::complex<double> z)151 inline c10::complex<double> ceil_impl (c10::complex<double> z) {
152 return c10::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
153 }
154
155 template<typename T>
sgn_impl(c10::complex<T> z)156 inline c10::complex<T> sgn_impl (c10::complex<T> z) {
157 if (z == c10::complex<T>(0, 0)) {
158 return c10::complex<T>(0, 0);
159 } else {
160 return z / zabs(z);
161 }
162 }
163
164 template <typename TYPE>
floor_impl(TYPE z)165 inline TYPE floor_impl (TYPE z) {
166 return std::floor(z);
167 }
168
169 template <>
floor_impl(c10::complex<float> z)170 inline c10::complex<float> floor_impl (c10::complex<float> z) {
171 return c10::complex<float>(std::floor(z.real()), std::floor(z.imag()));
172 }
173
174 template <>
floor_impl(c10::complex<double> z)175 inline c10::complex<double> floor_impl (c10::complex<double> z) {
176 return c10::complex<double>(std::floor(z.real()), std::floor(z.imag()));
177 }
178
179 template <typename TYPE>
round_impl(TYPE z)180 inline TYPE round_impl (TYPE z) {
181 return std::nearbyint(z);
182 }
183
184 template <>
round_impl(c10::complex<float> z)185 inline c10::complex<float> round_impl (c10::complex<float> z) {
186 return c10::complex<float>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
187 }
188
189 template <>
round_impl(c10::complex<double> z)190 inline c10::complex<double> round_impl (c10::complex<double> z) {
191 return c10::complex<double>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
192 }
193
194 template <typename TYPE>
trunc_impl(TYPE z)195 inline TYPE trunc_impl (TYPE z) {
196 return std::trunc(z);
197 }
198
199 template <>
trunc_impl(c10::complex<float> z)200 inline c10::complex<float> trunc_impl (c10::complex<float> z) {
201 return c10::complex<float>(std::trunc(z.real()), std::trunc(z.imag()));
202 }
203
204 template <>
trunc_impl(c10::complex<double> z)205 inline c10::complex<double> trunc_impl (c10::complex<double> z) {
206 return c10::complex<double>(std::trunc(z.real()), std::trunc(z.imag()));
207 }
208
209 template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
max_impl(TYPE a,TYPE b)210 inline TYPE max_impl (TYPE a, TYPE b) {
211 if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
212 return std::numeric_limits<TYPE>::quiet_NaN();
213 } else {
214 return std::max(a, b);
215 }
216 }
217
218 template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
max_impl(TYPE a,TYPE b)219 inline TYPE max_impl (TYPE a, TYPE b) {
220 if (_isnan<TYPE>(a)) {
221 return a;
222 } else if (_isnan<TYPE>(b)) {
223 return b;
224 } else {
225 return std::abs(a) > std::abs(b) ? a : b;
226 }
227 }
228
229 template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
min_impl(TYPE a,TYPE b)230 inline TYPE min_impl (TYPE a, TYPE b) {
231 if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
232 return std::numeric_limits<TYPE>::quiet_NaN();
233 } else {
234 return std::min(a, b);
235 }
236 }
237
238 template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
min_impl(TYPE a,TYPE b)239 inline TYPE min_impl (TYPE a, TYPE b) {
240 if (_isnan<TYPE>(a)) {
241 return a;
242 } else if (_isnan<TYPE>(b)) {
243 return b;
244 } else {
245 return std::abs(a) < std::abs(b) ? a : b;
246 }
247 }
248
249 } // end namespace
250 } //end at::native
251