1 #if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
2 #error \
3 "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead."
4 #endif
5
6 namespace c10_complex_math {
7
8 // Exponential functions
9
10 template <typename T>
exp(const c10::complex<T> & x)11 C10_HOST_DEVICE inline c10::complex<T> exp(const c10::complex<T>& x) {
12 #if defined(__CUDACC__) || defined(__HIPCC__)
13 return static_cast<c10::complex<T>>(
14 thrust::exp(static_cast<thrust::complex<T>>(x)));
15 #else
16 return static_cast<c10::complex<T>>(
17 std::exp(static_cast<std::complex<T>>(x)));
18 #endif
19 }
20
21 template <typename T>
log(const c10::complex<T> & x)22 C10_HOST_DEVICE inline c10::complex<T> log(const c10::complex<T>& x) {
23 #if defined(__CUDACC__) || defined(__HIPCC__)
24 return static_cast<c10::complex<T>>(
25 thrust::log(static_cast<thrust::complex<T>>(x)));
26 #else
27 return static_cast<c10::complex<T>>(
28 std::log(static_cast<std::complex<T>>(x)));
29 #endif
30 }
31
32 template <typename T>
log10(const c10::complex<T> & x)33 C10_HOST_DEVICE inline c10::complex<T> log10(const c10::complex<T>& x) {
34 #if defined(__CUDACC__) || defined(__HIPCC__)
35 return static_cast<c10::complex<T>>(
36 thrust::log10(static_cast<thrust::complex<T>>(x)));
37 #else
38 return static_cast<c10::complex<T>>(
39 std::log10(static_cast<std::complex<T>>(x)));
40 #endif
41 }
42
43 template <typename T>
log2(const c10::complex<T> & x)44 C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T>& x) {
45 const c10::complex<T> log2 = c10::complex<T>(::log(2.0), 0.0);
46 return c10_complex_math::log(x) / log2;
47 }
48
49 // Power functions
50 //
51 #if defined(_LIBCPP_VERSION) || \
52 (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
53 namespace _detail {
54 C10_API c10::complex<float> sqrt(const c10::complex<float>& in);
55 C10_API c10::complex<double> sqrt(const c10::complex<double>& in);
56 C10_API c10::complex<float> acos(const c10::complex<float>& in);
57 C10_API c10::complex<double> acos(const c10::complex<double>& in);
58 } // namespace _detail
59 #endif
60
61 template <typename T>
sqrt(const c10::complex<T> & x)62 C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T>& x) {
63 #if defined(__CUDACC__) || defined(__HIPCC__)
64 return static_cast<c10::complex<T>>(
65 thrust::sqrt(static_cast<thrust::complex<T>>(x)));
66 #elif !( \
67 defined(_LIBCPP_VERSION) || \
68 (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)))
69 return static_cast<c10::complex<T>>(
70 std::sqrt(static_cast<std::complex<T>>(x)));
71 #else
72 return _detail::sqrt(x);
73 #endif
74 }
75
76 template <typename T>
pow(const c10::complex<T> & x,const c10::complex<T> & y)77 C10_HOST_DEVICE inline c10::complex<T> pow(
78 const c10::complex<T>& x,
79 const c10::complex<T>& y) {
80 #if defined(__CUDACC__) || defined(__HIPCC__)
81 return static_cast<c10::complex<T>>(thrust::pow(
82 static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
83 #else
84 return static_cast<c10::complex<T>>(std::pow(
85 static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
86 #endif
87 }
88
89 template <typename T>
pow(const c10::complex<T> & x,const T & y)90 C10_HOST_DEVICE inline c10::complex<T> pow(
91 const c10::complex<T>& x,
92 const T& y) {
93 #if defined(__CUDACC__) || defined(__HIPCC__)
94 return static_cast<c10::complex<T>>(
95 thrust::pow(static_cast<thrust::complex<T>>(x), y));
96 #else
97 return static_cast<c10::complex<T>>(
98 std::pow(static_cast<std::complex<T>>(x), y));
99 #endif
100 }
101
102 template <typename T>
pow(const T & x,const c10::complex<T> & y)103 C10_HOST_DEVICE inline c10::complex<T> pow(
104 const T& x,
105 const c10::complex<T>& y) {
106 #if defined(__CUDACC__) || defined(__HIPCC__)
107 return static_cast<c10::complex<T>>(
108 thrust::pow(x, static_cast<thrust::complex<T>>(y)));
109 #else
110 return static_cast<c10::complex<T>>(
111 std::pow(x, static_cast<std::complex<T>>(y)));
112 #endif
113 }
114
115 template <typename T, typename U>
pow(const c10::complex<T> & x,const c10::complex<U> & y)116 C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
117 const c10::complex<T>& x,
118 const c10::complex<U>& y) {
119 #if defined(__CUDACC__) || defined(__HIPCC__)
120 return static_cast<c10::complex<T>>(thrust::pow(
121 static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
122 #else
123 return static_cast<c10::complex<T>>(std::pow(
124 static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
125 #endif
126 }
127
128 template <typename T, typename U>
pow(const c10::complex<T> & x,const U & y)129 C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
130 const c10::complex<T>& x,
131 const U& y) {
132 #if defined(__CUDACC__) || defined(__HIPCC__)
133 return static_cast<c10::complex<T>>(
134 thrust::pow(static_cast<thrust::complex<T>>(x), y));
135 #else
136 return static_cast<c10::complex<T>>(
137 std::pow(static_cast<std::complex<T>>(x), y));
138 #endif
139 }
140
141 template <typename T, typename U>
pow(const T & x,const c10::complex<U> & y)142 C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
143 const T& x,
144 const c10::complex<U>& y) {
145 #if defined(__CUDACC__) || defined(__HIPCC__)
146 return static_cast<c10::complex<T>>(
147 thrust::pow(x, static_cast<thrust::complex<T>>(y)));
148 #else
149 return static_cast<c10::complex<T>>(
150 std::pow(x, static_cast<std::complex<T>>(y)));
151 #endif
152 }
153
154 // Trigonometric functions
155
156 template <typename T>
sin(const c10::complex<T> & x)157 C10_HOST_DEVICE inline c10::complex<T> sin(const c10::complex<T>& x) {
158 #if defined(__CUDACC__) || defined(__HIPCC__)
159 return static_cast<c10::complex<T>>(
160 thrust::sin(static_cast<thrust::complex<T>>(x)));
161 #else
162 return static_cast<c10::complex<T>>(
163 std::sin(static_cast<std::complex<T>>(x)));
164 #endif
165 }
166
167 template <typename T>
cos(const c10::complex<T> & x)168 C10_HOST_DEVICE inline c10::complex<T> cos(const c10::complex<T>& x) {
169 #if defined(__CUDACC__) || defined(__HIPCC__)
170 return static_cast<c10::complex<T>>(
171 thrust::cos(static_cast<thrust::complex<T>>(x)));
172 #else
173 return static_cast<c10::complex<T>>(
174 std::cos(static_cast<std::complex<T>>(x)));
175 #endif
176 }
177
178 template <typename T>
tan(const c10::complex<T> & x)179 C10_HOST_DEVICE inline c10::complex<T> tan(const c10::complex<T>& x) {
180 #if defined(__CUDACC__) || defined(__HIPCC__)
181 return static_cast<c10::complex<T>>(
182 thrust::tan(static_cast<thrust::complex<T>>(x)));
183 #else
184 return static_cast<c10::complex<T>>(
185 std::tan(static_cast<std::complex<T>>(x)));
186 #endif
187 }
188
189 template <typename T>
asin(const c10::complex<T> & x)190 C10_HOST_DEVICE inline c10::complex<T> asin(const c10::complex<T>& x) {
191 #if defined(__CUDACC__) || defined(__HIPCC__)
192 return static_cast<c10::complex<T>>(
193 thrust::asin(static_cast<thrust::complex<T>>(x)));
194 #else
195 return static_cast<c10::complex<T>>(
196 std::asin(static_cast<std::complex<T>>(x)));
197 #endif
198 }
199
200 template <typename T>
acos(const c10::complex<T> & x)201 C10_HOST_DEVICE inline c10::complex<T> acos(const c10::complex<T>& x) {
202 #if defined(__CUDACC__) || defined(__HIPCC__)
203 return static_cast<c10::complex<T>>(
204 thrust::acos(static_cast<thrust::complex<T>>(x)));
205 #elif !defined(_LIBCPP_VERSION)
206 return static_cast<c10::complex<T>>(
207 std::acos(static_cast<std::complex<T>>(x)));
208 #else
209 return _detail::acos(x);
210 #endif
211 }
212
213 template <typename T>
atan(const c10::complex<T> & x)214 C10_HOST_DEVICE inline c10::complex<T> atan(const c10::complex<T>& x) {
215 #if defined(__CUDACC__) || defined(__HIPCC__)
216 return static_cast<c10::complex<T>>(
217 thrust::atan(static_cast<thrust::complex<T>>(x)));
218 #else
219 return static_cast<c10::complex<T>>(
220 std::atan(static_cast<std::complex<T>>(x)));
221 #endif
222 }
223
224 // Hyperbolic functions
225
226 template <typename T>
sinh(const c10::complex<T> & x)227 C10_HOST_DEVICE inline c10::complex<T> sinh(const c10::complex<T>& x) {
228 #if defined(__CUDACC__) || defined(__HIPCC__)
229 return static_cast<c10::complex<T>>(
230 thrust::sinh(static_cast<thrust::complex<T>>(x)));
231 #else
232 return static_cast<c10::complex<T>>(
233 std::sinh(static_cast<std::complex<T>>(x)));
234 #endif
235 }
236
237 template <typename T>
cosh(const c10::complex<T> & x)238 C10_HOST_DEVICE inline c10::complex<T> cosh(const c10::complex<T>& x) {
239 #if defined(__CUDACC__) || defined(__HIPCC__)
240 return static_cast<c10::complex<T>>(
241 thrust::cosh(static_cast<thrust::complex<T>>(x)));
242 #else
243 return static_cast<c10::complex<T>>(
244 std::cosh(static_cast<std::complex<T>>(x)));
245 #endif
246 }
247
248 template <typename T>
tanh(const c10::complex<T> & x)249 C10_HOST_DEVICE inline c10::complex<T> tanh(const c10::complex<T>& x) {
250 #if defined(__CUDACC__) || defined(__HIPCC__)
251 return static_cast<c10::complex<T>>(
252 thrust::tanh(static_cast<thrust::complex<T>>(x)));
253 #else
254 return static_cast<c10::complex<T>>(
255 std::tanh(static_cast<std::complex<T>>(x)));
256 #endif
257 }
258
259 template <typename T>
asinh(const c10::complex<T> & x)260 C10_HOST_DEVICE inline c10::complex<T> asinh(const c10::complex<T>& x) {
261 #if defined(__CUDACC__) || defined(__HIPCC__)
262 return static_cast<c10::complex<T>>(
263 thrust::asinh(static_cast<thrust::complex<T>>(x)));
264 #else
265 return static_cast<c10::complex<T>>(
266 std::asinh(static_cast<std::complex<T>>(x)));
267 #endif
268 }
269
270 template <typename T>
acosh(const c10::complex<T> & x)271 C10_HOST_DEVICE inline c10::complex<T> acosh(const c10::complex<T>& x) {
272 #if defined(__CUDACC__) || defined(__HIPCC__)
273 return static_cast<c10::complex<T>>(
274 thrust::acosh(static_cast<thrust::complex<T>>(x)));
275 #else
276 return static_cast<c10::complex<T>>(
277 std::acosh(static_cast<std::complex<T>>(x)));
278 #endif
279 }
280
281 template <typename T>
atanh(const c10::complex<T> & x)282 C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
283 #if defined(__CUDACC__) || defined(__HIPCC__)
284 return static_cast<c10::complex<T>>(
285 thrust::atanh(static_cast<thrust::complex<T>>(x)));
286 #else
287 return static_cast<c10::complex<T>>(
288 std::atanh(static_cast<std::complex<T>>(x)));
289 #endif
290 }
291
292 template <typename T>
log1p(const c10::complex<T> & z)293 C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
294 #if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \
295 defined(__HIPCC__)
296 // For Mac, the new implementation yielded a high relative error. Falling back
297 // to the old version for now.
298 // See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
299 // For CUDA we also use this one, as thrust::log(thrust::complex) takes
300 // *forever* to compile
301
302 // log1p(z) = log(1 + z)
303 // Let's define 1 + z = r * e ^ (i * a), then we have
304 // log(r * e ^ (i * a)) = log(r) + i * a
305 // With z = x + iy, the term r can be written as
306 // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5
307 // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5
308 // So, log(r) is
309 // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2)
310 // = 0.5 * log1p(x * (x + 2) + y ^ 2)
311 // we need to use the expression only on certain condition to avoid overflow
312 // and underflow from `(x * (x + 2) + y ^ 2)`
313 T x = z.real();
314 T y = z.imag();
315 T zabs = std::abs(z);
316 T theta = std::atan2(y, x + T(1));
317 if (zabs < 0.5) {
318 T r = x * (T(2) + x) + y * y;
319 if (r == 0) { // handle underflow
320 return {x, theta};
321 }
322 return {T(0.5) * std::log1p(r), theta};
323 } else {
324 T z0 = std::hypot(x + 1, y);
325 return {std::log(z0), theta};
326 }
327 #else
328 // CPU path
329 // Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
330 c10::complex<T> u = z + T(1);
331 if (u == T(1)) {
332 return z;
333 } else {
334 auto log_u = log(u);
335 if (u - T(1) == z) {
336 return log_u;
337 }
338 return log_u * (z / (u - T(1)));
339 }
340 #endif
341 }
342
343 template <typename T>
expm1(const c10::complex<T> & z)344 C10_HOST_DEVICE inline c10::complex<T> expm1(const c10::complex<T>& z) {
345 // expm1(z) = exp(z) - 1
346 // Define z = x + i * y
347 // f = e ^ (x + i * y) - 1
348 // = e ^ x * e ^ (i * y) - 1
349 // = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y))
350 // = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y)
351 // = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y)
352 T x = z.real();
353 T y = z.imag();
354 T a = std::sin(y / 2);
355 T er = std::expm1(x) * std::cos(y) - T(2) * a * a;
356 T ei = std::exp(x) * std::sin(y);
357 return {er, ei};
358 }
359
360 } // namespace c10_complex_math
361
362 using c10_complex_math::acos;
363 using c10_complex_math::acosh;
364 using c10_complex_math::asin;
365 using c10_complex_math::asinh;
366 using c10_complex_math::atan;
367 using c10_complex_math::atanh;
368 using c10_complex_math::cos;
369 using c10_complex_math::cosh;
370 using c10_complex_math::exp;
371 using c10_complex_math::expm1;
372 using c10_complex_math::log;
373 using c10_complex_math::log10;
374 using c10_complex_math::log1p;
375 using c10_complex_math::log2;
376 using c10_complex_math::pow;
377 using c10_complex_math::sin;
378 using c10_complex_math::sinh;
379 using c10_complex_math::sqrt;
380 using c10_complex_math::tan;
381 using c10_complex_math::tanh;
382
383 namespace std {
384
385 using c10_complex_math::acos;
386 using c10_complex_math::acosh;
387 using c10_complex_math::asin;
388 using c10_complex_math::asinh;
389 using c10_complex_math::atan;
390 using c10_complex_math::atanh;
391 using c10_complex_math::cos;
392 using c10_complex_math::cosh;
393 using c10_complex_math::exp;
394 using c10_complex_math::expm1;
395 using c10_complex_math::log;
396 using c10_complex_math::log10;
397 using c10_complex_math::log1p;
398 using c10_complex_math::log2;
399 using c10_complex_math::pow;
400 using c10_complex_math::sin;
401 using c10_complex_math::sinh;
402 using c10_complex_math::sqrt;
403 using c10_complex_math::tan;
404 using c10_complex_math::tanh;
405
406 } // namespace std
407