xref: /aosp_15_r20/external/pytorch/c10/util/complex_math.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
2 #error \
3     "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead."
4 #endif
5 
6 namespace c10_complex_math {
7 
8 // Exponential functions
9 
10 template <typename T>
exp(const c10::complex<T> & x)11 C10_HOST_DEVICE inline c10::complex<T> exp(const c10::complex<T>& x) {
12 #if defined(__CUDACC__) || defined(__HIPCC__)
13   return static_cast<c10::complex<T>>(
14       thrust::exp(static_cast<thrust::complex<T>>(x)));
15 #else
16   return static_cast<c10::complex<T>>(
17       std::exp(static_cast<std::complex<T>>(x)));
18 #endif
19 }
20 
21 template <typename T>
log(const c10::complex<T> & x)22 C10_HOST_DEVICE inline c10::complex<T> log(const c10::complex<T>& x) {
23 #if defined(__CUDACC__) || defined(__HIPCC__)
24   return static_cast<c10::complex<T>>(
25       thrust::log(static_cast<thrust::complex<T>>(x)));
26 #else
27   return static_cast<c10::complex<T>>(
28       std::log(static_cast<std::complex<T>>(x)));
29 #endif
30 }
31 
32 template <typename T>
log10(const c10::complex<T> & x)33 C10_HOST_DEVICE inline c10::complex<T> log10(const c10::complex<T>& x) {
34 #if defined(__CUDACC__) || defined(__HIPCC__)
35   return static_cast<c10::complex<T>>(
36       thrust::log10(static_cast<thrust::complex<T>>(x)));
37 #else
38   return static_cast<c10::complex<T>>(
39       std::log10(static_cast<std::complex<T>>(x)));
40 #endif
41 }
42 
43 template <typename T>
log2(const c10::complex<T> & x)44 C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T>& x) {
45   const c10::complex<T> log2 = c10::complex<T>(::log(2.0), 0.0);
46   return c10_complex_math::log(x) / log2;
47 }
48 
49 // Power functions
50 //
51 #if defined(_LIBCPP_VERSION) || \
52     (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
53 namespace _detail {
54 C10_API c10::complex<float> sqrt(const c10::complex<float>& in);
55 C10_API c10::complex<double> sqrt(const c10::complex<double>& in);
56 C10_API c10::complex<float> acos(const c10::complex<float>& in);
57 C10_API c10::complex<double> acos(const c10::complex<double>& in);
58 } // namespace _detail
59 #endif
60 
61 template <typename T>
sqrt(const c10::complex<T> & x)62 C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T>& x) {
63 #if defined(__CUDACC__) || defined(__HIPCC__)
64   return static_cast<c10::complex<T>>(
65       thrust::sqrt(static_cast<thrust::complex<T>>(x)));
66 #elif !(                        \
67     defined(_LIBCPP_VERSION) || \
68     (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)))
69   return static_cast<c10::complex<T>>(
70       std::sqrt(static_cast<std::complex<T>>(x)));
71 #else
72   return _detail::sqrt(x);
73 #endif
74 }
75 
76 template <typename T>
pow(const c10::complex<T> & x,const c10::complex<T> & y)77 C10_HOST_DEVICE inline c10::complex<T> pow(
78     const c10::complex<T>& x,
79     const c10::complex<T>& y) {
80 #if defined(__CUDACC__) || defined(__HIPCC__)
81   return static_cast<c10::complex<T>>(thrust::pow(
82       static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
83 #else
84   return static_cast<c10::complex<T>>(std::pow(
85       static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
86 #endif
87 }
88 
89 template <typename T>
pow(const c10::complex<T> & x,const T & y)90 C10_HOST_DEVICE inline c10::complex<T> pow(
91     const c10::complex<T>& x,
92     const T& y) {
93 #if defined(__CUDACC__) || defined(__HIPCC__)
94   return static_cast<c10::complex<T>>(
95       thrust::pow(static_cast<thrust::complex<T>>(x), y));
96 #else
97   return static_cast<c10::complex<T>>(
98       std::pow(static_cast<std::complex<T>>(x), y));
99 #endif
100 }
101 
102 template <typename T>
pow(const T & x,const c10::complex<T> & y)103 C10_HOST_DEVICE inline c10::complex<T> pow(
104     const T& x,
105     const c10::complex<T>& y) {
106 #if defined(__CUDACC__) || defined(__HIPCC__)
107   return static_cast<c10::complex<T>>(
108       thrust::pow(x, static_cast<thrust::complex<T>>(y)));
109 #else
110   return static_cast<c10::complex<T>>(
111       std::pow(x, static_cast<std::complex<T>>(y)));
112 #endif
113 }
114 
115 template <typename T, typename U>
pow(const c10::complex<T> & x,const c10::complex<U> & y)116 C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
117     const c10::complex<T>& x,
118     const c10::complex<U>& y) {
119 #if defined(__CUDACC__) || defined(__HIPCC__)
120   return static_cast<c10::complex<T>>(thrust::pow(
121       static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
122 #else
123   return static_cast<c10::complex<T>>(std::pow(
124       static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
125 #endif
126 }
127 
128 template <typename T, typename U>
pow(const c10::complex<T> & x,const U & y)129 C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
130     const c10::complex<T>& x,
131     const U& y) {
132 #if defined(__CUDACC__) || defined(__HIPCC__)
133   return static_cast<c10::complex<T>>(
134       thrust::pow(static_cast<thrust::complex<T>>(x), y));
135 #else
136   return static_cast<c10::complex<T>>(
137       std::pow(static_cast<std::complex<T>>(x), y));
138 #endif
139 }
140 
141 template <typename T, typename U>
pow(const T & x,const c10::complex<U> & y)142 C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
143     const T& x,
144     const c10::complex<U>& y) {
145 #if defined(__CUDACC__) || defined(__HIPCC__)
146   return static_cast<c10::complex<T>>(
147       thrust::pow(x, static_cast<thrust::complex<T>>(y)));
148 #else
149   return static_cast<c10::complex<T>>(
150       std::pow(x, static_cast<std::complex<T>>(y)));
151 #endif
152 }
153 
154 // Trigonometric functions
155 
156 template <typename T>
sin(const c10::complex<T> & x)157 C10_HOST_DEVICE inline c10::complex<T> sin(const c10::complex<T>& x) {
158 #if defined(__CUDACC__) || defined(__HIPCC__)
159   return static_cast<c10::complex<T>>(
160       thrust::sin(static_cast<thrust::complex<T>>(x)));
161 #else
162   return static_cast<c10::complex<T>>(
163       std::sin(static_cast<std::complex<T>>(x)));
164 #endif
165 }
166 
167 template <typename T>
cos(const c10::complex<T> & x)168 C10_HOST_DEVICE inline c10::complex<T> cos(const c10::complex<T>& x) {
169 #if defined(__CUDACC__) || defined(__HIPCC__)
170   return static_cast<c10::complex<T>>(
171       thrust::cos(static_cast<thrust::complex<T>>(x)));
172 #else
173   return static_cast<c10::complex<T>>(
174       std::cos(static_cast<std::complex<T>>(x)));
175 #endif
176 }
177 
178 template <typename T>
tan(const c10::complex<T> & x)179 C10_HOST_DEVICE inline c10::complex<T> tan(const c10::complex<T>& x) {
180 #if defined(__CUDACC__) || defined(__HIPCC__)
181   return static_cast<c10::complex<T>>(
182       thrust::tan(static_cast<thrust::complex<T>>(x)));
183 #else
184   return static_cast<c10::complex<T>>(
185       std::tan(static_cast<std::complex<T>>(x)));
186 #endif
187 }
188 
189 template <typename T>
asin(const c10::complex<T> & x)190 C10_HOST_DEVICE inline c10::complex<T> asin(const c10::complex<T>& x) {
191 #if defined(__CUDACC__) || defined(__HIPCC__)
192   return static_cast<c10::complex<T>>(
193       thrust::asin(static_cast<thrust::complex<T>>(x)));
194 #else
195   return static_cast<c10::complex<T>>(
196       std::asin(static_cast<std::complex<T>>(x)));
197 #endif
198 }
199 
200 template <typename T>
acos(const c10::complex<T> & x)201 C10_HOST_DEVICE inline c10::complex<T> acos(const c10::complex<T>& x) {
202 #if defined(__CUDACC__) || defined(__HIPCC__)
203   return static_cast<c10::complex<T>>(
204       thrust::acos(static_cast<thrust::complex<T>>(x)));
205 #elif !defined(_LIBCPP_VERSION)
206   return static_cast<c10::complex<T>>(
207       std::acos(static_cast<std::complex<T>>(x)));
208 #else
209   return _detail::acos(x);
210 #endif
211 }
212 
213 template <typename T>
atan(const c10::complex<T> & x)214 C10_HOST_DEVICE inline c10::complex<T> atan(const c10::complex<T>& x) {
215 #if defined(__CUDACC__) || defined(__HIPCC__)
216   return static_cast<c10::complex<T>>(
217       thrust::atan(static_cast<thrust::complex<T>>(x)));
218 #else
219   return static_cast<c10::complex<T>>(
220       std::atan(static_cast<std::complex<T>>(x)));
221 #endif
222 }
223 
224 // Hyperbolic functions
225 
226 template <typename T>
sinh(const c10::complex<T> & x)227 C10_HOST_DEVICE inline c10::complex<T> sinh(const c10::complex<T>& x) {
228 #if defined(__CUDACC__) || defined(__HIPCC__)
229   return static_cast<c10::complex<T>>(
230       thrust::sinh(static_cast<thrust::complex<T>>(x)));
231 #else
232   return static_cast<c10::complex<T>>(
233       std::sinh(static_cast<std::complex<T>>(x)));
234 #endif
235 }
236 
237 template <typename T>
cosh(const c10::complex<T> & x)238 C10_HOST_DEVICE inline c10::complex<T> cosh(const c10::complex<T>& x) {
239 #if defined(__CUDACC__) || defined(__HIPCC__)
240   return static_cast<c10::complex<T>>(
241       thrust::cosh(static_cast<thrust::complex<T>>(x)));
242 #else
243   return static_cast<c10::complex<T>>(
244       std::cosh(static_cast<std::complex<T>>(x)));
245 #endif
246 }
247 
248 template <typename T>
tanh(const c10::complex<T> & x)249 C10_HOST_DEVICE inline c10::complex<T> tanh(const c10::complex<T>& x) {
250 #if defined(__CUDACC__) || defined(__HIPCC__)
251   return static_cast<c10::complex<T>>(
252       thrust::tanh(static_cast<thrust::complex<T>>(x)));
253 #else
254   return static_cast<c10::complex<T>>(
255       std::tanh(static_cast<std::complex<T>>(x)));
256 #endif
257 }
258 
259 template <typename T>
asinh(const c10::complex<T> & x)260 C10_HOST_DEVICE inline c10::complex<T> asinh(const c10::complex<T>& x) {
261 #if defined(__CUDACC__) || defined(__HIPCC__)
262   return static_cast<c10::complex<T>>(
263       thrust::asinh(static_cast<thrust::complex<T>>(x)));
264 #else
265   return static_cast<c10::complex<T>>(
266       std::asinh(static_cast<std::complex<T>>(x)));
267 #endif
268 }
269 
270 template <typename T>
acosh(const c10::complex<T> & x)271 C10_HOST_DEVICE inline c10::complex<T> acosh(const c10::complex<T>& x) {
272 #if defined(__CUDACC__) || defined(__HIPCC__)
273   return static_cast<c10::complex<T>>(
274       thrust::acosh(static_cast<thrust::complex<T>>(x)));
275 #else
276   return static_cast<c10::complex<T>>(
277       std::acosh(static_cast<std::complex<T>>(x)));
278 #endif
279 }
280 
281 template <typename T>
atanh(const c10::complex<T> & x)282 C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
283 #if defined(__CUDACC__) || defined(__HIPCC__)
284   return static_cast<c10::complex<T>>(
285       thrust::atanh(static_cast<thrust::complex<T>>(x)));
286 #else
287   return static_cast<c10::complex<T>>(
288       std::atanh(static_cast<std::complex<T>>(x)));
289 #endif
290 }
291 
292 template <typename T>
log1p(const c10::complex<T> & z)293 C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
294 #if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \
295     defined(__HIPCC__)
296   // For Mac, the new implementation yielded a high relative error. Falling back
297   // to the old version for now.
298   // See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
299   // For CUDA we also use this one, as thrust::log(thrust::complex) takes
300   // *forever* to compile
301 
302   // log1p(z) = log(1 + z)
303   // Let's define 1 + z = r * e ^ (i * a), then we have
304   // log(r * e ^ (i * a)) = log(r) + i * a
305   // With z = x + iy, the term r can be written as
306   // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5
307   //   = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5
308   // So, log(r) is
309   // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2)
310   //        = 0.5 * log1p(x * (x + 2) + y ^ 2)
311   // we need to use the expression only on certain condition to avoid overflow
312   // and underflow from `(x * (x + 2) + y ^ 2)`
313   T x = z.real();
314   T y = z.imag();
315   T zabs = std::abs(z);
316   T theta = std::atan2(y, x + T(1));
317   if (zabs < 0.5) {
318     T r = x * (T(2) + x) + y * y;
319     if (r == 0) { // handle underflow
320       return {x, theta};
321     }
322     return {T(0.5) * std::log1p(r), theta};
323   } else {
324     T z0 = std::hypot(x + 1, y);
325     return {std::log(z0), theta};
326   }
327 #else
328   // CPU path
329   // Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
330   c10::complex<T> u = z + T(1);
331   if (u == T(1)) {
332     return z;
333   } else {
334     auto log_u = log(u);
335     if (u - T(1) == z) {
336       return log_u;
337     }
338     return log_u * (z / (u - T(1)));
339   }
340 #endif
341 }
342 
343 template <typename T>
expm1(const c10::complex<T> & z)344 C10_HOST_DEVICE inline c10::complex<T> expm1(const c10::complex<T>& z) {
345   // expm1(z) = exp(z) - 1
346   // Define z = x + i * y
347   // f = e ^ (x + i * y) - 1
348   //   = e ^ x * e ^ (i * y) - 1
349   //   = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y))
350   //   = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y)
351   //   = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y)
352   T x = z.real();
353   T y = z.imag();
354   T a = std::sin(y / 2);
355   T er = std::expm1(x) * std::cos(y) - T(2) * a * a;
356   T ei = std::exp(x) * std::sin(y);
357   return {er, ei};
358 }
359 
360 } // namespace c10_complex_math
361 
362 using c10_complex_math::acos;
363 using c10_complex_math::acosh;
364 using c10_complex_math::asin;
365 using c10_complex_math::asinh;
366 using c10_complex_math::atan;
367 using c10_complex_math::atanh;
368 using c10_complex_math::cos;
369 using c10_complex_math::cosh;
370 using c10_complex_math::exp;
371 using c10_complex_math::expm1;
372 using c10_complex_math::log;
373 using c10_complex_math::log10;
374 using c10_complex_math::log1p;
375 using c10_complex_math::log2;
376 using c10_complex_math::pow;
377 using c10_complex_math::sin;
378 using c10_complex_math::sinh;
379 using c10_complex_math::sqrt;
380 using c10_complex_math::tan;
381 using c10_complex_math::tanh;
382 
383 namespace std {
384 
385 using c10_complex_math::acos;
386 using c10_complex_math::acosh;
387 using c10_complex_math::asin;
388 using c10_complex_math::asinh;
389 using c10_complex_math::atan;
390 using c10_complex_math::atanh;
391 using c10_complex_math::cos;
392 using c10_complex_math::cosh;
393 using c10_complex_math::exp;
394 using c10_complex_math::expm1;
395 using c10_complex_math::log;
396 using c10_complex_math::log10;
397 using c10_complex_math::log1p;
398 using c10_complex_math::log2;
399 using c10_complex_math::pow;
400 using c10_complex_math::sin;
401 using c10_complex_math::sinh;
402 using c10_complex_math::sqrt;
403 using c10_complex_math::tan;
404 using c10_complex_math::tanh;
405 
406 } // namespace std
407