xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/llvm_complex.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // This is copy-pasted (with modification) from the following llvm file:
2 // - https://github.com/llvm/llvm-project/blob/main/libcxx/include/complex
3 //
4 // -*- C++ -*-
5 //===----------------------------------------------------------------------===//
6 //
7 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
8 // See https://llvm.org/LICENSE.txt for license information.
9 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <string>
14 #include <ATen/cuda/llvm_jit_strings.h>
15 
16 
17 namespace at::cuda {
18 
19 const std::string complex_body = R"ESCAPE(
20 
21 namespace std {
22 
23 template<class _Tp> class complex;
24 
25 template<class _Tp> complex<_Tp> operator*(const complex<_Tp>& __z, const complex<_Tp>& __w);
26 template<class _Tp> complex<_Tp> operator/(const complex<_Tp>& __x, const complex<_Tp>& __y);
27 
28 template<class _Tp>
29 class complex
30 {
31 public:
32     typedef _Tp value_type;
33 private:
34     value_type __re_;
35     value_type __im_;
36 public:
37     constexpr
38     complex(const value_type& __re = value_type(), const value_type& __im = value_type())
39         : __re_(__re), __im_(__im) {}
40     template<class _Xp> constexpr
41     complex(const complex<_Xp>& __c)
42         : __re_(__c.real()), __im_(__c.imag()) {}
43 
44     constexpr value_type real() const {return __re_;}
45     constexpr value_type imag() const {return __im_;}
46 
47     void real(value_type __re) {__re_ = __re;}
48     void imag(value_type __im) {__im_ = __im;}
49 
50     constexpr operator bool() const {
51         return real() || imag();
52     }
53 
54     complex& operator= (const value_type& __re)
55         {__re_ = __re; __im_ = value_type(); return *this;}
56     complex& operator+=(const value_type& __re) {__re_ += __re; return *this;}
57     complex& operator-=(const value_type& __re) {__re_ -= __re; return *this;}
58     complex& operator*=(const value_type& __re) {__re_ *= __re; __im_ *= __re; return *this;}
59     complex& operator/=(const value_type& __re) {__re_ /= __re; __im_ /= __re; return *this;}
60 
61     template<class _Xp> complex& operator= (const complex<_Xp>& __c)
62         {
63             __re_ = __c.real();
64             __im_ = __c.imag();
65             return *this;
66         }
67     template<class _Xp> complex& operator+=(const complex<_Xp>& __c)
68         {
69             __re_ += __c.real();
70             __im_ += __c.imag();
71             return *this;
72         }
73     template<class _Xp> complex& operator-=(const complex<_Xp>& __c)
74         {
75             __re_ -= __c.real();
76             __im_ -= __c.imag();
77             return *this;
78         }
79     template<class _Xp> complex& operator*=(const complex<_Xp>& __c)
80         {
81             *this = *this * complex(__c.real(), __c.imag());
82             return *this;
83         }
84     template<class _Xp> complex& operator/=(const complex<_Xp>& __c)
85         {
86             *this = *this / complex(__c.real(), __c.imag());
87             return *this;
88         }
89 };
90 
91 template<> class complex<double>;
92 
93 template<>
94 class complex<float>
95 {
96     float __re_;
97     float __im_;
98 public:
99     typedef float value_type;
100 
101     constexpr complex(float __re = 0.0f, float __im = 0.0f)
102         : __re_(__re), __im_(__im) {}
103 
104     explicit constexpr complex(const complex<double>& __c);
105 
106     constexpr float real() const {return __re_;}
107     constexpr float imag() const {return __im_;}
108 
109     void real(value_type __re) {__re_ = __re;}
110     void imag(value_type __im) {__im_ = __im;}
111 
112     constexpr operator bool() const {
113         return real() || imag();
114     }
115 
116     complex& operator= (float __re)
117         {__re_ = __re; __im_ = value_type(); return *this;}
118     complex& operator+=(float __re) {__re_ += __re; return *this;}
119     complex& operator-=(float __re) {__re_ -= __re; return *this;}
120     complex& operator*=(float __re) {__re_ *= __re; __im_ *= __re; return *this;}
121     complex& operator/=(float __re) {__re_ /= __re; __im_ /= __re; return *this;}
122 
123     template<class _Xp> complex& operator= (const complex<_Xp>& __c)
124         {
125             __re_ = __c.real();
126             __im_ = __c.imag();
127             return *this;
128         }
129     template<class _Xp> complex& operator+=(const complex<_Xp>& __c)
130         {
131             __re_ += __c.real();
132             __im_ += __c.imag();
133             return *this;
134         }
135     template<class _Xp> complex& operator-=(const complex<_Xp>& __c)
136         {
137             __re_ -= __c.real();
138             __im_ -= __c.imag();
139             return *this;
140         }
141     template<class _Xp> complex& operator*=(const complex<_Xp>& __c)
142         {
143             *this = *this * complex(__c.real(), __c.imag());
144             return *this;
145         }
146     template<class _Xp> complex& operator/=(const complex<_Xp>& __c)
147         {
148             *this = *this / complex(__c.real(), __c.imag());
149             return *this;
150         }
151 };
152 
153 template<>
154 class complex<double>
155 {
156     double __re_;
157     double __im_;
158 public:
159     typedef double value_type;
160 
161     constexpr complex(double __re = 0.0, double __im = 0.0)
162         : __re_(__re), __im_(__im) {}
163 
164     constexpr complex(const complex<float>& __c);
165 
166     constexpr double real() const {return __re_;}
167     constexpr double imag() const {return __im_;}
168 
169     void real(value_type __re) {__re_ = __re;}
170     void imag(value_type __im) {__im_ = __im;}
171 
172     constexpr operator bool() const {
173         return real() || imag();
174     }
175 
176     complex& operator= (double __re)
177         {__re_ = __re; __im_ = value_type(); return *this;}
178     complex& operator+=(double __re) {__re_ += __re; return *this;}
179     complex& operator-=(double __re) {__re_ -= __re; return *this;}
180     complex& operator*=(double __re) {__re_ *= __re; __im_ *= __re; return *this;}
181     complex& operator/=(double __re) {__re_ /= __re; __im_ /= __re; return *this;}
182 
183     template<class _Xp> complex& operator= (const complex<_Xp>& __c)
184         {
185             __re_ = __c.real();
186             __im_ = __c.imag();
187             return *this;
188         }
189     template<class _Xp> complex& operator+=(const complex<_Xp>& __c)
190         {
191             __re_ += __c.real();
192             __im_ += __c.imag();
193             return *this;
194         }
195     template<class _Xp> complex& operator-=(const complex<_Xp>& __c)
196         {
197             __re_ -= __c.real();
198             __im_ -= __c.imag();
199             return *this;
200         }
201     template<class _Xp> complex& operator*=(const complex<_Xp>& __c)
202         {
203             *this = *this * complex(__c.real(), __c.imag());
204             return *this;
205         }
206     template<class _Xp> complex& operator/=(const complex<_Xp>& __c)
207         {
208             *this = *this / complex(__c.real(), __c.imag());
209             return *this;
210         }
211 };
212 
213 inline
214 constexpr
215 complex<float>::complex(const complex<double>& __c)
216     : __re_(__c.real()), __im_(__c.imag()) {}
217 
218 inline
219 constexpr
220 complex<double>::complex(const complex<float>& __c)
221     : __re_(__c.real()), __im_(__c.imag()) {}
222 
223 
224 // 26.3.6 operators:
225 
226 template<class _Tp>
227 inline
228 complex<_Tp>
229 operator+(const complex<_Tp>& __x, const complex<_Tp>& __y)
230 {
231     complex<_Tp> __t(__x);
232     __t += __y;
233     return __t;
234 }
235 
236 template<class _Tp>
237 inline
238 complex<_Tp>
239 operator+(const complex<_Tp>& __x, const _Tp& __y)
240 {
241     complex<_Tp> __t(__x);
242     __t += __y;
243     return __t;
244 }
245 
246 template<class _Tp>
247 inline
248 complex<_Tp>
249 operator+(const _Tp& __x, const complex<_Tp>& __y)
250 {
251     complex<_Tp> __t(__y);
252     __t += __x;
253     return __t;
254 }
255 
256 template<class _Tp>
257 inline
258 complex<_Tp>
259 operator-(const complex<_Tp>& __x, const complex<_Tp>& __y)
260 {
261     complex<_Tp> __t(__x);
262     __t -= __y;
263     return __t;
264 }
265 
266 template<class _Tp>
267 inline
268 complex<_Tp>
269 operator-(const complex<_Tp>& __x, const _Tp& __y)
270 {
271     complex<_Tp> __t(__x);
272     __t -= __y;
273     return __t;
274 }
275 
276 template<class _Tp>
277 inline
278 complex<_Tp>
279 operator-(const _Tp& __x, const complex<_Tp>& __y)
280 {
281     complex<_Tp> __t(-__y);
282     __t += __x;
283     return __t;
284 }
285 
286 template<class _Tp>
287 complex<_Tp>
288 operator*(const complex<_Tp>& __z, const complex<_Tp>& __w)
289 {
290     _Tp __a = __z.real();
291     _Tp __b = __z.imag();
292     _Tp __c = __w.real();
293     _Tp __d = __w.imag();
294     _Tp __ac = __a * __c;
295     _Tp __bd = __b * __d;
296     _Tp __ad = __a * __d;
297     _Tp __bc = __b * __c;
298     _Tp __x = __ac - __bd;
299     _Tp __y = __ad + __bc;
300     if (isnan(__x) && isnan(__y))
301     {
302         bool __recalc = false;
303         if (isinf(__a) || isinf(__b))
304         {
305             __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a);
306             __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b);
307             if (isnan(__c))
308                 __c = copysign(_Tp(0), __c);
309             if (isnan(__d))
310                 __d = copysign(_Tp(0), __d);
311             __recalc = true;
312         }
313         if (isinf(__c) || isinf(__d))
314         {
315             __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c);
316             __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d);
317             if (isnan(__a))
318                 __a = copysign(_Tp(0), __a);
319             if (isnan(__b))
320                 __b = copysign(_Tp(0), __b);
321             __recalc = true;
322         }
323         if (!__recalc && (isinf(__ac) || isinf(__bd) ||
324                           isinf(__ad) || isinf(__bc)))
325         {
326             if (isnan(__a))
327                 __a = copysign(_Tp(0), __a);
328             if (isnan(__b))
329                 __b = copysign(_Tp(0), __b);
330             if (isnan(__c))
331                 __c = copysign(_Tp(0), __c);
332             if (isnan(__d))
333                 __d = copysign(_Tp(0), __d);
334             __recalc = true;
335         }
336         if (__recalc)
337         {
338             __x = _Tp(INFINITY) * (__a * __c - __b * __d);
339             __y = _Tp(INFINITY) * (__a * __d + __b * __c);
340         }
341     }
342     return complex<_Tp>(__x, __y);
343 }
344 
345 template<class _Tp>
346 inline
347 complex<_Tp>
348 operator*(const complex<_Tp>& __x, const _Tp& __y)
349 {
350     complex<_Tp> __t(__x);
351     __t *= __y;
352     return __t;
353 }
354 
355 template<class _Tp>
356 inline
357 complex<_Tp>
358 operator*(const _Tp& __x, const complex<_Tp>& __y)
359 {
360     complex<_Tp> __t(__y);
361     __t *= __x;
362     return __t;
363 }
364 
365 template<class _Tp>
366 complex<_Tp>
367 operator/(const complex<_Tp>& __z, const complex<_Tp>& __w)
368 {
369     int __ilogbw = 0;
370     _Tp __a = __z.real();
371     _Tp __b = __z.imag();
372     _Tp __c = __w.real();
373     _Tp __d = __w.imag();
374     _Tp __logbw = logb(fmax(fabs(__c), fabs(__d)));
375     if (isfinite(__logbw))
376     {
377         __ilogbw = static_cast<int>(__logbw);
378         __c = scalbn(__c, -__ilogbw);
379         __d = scalbn(__d, -__ilogbw);
380     }
381     _Tp __denom = __c * __c + __d * __d;
382     _Tp __x = scalbn((__a * __c + __b * __d) / __denom, -__ilogbw);
383     _Tp __y = scalbn((__b * __c - __a * __d) / __denom, -__ilogbw);
384     if (isnan(__x) && isnan(__y))
385     {
386         if ((__denom == _Tp(0)) && (!isnan(__a) || !isnan(__b)))
387         {
388             __x = copysign(_Tp(INFINITY), __c) * __a;
389             __y = copysign(_Tp(INFINITY), __c) * __b;
390         }
391         else if ((isinf(__a) || isinf(__b)) && isfinite(__c) && isfinite(__d))
392         {
393             __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a);
394             __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b);
395             __x = _Tp(INFINITY) * (__a * __c + __b * __d);
396             __y = _Tp(INFINITY) * (__b * __c - __a * __d);
397         }
398         else if (isinf(__logbw) && __logbw > _Tp(0) && isfinite(__a) && isfinite(__b))
399         {
400             __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c);
401             __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d);
402             __x = _Tp(0) * (__a * __c + __b * __d);
403             __y = _Tp(0) * (__b * __c - __a * __d);
404         }
405     }
406     return complex<_Tp>(__x, __y);
407 }
408 
409 template<class _Tp>
410 inline
411 complex<_Tp>
412 operator/(const complex<_Tp>& __x, const _Tp& __y)
413 {
414     return complex<_Tp>(__x.real() / __y, __x.imag() / __y);
415 }
416 
417 template<class _Tp>
418 inline
419 complex<_Tp>
420 operator/(const _Tp& __x, const complex<_Tp>& __y)
421 {
422     complex<_Tp> __t(__x);
423     __t /= __y;
424     return __t;
425 }
426 
427 template<class _Tp>
428 inline
429 complex<_Tp>
430 operator+(const complex<_Tp>& __x)
431 {
432     return __x;
433 }
434 
435 template<class _Tp>
436 inline
437 complex<_Tp>
438 operator-(const complex<_Tp>& __x)
439 {
440     return complex<_Tp>(-__x.real(), -__x.imag());
441 }
442 
443 template<class _Tp>
444 inline constexpr
445 bool
446 operator==(const complex<_Tp>& __x, const complex<_Tp>& __y)
447 {
448     return __x.real() == __y.real() && __x.imag() == __y.imag();
449 }
450 
451 template<class _Tp>
452 inline constexpr
453 bool
454 operator==(const complex<_Tp>& __x, const _Tp& __y)
455 {
456     return __x.real() == __y && __x.imag() == 0;
457 }
458 
459 template<class _Tp>
460 inline constexpr
461 bool
462 operator==(const _Tp& __x, const complex<_Tp>& __y)
463 {
464     return __x == __y.real() && 0 == __y.imag();
465 }
466 
467 template<class _Tp>
468 inline constexpr
469 bool
470 operator!=(const complex<_Tp>& __x, const complex<_Tp>& __y)
471 {
472     return !(__x == __y);
473 }
474 
475 template<class _Tp>
476 inline constexpr
477 bool
478 operator!=(const complex<_Tp>& __x, const _Tp& __y)
479 {
480     return !(__x == __y);
481 }
482 
483 template<class _Tp>
484 inline constexpr
485 bool
486 operator!=(const _Tp& __x, const complex<_Tp>& __y)
487 {
488     return !(__x == __y);
489 }
490 
491 template<class _Tp>
492 inline constexpr
493 bool
494 operator&&(const complex<_Tp>& __x, const complex<_Tp>& __y)
495 {
496     return bool(__x) && bool(__y);
497 }
498 
499 template<class _Tp>
500 inline constexpr
501 bool
502 isnan(const complex<_Tp>& __x)
503 {
504     return isnan(__x.real()) || isnan(__x.imag());
505 }
506 
507 template<class _Tp>
508 inline constexpr
509 bool
510 operator||(const complex<_Tp>& __x, const complex<_Tp>& __y)
511 {
512     return bool(__x) || bool(__y);
513 }
514 
515 // 26.3.7 values:
516 
517 template <class _Tp, bool = is_integral<_Tp>::value,
518                      bool = is_floating_point<_Tp>::value
519                      >
520 struct __libcpp_complex_overload_traits {};
521 
522 // Integral Types
523 template <class _Tp>
524 struct __libcpp_complex_overload_traits<_Tp, true, false>
525 {
526     typedef double _ValueType;
527     typedef complex<double> _ComplexType;
528 };
529 
530 // Floating point types
531 template <class _Tp>
532 struct __libcpp_complex_overload_traits<_Tp, false, true>
533 {
534     typedef _Tp _ValueType;
535     typedef complex<_Tp> _ComplexType;
536 };
537 
538 // real
539 
540 template<class _Tp>
541 inline constexpr
542 _Tp
543 real(const complex<_Tp>& __c)
544 {
545     return __c.real();
546 }
547 
548 template <class _Tp>
549 inline constexpr
550 typename __libcpp_complex_overload_traits<_Tp>::_ValueType
551 real(_Tp __re)
552 {
553     return __re;
554 }
555 
556 // imag
557 
558 template<class _Tp>
559 inline constexpr
560 _Tp
561 imag(const complex<_Tp>& __c)
562 {
563     return __c.imag();
564 }
565 
566 template <class _Tp>
567 inline constexpr
568 typename __libcpp_complex_overload_traits<_Tp>::_ValueType
569 imag(_Tp)
570 {
571     return 0;
572 }
573 
574 // abs
575 
576 template<class _Tp>
577 inline
578 _Tp
579 abs(const complex<_Tp>& __c)
580 {
581     return hypot(__c.real(), __c.imag());
582 }
583 
584 // arg
585 
586 template<class _Tp>
587 inline
588 _Tp
589 arg(const complex<_Tp>& __c)
590 {
591     return atan2(__c.imag(), __c.real());
592 }
593 
594 template<class _Tp>
595 inline
596 typename enable_if
597 <
598     is_integral<_Tp>::value || is_same<_Tp, double>::value,
599     double
600 >::type
601 arg(_Tp __re)
602 {
603     return atan2(0., __re);
604 }
605 
606 template <class _Tp>
607 inline
608 typename enable_if<
609     is_same<_Tp, float>::value,
610     float
611 >::type
612 arg(_Tp __re)
613 {
614     return atan2f(0.F, __re);
615 }
616 
617 }
618 
619 )ESCAPE";
620 
621 const std::string complex_half_body = R"ESCAPE(
622 namespace std {
623 template <>
624 struct alignas(2) complex<at::Half> {
625   at::Half real_;
626   at::Half imag_;
627 
628   // Constructors
629   complex() = default;
630 
631   // implicit casting to and from `complex<float>`.
632   // NOTE: computation of `complex<Half>` will occur in `complex<float>`
633   __host__ __device__ inline complex(const std::complex<float>& value)
634       : real_(value.real()), imag_(value.imag()) {}
635 
636   inline __host__ __device__ operator std::complex<float>() const {
637     return {real_, imag_};
638   }
639 
640   at::Half real() const {return real_;}
641   at::Half imag() const {return imag_;}
642 
643 };
644 }
645 )ESCAPE";
646 
647 
get_complex_body_string()648 const std::string &get_complex_body_string() {
649   return complex_body;
650 }
651 
get_complex_half_body_string()652 const std::string &get_complex_half_body_string() {
653   return complex_half_body;
654 }
655 
656 const std::string complex_math = R"ESCAPE(
657 
658 namespace std {
659 
660 // norm
661 
662 template<class _Tp>
663 inline
664 _Tp
665 norm(const complex<_Tp>& __c)
666 {
667     if (isinf(__c.real()))
668         return abs(__c.real());
669     if (isinf(__c.imag()))
670         return abs(__c.imag());
671     return __c.real() * __c.real() + __c.imag() * __c.imag();
672 }
673 
674 template <class _Tp>
675 inline
676 typename __libcpp_complex_overload_traits<_Tp>::_ValueType
677 norm(_Tp __re)
678 {
679     typedef typename __libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType;
680     return static_cast<_ValueType>(__re) * __re;
681 }
682 
683 // conj
684 
685 template<class _Tp>
686 inline
687 complex<_Tp>
688 conj(const complex<_Tp>& __c)
689 {
690     return complex<_Tp>(__c.real(), -__c.imag());
691 }
692 
693 template <class _Tp>
694 inline
695 typename __libcpp_complex_overload_traits<_Tp>::_ComplexType
696 conj(_Tp __re)
697 {
698     typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType;
699     return _ComplexType(__re);
700 }
701 
702 
703 
704 // proj
705 
706 template<class _Tp>
707 inline
708 complex<_Tp>
709 proj(const complex<_Tp>& __c)
710 {
711     complex<_Tp> __r = __c;
712     if (isinf(__c.real()) || isinf(__c.imag()))
713         __r = complex<_Tp>(INFINITY, copysign(_Tp(0), __c.imag()));
714     return __r;
715 }
716 
717 template <class _Tp>
718 inline
719 typename enable_if
720 <
721     is_floating_point<_Tp>::value,
722     typename __libcpp_complex_overload_traits<_Tp>::_ComplexType
723 >::type
724 proj(_Tp __re)
725 {
726     if (isinf(__re))
727         __re = abs(__re);
728     return complex<_Tp>(__re);
729 }
730 
731 template <class _Tp>
732 inline
733 typename enable_if
734 <
735     is_integral<_Tp>::value,
736     typename __libcpp_complex_overload_traits<_Tp>::_ComplexType
737 >::type
738 proj(_Tp __re)
739 {
740     typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType;
741     return _ComplexType(__re);
742 }
743 
744 // polar
745 
746 template<class _Tp>
747 complex<_Tp>
748 polar(const _Tp& __rho, const _Tp& __theta = _Tp())
749 {
750     if (isnan(__rho) || signbit(__rho))
751         return complex<_Tp>(_Tp(NAN), _Tp(NAN));
752     if (isnan(__theta))
753     {
754         if (isinf(__rho))
755             return complex<_Tp>(__rho, __theta);
756         return complex<_Tp>(__theta, __theta);
757     }
758     if (isinf(__theta))
759     {
760         if (isinf(__rho))
761             return complex<_Tp>(__rho, _Tp(NAN));
762         return complex<_Tp>(_Tp(NAN), _Tp(NAN));
763     }
764     _Tp __x = __rho * cos(__theta);
765     if (isnan(__x))
766         __x = 0;
767     _Tp __y = __rho * sin(__theta);
768     if (isnan(__y))
769         __y = 0;
770     return complex<_Tp>(__x, __y);
771 }
772 
773 // log
774 
775 template<class _Tp>
776 inline
777 complex<_Tp>
778 log(const complex<_Tp>& __x)
779 {
780     return complex<_Tp>(log(abs(__x)), arg(__x));
781 }
782 
783 // log10
784 
785 template<class _Tp>
786 inline
787 complex<_Tp>
788 log10(const complex<_Tp>& __x)
789 {
790     return log(__x) / log(_Tp(10));
791 }
792 
793 // log2
794 
795 template<class _Tp>
796 inline
797 complex<_Tp>
798 log2(const complex<_Tp>& __x)
799 {
800     return log(__x) / log(_Tp(2));
801 }
802 
803 // sqrt
804 
805 template<class _Tp>
806 complex<_Tp>
807 sqrt(const complex<_Tp>& __x)
808 {
809     if (isinf(__x.imag()))
810         return complex<_Tp>(_Tp(INFINITY), __x.imag());
811     if (isinf(__x.real()))
812     {
813         if (__x.real() > _Tp(0))
814             return complex<_Tp>(__x.real(), isnan(__x.imag()) ? __x.imag() : copysign(_Tp(0), __x.imag()));
815         return complex<_Tp>(isnan(__x.imag()) ? __x.imag() : _Tp(0), copysign(__x.real(), __x.imag()));
816     }
817     return polar(sqrt(abs(__x)), arg(__x) / _Tp(2));
818 }
819 
820 // exp
821 
822 template<class _Tp>
823 complex<_Tp>
824 exp(const complex<_Tp>& __x)
825 {
826     _Tp __i = __x.imag();
827     if (__i == 0) {
828         return complex<_Tp>(exp(__x.real()), copysign(_Tp(0), __x.imag()));
829     }
830     if (isinf(__x.real()))
831     {
832         if (__x.real() < _Tp(0))
833         {
834             if (!isfinite(__i))
835                 __i = _Tp(1);
836         }
837         else if (__i == 0 || !isfinite(__i))
838         {
839             if (isinf(__i))
840                 __i = _Tp(NAN);
841             return complex<_Tp>(__x.real(), __i);
842         }
843     }
844     _Tp __e = exp(__x.real());
845     return complex<_Tp>(__e * cos(__i), __e * sin(__i));
846 }
847 
848 // pow
849 
850 template<class _Tp>
851 inline
852 complex<_Tp>
853 pow(const complex<_Tp>& __x, const complex<_Tp>& __y)
854 {
855     return exp(__y * log(__x));
856 }
857 
858 template<class _Tp, class _Up>
859 inline
860 complex<typename __promote<_Tp, _Up>::type>
861 pow(const complex<_Tp>& __x, const complex<_Up>& __y)
862 {
863     typedef complex<typename __promote<_Tp, _Up>::type> result_type;
864     return std::pow(result_type(__x), result_type(__y));
865 }
866 
867 template<class _Tp, class _Up>
868 inline
869 typename enable_if
870 <
871     is_arithmetic<_Up>::value,
872     complex<typename __promote<_Tp, _Up>::type>
873 >::type
874 pow(const complex<_Tp>& __x, const _Up& __y)
875 {
876     typedef complex<typename __promote<_Tp, _Up>::type> result_type;
877     return std::pow(result_type(__x), result_type(__y));
878 }
879 
880 template<class _Tp, class _Up>
881 inline
882 typename enable_if
883 <
884     is_arithmetic<_Tp>::value,
885     complex<typename __promote<_Tp, _Up>::type>
886 >::type
887 pow(const _Tp& __x, const complex<_Up>& __y)
888 {
889     typedef complex<typename __promote<_Tp, _Up>::type> result_type;
890     return std::pow(result_type(__x), result_type(__y));
891 }
892 
893 // __sqr, computes pow(x, 2)
894 
895 template<class _Tp>
896 inline
897 complex<_Tp>
898 __sqr(const complex<_Tp>& __x)
899 {
900     return complex<_Tp>((__x.real() - __x.imag()) * (__x.real() + __x.imag()),
901                         _Tp(2) * __x.real() * __x.imag());
902 }
903 
904 // asinh
905 
906 template<class _Tp>
907 complex<_Tp>
908 asinh(const complex<_Tp>& __x)
909 {
910     const _Tp __pi(atan2(+0., -0.));
911     if (isinf(__x.real()))
912     {
913         if (isnan(__x.imag()))
914             return __x;
915         if (isinf(__x.imag()))
916             return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag()));
917         return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag()));
918     }
919     if (isnan(__x.real()))
920     {
921         if (isinf(__x.imag()))
922             return complex<_Tp>(__x.imag(), __x.real());
923         if (__x.imag() == 0)
924             return __x;
925         return complex<_Tp>(__x.real(), __x.real());
926     }
927     if (isinf(__x.imag()))
928         return complex<_Tp>(copysign(__x.imag(), __x.real()), copysign(__pi/_Tp(2), __x.imag()));
929     complex<_Tp> __z = log(__x + sqrt(__sqr(__x) + _Tp(1)));
930     return complex<_Tp>(copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag()));
931 }
932 
933 // acosh
934 
935 template<class _Tp>
936 complex<_Tp>
937 acosh(const complex<_Tp>& __x)
938 {
939     const _Tp __pi(atan2(+0., -0.));
940     if (isinf(__x.real()))
941     {
942         if (isnan(__x.imag()))
943             return complex<_Tp>(abs(__x.real()), __x.imag());
944         if (isinf(__x.imag()))
945         {
946             if (__x.real() > 0)
947                 return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag()));
948             else
949                 return complex<_Tp>(-__x.real(), copysign(__pi * _Tp(0.75), __x.imag()));
950         }
951         if (__x.real() < 0)
952             return complex<_Tp>(-__x.real(), copysign(__pi, __x.imag()));
953         return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag()));
954     }
955     if (isnan(__x.real()))
956     {
957         if (isinf(__x.imag()))
958             return complex<_Tp>(abs(__x.imag()), __x.real());
959         return complex<_Tp>(__x.real(), __x.real());
960     }
961     if (isinf(__x.imag()))
962         return complex<_Tp>(abs(__x.imag()), copysign(__pi/_Tp(2), __x.imag()));
963     complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1)));
964     return complex<_Tp>(copysign(__z.real(), _Tp(0)), copysign(__z.imag(), __x.imag()));
965 }
966 
967 // atanh
968 
969 template<class _Tp>
970 complex<_Tp>
971 atanh(const complex<_Tp>& __x)
972 {
973     const _Tp __pi(atan2(+0., -0.));
974     if (isinf(__x.imag()))
975     {
976         return complex<_Tp>(copysign(_Tp(0), __x.real()), copysign(__pi/_Tp(2), __x.imag()));
977     }
978     if (isnan(__x.imag()))
979     {
980         if (isinf(__x.real()) || __x.real() == 0)
981             return complex<_Tp>(copysign(_Tp(0), __x.real()), __x.imag());
982         return complex<_Tp>(__x.imag(), __x.imag());
983     }
984     if (isnan(__x.real()))
985     {
986         return complex<_Tp>(__x.real(), __x.real());
987     }
988     if (isinf(__x.real()))
989     {
990         return complex<_Tp>(copysign(_Tp(0), __x.real()), copysign(__pi/_Tp(2), __x.imag()));
991     }
992     if (abs(__x.real()) == _Tp(1) && __x.imag() == _Tp(0))
993     {
994         return complex<_Tp>(copysign(_Tp(INFINITY), __x.real()), copysign(_Tp(0), __x.imag()));
995     }
996     complex<_Tp> __z = log((_Tp(1) + __x) / (_Tp(1) - __x)) / _Tp(2);
997     return complex<_Tp>(copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag()));
998 }
999 
1000 // sinh
1001 
1002 template<class _Tp>
1003 complex<_Tp>
1004 sinh(const complex<_Tp>& __x)
1005 {
1006     if (isinf(__x.real()) && !isfinite(__x.imag()))
1007         return complex<_Tp>(__x.real(), _Tp(NAN));
1008     if (__x.real() == 0 && !isfinite(__x.imag()))
1009         return complex<_Tp>(__x.real(), _Tp(NAN));
1010     if (__x.imag() == 0 && !isfinite(__x.real()))
1011         return __x;
1012     return complex<_Tp>(sinh(__x.real()) * cos(__x.imag()), cosh(__x.real()) * sin(__x.imag()));
1013 }
1014 
1015 // cosh
1016 
1017 template<class _Tp>
1018 complex<_Tp>
1019 cosh(const complex<_Tp>& __x)
1020 {
1021     if (isinf(__x.real()) && !isfinite(__x.imag()))
1022         return complex<_Tp>(abs(__x.real()), _Tp(NAN));
1023     if (__x.real() == 0 && !isfinite(__x.imag()))
1024         return complex<_Tp>(_Tp(NAN), __x.real());
1025     if (__x.real() == 0 && __x.imag() == 0)
1026         return complex<_Tp>(_Tp(1), __x.imag());
1027     if (__x.imag() == 0 && !isfinite(__x.real()))
1028         return complex<_Tp>(abs(__x.real()), __x.imag());
1029     return complex<_Tp>(cosh(__x.real()) * cos(__x.imag()), sinh(__x.real()) * sin(__x.imag()));
1030 }
1031 
1032 // tanh
1033 
1034 template<class _Tp>
1035 complex<_Tp>
1036 tanh(const complex<_Tp>& __x)
1037 {
1038     if (isinf(__x.real()))
1039     {
1040         if (!isfinite(__x.imag()))
1041             return complex<_Tp>(copysign(_Tp(1), __x.real()), _Tp(0));
1042         return complex<_Tp>(copysign(_Tp(1), __x.real()), copysign(_Tp(0), sin(_Tp(2) * __x.imag())));
1043     }
1044     if (isnan(__x.real()) && __x.imag() == 0)
1045         return __x;
1046     _Tp __2r(_Tp(2) * __x.real());
1047     _Tp __2i(_Tp(2) * __x.imag());
1048     _Tp __d(cosh(__2r) + cos(__2i));
1049     _Tp __2rsh(sinh(__2r));
1050     if (isinf(__2rsh) && isinf(__d))
1051         return complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1),
1052                             __2i > _Tp(0) ? _Tp(0) : _Tp(-0.));
1053     return  complex<_Tp>(__2rsh/__d, sin(__2i)/__d);
1054 }
1055 
1056 // asin
1057 
1058 template<class _Tp>
1059 complex<_Tp>
1060 asin(const complex<_Tp>& __x)
1061 {
1062     complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real()));
1063     return complex<_Tp>(__z.imag(), -__z.real());
1064 }
1065 
1066 // acos
1067 
1068 template<class _Tp>
1069 complex<_Tp>
1070 acos(const complex<_Tp>& __x)
1071 {
1072     const _Tp __pi(atan2(+0., -0.));
1073     if (isinf(__x.real()))
1074     {
1075         if (isnan(__x.imag()))
1076             return complex<_Tp>(__x.imag(), __x.real());
1077         if (isinf(__x.imag()))
1078         {
1079             if (__x.real() < _Tp(0))
1080                 return complex<_Tp>(_Tp(0.75) * __pi, -__x.imag());
1081             return complex<_Tp>(_Tp(0.25) * __pi, -__x.imag());
1082         }
1083         if (__x.real() < _Tp(0))
1084             return complex<_Tp>(__pi, signbit(__x.imag()) ? -__x.real() : __x.real());
1085         return complex<_Tp>(_Tp(0), signbit(__x.imag()) ? __x.real() : -__x.real());
1086     }
1087     if (isnan(__x.real()))
1088     {
1089         if (isinf(__x.imag()))
1090             return complex<_Tp>(__x.real(), -__x.imag());
1091         return complex<_Tp>(__x.real(), __x.real());
1092     }
1093     if (isinf(__x.imag()))
1094         return complex<_Tp>(__pi/_Tp(2), -__x.imag());
1095     if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag())))
1096         return complex<_Tp>(__pi/_Tp(2), -__x.imag());
1097     complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1)));
1098     if (signbit(__x.imag()))
1099         return complex<_Tp>(abs(__z.imag()), abs(__z.real()));
1100     return complex<_Tp>(abs(__z.imag()), -abs(__z.real()));
1101 }
1102 
1103 // atan
1104 
1105 template<class _Tp>
1106 complex<_Tp>
1107 atan(const complex<_Tp>& __x)
1108 {
1109     complex<_Tp> __z = atanh(complex<_Tp>(-__x.imag(), __x.real()));
1110     return complex<_Tp>(__z.imag(), -__z.real());
1111 }
1112 
1113 // sin
1114 
1115 template<class _Tp>
1116 complex<_Tp>
1117 sin(const complex<_Tp>& __x)
1118 {
1119     complex<_Tp> __z = sinh(complex<_Tp>(-__x.imag(), __x.real()));
1120     return complex<_Tp>(__z.imag(), -__z.real());
1121 }
1122 
1123 // cos
1124 
1125 template<class _Tp>
1126 inline
1127 complex<_Tp>
1128 cos(const complex<_Tp>& __x)
1129 {
1130     return cosh(complex<_Tp>(-__x.imag(), __x.real()));
1131 }
1132 
1133 // tan
1134 
1135 template<class _Tp>
1136 complex<_Tp>
1137 tan(const complex<_Tp>& __x)
1138 {
1139     complex<_Tp> __z = tanh(complex<_Tp>(-__x.imag(), __x.real()));
1140     return complex<_Tp>(__z.imag(), -__z.real());
1141 }
1142 
1143 // Literal suffix for complex number literals [complex.literals]
1144 inline namespace literals
1145 {
1146   inline namespace complex_literals
1147   {
1148     constexpr complex<double> operator""i(long double __im)
1149     {
1150         return { 0.0, static_cast<double>(__im) };
1151     }
1152 
1153     constexpr complex<double> operator""i(unsigned long long __im)
1154     {
1155         return { 0.0, static_cast<double>(__im) };
1156     }
1157 
1158 
1159     constexpr complex<float> operator""if(long double __im)
1160     {
1161         return { 0.0f, static_cast<float>(__im) };
1162     }
1163 
1164     constexpr complex<float> operator""if(unsigned long long __im)
1165     {
1166         return { 0.0f, static_cast<float>(__im) };
1167     }
1168   } // namespace complex_literals
1169 } // namespace literals
1170 
1171 } // namespace std
1172 
1173 )ESCAPE";
1174 
get_complex_math_string()1175 const std::string &get_complex_math_string() {
1176   return complex_math;
1177 }
1178 
1179 } // namespace at::cuda
1180