1 // This is copy-pasted (with modification) from the following llvm file:
2 // - https://github.com/llvm/llvm-project/blob/main/libcxx/include/complex
3 //
4 // -*- C++ -*-
5 //===----------------------------------------------------------------------===//
6 //
7 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
8 // See https://llvm.org/LICENSE.txt for license information.
9 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <string>
14 #include <ATen/cuda/llvm_jit_strings.h>
15
16
17 namespace at::cuda {
18
19 const std::string complex_body = R"ESCAPE(
20
21 namespace std {
22
23 template<class _Tp> class complex;
24
25 template<class _Tp> complex<_Tp> operator*(const complex<_Tp>& __z, const complex<_Tp>& __w);
26 template<class _Tp> complex<_Tp> operator/(const complex<_Tp>& __x, const complex<_Tp>& __y);
27
28 template<class _Tp>
29 class complex
30 {
31 public:
32 typedef _Tp value_type;
33 private:
34 value_type __re_;
35 value_type __im_;
36 public:
37 constexpr
38 complex(const value_type& __re = value_type(), const value_type& __im = value_type())
39 : __re_(__re), __im_(__im) {}
40 template<class _Xp> constexpr
41 complex(const complex<_Xp>& __c)
42 : __re_(__c.real()), __im_(__c.imag()) {}
43
44 constexpr value_type real() const {return __re_;}
45 constexpr value_type imag() const {return __im_;}
46
47 void real(value_type __re) {__re_ = __re;}
48 void imag(value_type __im) {__im_ = __im;}
49
50 constexpr operator bool() const {
51 return real() || imag();
52 }
53
54 complex& operator= (const value_type& __re)
55 {__re_ = __re; __im_ = value_type(); return *this;}
56 complex& operator+=(const value_type& __re) {__re_ += __re; return *this;}
57 complex& operator-=(const value_type& __re) {__re_ -= __re; return *this;}
58 complex& operator*=(const value_type& __re) {__re_ *= __re; __im_ *= __re; return *this;}
59 complex& operator/=(const value_type& __re) {__re_ /= __re; __im_ /= __re; return *this;}
60
61 template<class _Xp> complex& operator= (const complex<_Xp>& __c)
62 {
63 __re_ = __c.real();
64 __im_ = __c.imag();
65 return *this;
66 }
67 template<class _Xp> complex& operator+=(const complex<_Xp>& __c)
68 {
69 __re_ += __c.real();
70 __im_ += __c.imag();
71 return *this;
72 }
73 template<class _Xp> complex& operator-=(const complex<_Xp>& __c)
74 {
75 __re_ -= __c.real();
76 __im_ -= __c.imag();
77 return *this;
78 }
79 template<class _Xp> complex& operator*=(const complex<_Xp>& __c)
80 {
81 *this = *this * complex(__c.real(), __c.imag());
82 return *this;
83 }
84 template<class _Xp> complex& operator/=(const complex<_Xp>& __c)
85 {
86 *this = *this / complex(__c.real(), __c.imag());
87 return *this;
88 }
89 };
90
91 template<> class complex<double>;
92
93 template<>
94 class complex<float>
95 {
96 float __re_;
97 float __im_;
98 public:
99 typedef float value_type;
100
101 constexpr complex(float __re = 0.0f, float __im = 0.0f)
102 : __re_(__re), __im_(__im) {}
103
104 explicit constexpr complex(const complex<double>& __c);
105
106 constexpr float real() const {return __re_;}
107 constexpr float imag() const {return __im_;}
108
109 void real(value_type __re) {__re_ = __re;}
110 void imag(value_type __im) {__im_ = __im;}
111
112 constexpr operator bool() const {
113 return real() || imag();
114 }
115
116 complex& operator= (float __re)
117 {__re_ = __re; __im_ = value_type(); return *this;}
118 complex& operator+=(float __re) {__re_ += __re; return *this;}
119 complex& operator-=(float __re) {__re_ -= __re; return *this;}
120 complex& operator*=(float __re) {__re_ *= __re; __im_ *= __re; return *this;}
121 complex& operator/=(float __re) {__re_ /= __re; __im_ /= __re; return *this;}
122
123 template<class _Xp> complex& operator= (const complex<_Xp>& __c)
124 {
125 __re_ = __c.real();
126 __im_ = __c.imag();
127 return *this;
128 }
129 template<class _Xp> complex& operator+=(const complex<_Xp>& __c)
130 {
131 __re_ += __c.real();
132 __im_ += __c.imag();
133 return *this;
134 }
135 template<class _Xp> complex& operator-=(const complex<_Xp>& __c)
136 {
137 __re_ -= __c.real();
138 __im_ -= __c.imag();
139 return *this;
140 }
141 template<class _Xp> complex& operator*=(const complex<_Xp>& __c)
142 {
143 *this = *this * complex(__c.real(), __c.imag());
144 return *this;
145 }
146 template<class _Xp> complex& operator/=(const complex<_Xp>& __c)
147 {
148 *this = *this / complex(__c.real(), __c.imag());
149 return *this;
150 }
151 };
152
153 template<>
154 class complex<double>
155 {
156 double __re_;
157 double __im_;
158 public:
159 typedef double value_type;
160
161 constexpr complex(double __re = 0.0, double __im = 0.0)
162 : __re_(__re), __im_(__im) {}
163
164 constexpr complex(const complex<float>& __c);
165
166 constexpr double real() const {return __re_;}
167 constexpr double imag() const {return __im_;}
168
169 void real(value_type __re) {__re_ = __re;}
170 void imag(value_type __im) {__im_ = __im;}
171
172 constexpr operator bool() const {
173 return real() || imag();
174 }
175
176 complex& operator= (double __re)
177 {__re_ = __re; __im_ = value_type(); return *this;}
178 complex& operator+=(double __re) {__re_ += __re; return *this;}
179 complex& operator-=(double __re) {__re_ -= __re; return *this;}
180 complex& operator*=(double __re) {__re_ *= __re; __im_ *= __re; return *this;}
181 complex& operator/=(double __re) {__re_ /= __re; __im_ /= __re; return *this;}
182
183 template<class _Xp> complex& operator= (const complex<_Xp>& __c)
184 {
185 __re_ = __c.real();
186 __im_ = __c.imag();
187 return *this;
188 }
189 template<class _Xp> complex& operator+=(const complex<_Xp>& __c)
190 {
191 __re_ += __c.real();
192 __im_ += __c.imag();
193 return *this;
194 }
195 template<class _Xp> complex& operator-=(const complex<_Xp>& __c)
196 {
197 __re_ -= __c.real();
198 __im_ -= __c.imag();
199 return *this;
200 }
201 template<class _Xp> complex& operator*=(const complex<_Xp>& __c)
202 {
203 *this = *this * complex(__c.real(), __c.imag());
204 return *this;
205 }
206 template<class _Xp> complex& operator/=(const complex<_Xp>& __c)
207 {
208 *this = *this / complex(__c.real(), __c.imag());
209 return *this;
210 }
211 };
212
213 inline
214 constexpr
215 complex<float>::complex(const complex<double>& __c)
216 : __re_(__c.real()), __im_(__c.imag()) {}
217
218 inline
219 constexpr
220 complex<double>::complex(const complex<float>& __c)
221 : __re_(__c.real()), __im_(__c.imag()) {}
222
223
224 // 26.3.6 operators:
225
226 template<class _Tp>
227 inline
228 complex<_Tp>
229 operator+(const complex<_Tp>& __x, const complex<_Tp>& __y)
230 {
231 complex<_Tp> __t(__x);
232 __t += __y;
233 return __t;
234 }
235
236 template<class _Tp>
237 inline
238 complex<_Tp>
239 operator+(const complex<_Tp>& __x, const _Tp& __y)
240 {
241 complex<_Tp> __t(__x);
242 __t += __y;
243 return __t;
244 }
245
246 template<class _Tp>
247 inline
248 complex<_Tp>
249 operator+(const _Tp& __x, const complex<_Tp>& __y)
250 {
251 complex<_Tp> __t(__y);
252 __t += __x;
253 return __t;
254 }
255
256 template<class _Tp>
257 inline
258 complex<_Tp>
259 operator-(const complex<_Tp>& __x, const complex<_Tp>& __y)
260 {
261 complex<_Tp> __t(__x);
262 __t -= __y;
263 return __t;
264 }
265
266 template<class _Tp>
267 inline
268 complex<_Tp>
269 operator-(const complex<_Tp>& __x, const _Tp& __y)
270 {
271 complex<_Tp> __t(__x);
272 __t -= __y;
273 return __t;
274 }
275
276 template<class _Tp>
277 inline
278 complex<_Tp>
279 operator-(const _Tp& __x, const complex<_Tp>& __y)
280 {
281 complex<_Tp> __t(-__y);
282 __t += __x;
283 return __t;
284 }
285
286 template<class _Tp>
287 complex<_Tp>
288 operator*(const complex<_Tp>& __z, const complex<_Tp>& __w)
289 {
290 _Tp __a = __z.real();
291 _Tp __b = __z.imag();
292 _Tp __c = __w.real();
293 _Tp __d = __w.imag();
294 _Tp __ac = __a * __c;
295 _Tp __bd = __b * __d;
296 _Tp __ad = __a * __d;
297 _Tp __bc = __b * __c;
298 _Tp __x = __ac - __bd;
299 _Tp __y = __ad + __bc;
300 if (isnan(__x) && isnan(__y))
301 {
302 bool __recalc = false;
303 if (isinf(__a) || isinf(__b))
304 {
305 __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a);
306 __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b);
307 if (isnan(__c))
308 __c = copysign(_Tp(0), __c);
309 if (isnan(__d))
310 __d = copysign(_Tp(0), __d);
311 __recalc = true;
312 }
313 if (isinf(__c) || isinf(__d))
314 {
315 __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c);
316 __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d);
317 if (isnan(__a))
318 __a = copysign(_Tp(0), __a);
319 if (isnan(__b))
320 __b = copysign(_Tp(0), __b);
321 __recalc = true;
322 }
323 if (!__recalc && (isinf(__ac) || isinf(__bd) ||
324 isinf(__ad) || isinf(__bc)))
325 {
326 if (isnan(__a))
327 __a = copysign(_Tp(0), __a);
328 if (isnan(__b))
329 __b = copysign(_Tp(0), __b);
330 if (isnan(__c))
331 __c = copysign(_Tp(0), __c);
332 if (isnan(__d))
333 __d = copysign(_Tp(0), __d);
334 __recalc = true;
335 }
336 if (__recalc)
337 {
338 __x = _Tp(INFINITY) * (__a * __c - __b * __d);
339 __y = _Tp(INFINITY) * (__a * __d + __b * __c);
340 }
341 }
342 return complex<_Tp>(__x, __y);
343 }
344
345 template<class _Tp>
346 inline
347 complex<_Tp>
348 operator*(const complex<_Tp>& __x, const _Tp& __y)
349 {
350 complex<_Tp> __t(__x);
351 __t *= __y;
352 return __t;
353 }
354
355 template<class _Tp>
356 inline
357 complex<_Tp>
358 operator*(const _Tp& __x, const complex<_Tp>& __y)
359 {
360 complex<_Tp> __t(__y);
361 __t *= __x;
362 return __t;
363 }
364
365 template<class _Tp>
366 complex<_Tp>
367 operator/(const complex<_Tp>& __z, const complex<_Tp>& __w)
368 {
369 int __ilogbw = 0;
370 _Tp __a = __z.real();
371 _Tp __b = __z.imag();
372 _Tp __c = __w.real();
373 _Tp __d = __w.imag();
374 _Tp __logbw = logb(fmax(fabs(__c), fabs(__d)));
375 if (isfinite(__logbw))
376 {
377 __ilogbw = static_cast<int>(__logbw);
378 __c = scalbn(__c, -__ilogbw);
379 __d = scalbn(__d, -__ilogbw);
380 }
381 _Tp __denom = __c * __c + __d * __d;
382 _Tp __x = scalbn((__a * __c + __b * __d) / __denom, -__ilogbw);
383 _Tp __y = scalbn((__b * __c - __a * __d) / __denom, -__ilogbw);
384 if (isnan(__x) && isnan(__y))
385 {
386 if ((__denom == _Tp(0)) && (!isnan(__a) || !isnan(__b)))
387 {
388 __x = copysign(_Tp(INFINITY), __c) * __a;
389 __y = copysign(_Tp(INFINITY), __c) * __b;
390 }
391 else if ((isinf(__a) || isinf(__b)) && isfinite(__c) && isfinite(__d))
392 {
393 __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a);
394 __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b);
395 __x = _Tp(INFINITY) * (__a * __c + __b * __d);
396 __y = _Tp(INFINITY) * (__b * __c - __a * __d);
397 }
398 else if (isinf(__logbw) && __logbw > _Tp(0) && isfinite(__a) && isfinite(__b))
399 {
400 __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c);
401 __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d);
402 __x = _Tp(0) * (__a * __c + __b * __d);
403 __y = _Tp(0) * (__b * __c - __a * __d);
404 }
405 }
406 return complex<_Tp>(__x, __y);
407 }
408
409 template<class _Tp>
410 inline
411 complex<_Tp>
412 operator/(const complex<_Tp>& __x, const _Tp& __y)
413 {
414 return complex<_Tp>(__x.real() / __y, __x.imag() / __y);
415 }
416
417 template<class _Tp>
418 inline
419 complex<_Tp>
420 operator/(const _Tp& __x, const complex<_Tp>& __y)
421 {
422 complex<_Tp> __t(__x);
423 __t /= __y;
424 return __t;
425 }
426
427 template<class _Tp>
428 inline
429 complex<_Tp>
430 operator+(const complex<_Tp>& __x)
431 {
432 return __x;
433 }
434
435 template<class _Tp>
436 inline
437 complex<_Tp>
438 operator-(const complex<_Tp>& __x)
439 {
440 return complex<_Tp>(-__x.real(), -__x.imag());
441 }
442
443 template<class _Tp>
444 inline constexpr
445 bool
446 operator==(const complex<_Tp>& __x, const complex<_Tp>& __y)
447 {
448 return __x.real() == __y.real() && __x.imag() == __y.imag();
449 }
450
451 template<class _Tp>
452 inline constexpr
453 bool
454 operator==(const complex<_Tp>& __x, const _Tp& __y)
455 {
456 return __x.real() == __y && __x.imag() == 0;
457 }
458
459 template<class _Tp>
460 inline constexpr
461 bool
462 operator==(const _Tp& __x, const complex<_Tp>& __y)
463 {
464 return __x == __y.real() && 0 == __y.imag();
465 }
466
467 template<class _Tp>
468 inline constexpr
469 bool
470 operator!=(const complex<_Tp>& __x, const complex<_Tp>& __y)
471 {
472 return !(__x == __y);
473 }
474
475 template<class _Tp>
476 inline constexpr
477 bool
478 operator!=(const complex<_Tp>& __x, const _Tp& __y)
479 {
480 return !(__x == __y);
481 }
482
483 template<class _Tp>
484 inline constexpr
485 bool
486 operator!=(const _Tp& __x, const complex<_Tp>& __y)
487 {
488 return !(__x == __y);
489 }
490
491 template<class _Tp>
492 inline constexpr
493 bool
494 operator&&(const complex<_Tp>& __x, const complex<_Tp>& __y)
495 {
496 return bool(__x) && bool(__y);
497 }
498
499 template<class _Tp>
500 inline constexpr
501 bool
502 isnan(const complex<_Tp>& __x)
503 {
504 return isnan(__x.real()) || isnan(__x.imag());
505 }
506
507 template<class _Tp>
508 inline constexpr
509 bool
510 operator||(const complex<_Tp>& __x, const complex<_Tp>& __y)
511 {
512 return bool(__x) || bool(__y);
513 }
514
515 // 26.3.7 values:
516
517 template <class _Tp, bool = is_integral<_Tp>::value,
518 bool = is_floating_point<_Tp>::value
519 >
520 struct __libcpp_complex_overload_traits {};
521
522 // Integral Types
523 template <class _Tp>
524 struct __libcpp_complex_overload_traits<_Tp, true, false>
525 {
526 typedef double _ValueType;
527 typedef complex<double> _ComplexType;
528 };
529
530 // Floating point types
531 template <class _Tp>
532 struct __libcpp_complex_overload_traits<_Tp, false, true>
533 {
534 typedef _Tp _ValueType;
535 typedef complex<_Tp> _ComplexType;
536 };
537
538 // real
539
540 template<class _Tp>
541 inline constexpr
542 _Tp
543 real(const complex<_Tp>& __c)
544 {
545 return __c.real();
546 }
547
548 template <class _Tp>
549 inline constexpr
550 typename __libcpp_complex_overload_traits<_Tp>::_ValueType
551 real(_Tp __re)
552 {
553 return __re;
554 }
555
556 // imag
557
558 template<class _Tp>
559 inline constexpr
560 _Tp
561 imag(const complex<_Tp>& __c)
562 {
563 return __c.imag();
564 }
565
566 template <class _Tp>
567 inline constexpr
568 typename __libcpp_complex_overload_traits<_Tp>::_ValueType
569 imag(_Tp)
570 {
571 return 0;
572 }
573
574 // abs
575
576 template<class _Tp>
577 inline
578 _Tp
579 abs(const complex<_Tp>& __c)
580 {
581 return hypot(__c.real(), __c.imag());
582 }
583
584 // arg
585
586 template<class _Tp>
587 inline
588 _Tp
589 arg(const complex<_Tp>& __c)
590 {
591 return atan2(__c.imag(), __c.real());
592 }
593
594 template<class _Tp>
595 inline
596 typename enable_if
597 <
598 is_integral<_Tp>::value || is_same<_Tp, double>::value,
599 double
600 >::type
601 arg(_Tp __re)
602 {
603 return atan2(0., __re);
604 }
605
606 template <class _Tp>
607 inline
608 typename enable_if<
609 is_same<_Tp, float>::value,
610 float
611 >::type
612 arg(_Tp __re)
613 {
614 return atan2f(0.F, __re);
615 }
616
617 }
618
619 )ESCAPE";
620
621 const std::string complex_half_body = R"ESCAPE(
622 namespace std {
623 template <>
624 struct alignas(2) complex<at::Half> {
625 at::Half real_;
626 at::Half imag_;
627
628 // Constructors
629 complex() = default;
630
631 // implicit casting to and from `complex<float>`.
632 // NOTE: computation of `complex<Half>` will occur in `complex<float>`
633 __host__ __device__ inline complex(const std::complex<float>& value)
634 : real_(value.real()), imag_(value.imag()) {}
635
636 inline __host__ __device__ operator std::complex<float>() const {
637 return {real_, imag_};
638 }
639
640 at::Half real() const {return real_;}
641 at::Half imag() const {return imag_;}
642
643 };
644 }
645 )ESCAPE";
646
647
get_complex_body_string()648 const std::string &get_complex_body_string() {
649 return complex_body;
650 }
651
get_complex_half_body_string()652 const std::string &get_complex_half_body_string() {
653 return complex_half_body;
654 }
655
656 const std::string complex_math = R"ESCAPE(
657
658 namespace std {
659
660 // norm
661
662 template<class _Tp>
663 inline
664 _Tp
665 norm(const complex<_Tp>& __c)
666 {
667 if (isinf(__c.real()))
668 return abs(__c.real());
669 if (isinf(__c.imag()))
670 return abs(__c.imag());
671 return __c.real() * __c.real() + __c.imag() * __c.imag();
672 }
673
674 template <class _Tp>
675 inline
676 typename __libcpp_complex_overload_traits<_Tp>::_ValueType
677 norm(_Tp __re)
678 {
679 typedef typename __libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType;
680 return static_cast<_ValueType>(__re) * __re;
681 }
682
683 // conj
684
685 template<class _Tp>
686 inline
687 complex<_Tp>
688 conj(const complex<_Tp>& __c)
689 {
690 return complex<_Tp>(__c.real(), -__c.imag());
691 }
692
693 template <class _Tp>
694 inline
695 typename __libcpp_complex_overload_traits<_Tp>::_ComplexType
696 conj(_Tp __re)
697 {
698 typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType;
699 return _ComplexType(__re);
700 }
701
702
703
704 // proj
705
706 template<class _Tp>
707 inline
708 complex<_Tp>
709 proj(const complex<_Tp>& __c)
710 {
711 complex<_Tp> __r = __c;
712 if (isinf(__c.real()) || isinf(__c.imag()))
713 __r = complex<_Tp>(INFINITY, copysign(_Tp(0), __c.imag()));
714 return __r;
715 }
716
717 template <class _Tp>
718 inline
719 typename enable_if
720 <
721 is_floating_point<_Tp>::value,
722 typename __libcpp_complex_overload_traits<_Tp>::_ComplexType
723 >::type
724 proj(_Tp __re)
725 {
726 if (isinf(__re))
727 __re = abs(__re);
728 return complex<_Tp>(__re);
729 }
730
731 template <class _Tp>
732 inline
733 typename enable_if
734 <
735 is_integral<_Tp>::value,
736 typename __libcpp_complex_overload_traits<_Tp>::_ComplexType
737 >::type
738 proj(_Tp __re)
739 {
740 typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType;
741 return _ComplexType(__re);
742 }
743
744 // polar
745
746 template<class _Tp>
747 complex<_Tp>
748 polar(const _Tp& __rho, const _Tp& __theta = _Tp())
749 {
750 if (isnan(__rho) || signbit(__rho))
751 return complex<_Tp>(_Tp(NAN), _Tp(NAN));
752 if (isnan(__theta))
753 {
754 if (isinf(__rho))
755 return complex<_Tp>(__rho, __theta);
756 return complex<_Tp>(__theta, __theta);
757 }
758 if (isinf(__theta))
759 {
760 if (isinf(__rho))
761 return complex<_Tp>(__rho, _Tp(NAN));
762 return complex<_Tp>(_Tp(NAN), _Tp(NAN));
763 }
764 _Tp __x = __rho * cos(__theta);
765 if (isnan(__x))
766 __x = 0;
767 _Tp __y = __rho * sin(__theta);
768 if (isnan(__y))
769 __y = 0;
770 return complex<_Tp>(__x, __y);
771 }
772
773 // log
774
775 template<class _Tp>
776 inline
777 complex<_Tp>
778 log(const complex<_Tp>& __x)
779 {
780 return complex<_Tp>(log(abs(__x)), arg(__x));
781 }
782
783 // log10
784
785 template<class _Tp>
786 inline
787 complex<_Tp>
788 log10(const complex<_Tp>& __x)
789 {
790 return log(__x) / log(_Tp(10));
791 }
792
793 // log2
794
795 template<class _Tp>
796 inline
797 complex<_Tp>
798 log2(const complex<_Tp>& __x)
799 {
800 return log(__x) / log(_Tp(2));
801 }
802
803 // sqrt
804
805 template<class _Tp>
806 complex<_Tp>
807 sqrt(const complex<_Tp>& __x)
808 {
809 if (isinf(__x.imag()))
810 return complex<_Tp>(_Tp(INFINITY), __x.imag());
811 if (isinf(__x.real()))
812 {
813 if (__x.real() > _Tp(0))
814 return complex<_Tp>(__x.real(), isnan(__x.imag()) ? __x.imag() : copysign(_Tp(0), __x.imag()));
815 return complex<_Tp>(isnan(__x.imag()) ? __x.imag() : _Tp(0), copysign(__x.real(), __x.imag()));
816 }
817 return polar(sqrt(abs(__x)), arg(__x) / _Tp(2));
818 }
819
820 // exp
821
822 template<class _Tp>
823 complex<_Tp>
824 exp(const complex<_Tp>& __x)
825 {
826 _Tp __i = __x.imag();
827 if (__i == 0) {
828 return complex<_Tp>(exp(__x.real()), copysign(_Tp(0), __x.imag()));
829 }
830 if (isinf(__x.real()))
831 {
832 if (__x.real() < _Tp(0))
833 {
834 if (!isfinite(__i))
835 __i = _Tp(1);
836 }
837 else if (__i == 0 || !isfinite(__i))
838 {
839 if (isinf(__i))
840 __i = _Tp(NAN);
841 return complex<_Tp>(__x.real(), __i);
842 }
843 }
844 _Tp __e = exp(__x.real());
845 return complex<_Tp>(__e * cos(__i), __e * sin(__i));
846 }
847
848 // pow
849
850 template<class _Tp>
851 inline
852 complex<_Tp>
853 pow(const complex<_Tp>& __x, const complex<_Tp>& __y)
854 {
855 return exp(__y * log(__x));
856 }
857
858 template<class _Tp, class _Up>
859 inline
860 complex<typename __promote<_Tp, _Up>::type>
861 pow(const complex<_Tp>& __x, const complex<_Up>& __y)
862 {
863 typedef complex<typename __promote<_Tp, _Up>::type> result_type;
864 return std::pow(result_type(__x), result_type(__y));
865 }
866
867 template<class _Tp, class _Up>
868 inline
869 typename enable_if
870 <
871 is_arithmetic<_Up>::value,
872 complex<typename __promote<_Tp, _Up>::type>
873 >::type
874 pow(const complex<_Tp>& __x, const _Up& __y)
875 {
876 typedef complex<typename __promote<_Tp, _Up>::type> result_type;
877 return std::pow(result_type(__x), result_type(__y));
878 }
879
880 template<class _Tp, class _Up>
881 inline
882 typename enable_if
883 <
884 is_arithmetic<_Tp>::value,
885 complex<typename __promote<_Tp, _Up>::type>
886 >::type
887 pow(const _Tp& __x, const complex<_Up>& __y)
888 {
889 typedef complex<typename __promote<_Tp, _Up>::type> result_type;
890 return std::pow(result_type(__x), result_type(__y));
891 }
892
893 // __sqr, computes pow(x, 2)
894
895 template<class _Tp>
896 inline
897 complex<_Tp>
898 __sqr(const complex<_Tp>& __x)
899 {
900 return complex<_Tp>((__x.real() - __x.imag()) * (__x.real() + __x.imag()),
901 _Tp(2) * __x.real() * __x.imag());
902 }
903
904 // asinh
905
906 template<class _Tp>
907 complex<_Tp>
908 asinh(const complex<_Tp>& __x)
909 {
910 const _Tp __pi(atan2(+0., -0.));
911 if (isinf(__x.real()))
912 {
913 if (isnan(__x.imag()))
914 return __x;
915 if (isinf(__x.imag()))
916 return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag()));
917 return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag()));
918 }
919 if (isnan(__x.real()))
920 {
921 if (isinf(__x.imag()))
922 return complex<_Tp>(__x.imag(), __x.real());
923 if (__x.imag() == 0)
924 return __x;
925 return complex<_Tp>(__x.real(), __x.real());
926 }
927 if (isinf(__x.imag()))
928 return complex<_Tp>(copysign(__x.imag(), __x.real()), copysign(__pi/_Tp(2), __x.imag()));
929 complex<_Tp> __z = log(__x + sqrt(__sqr(__x) + _Tp(1)));
930 return complex<_Tp>(copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag()));
931 }
932
933 // acosh
934
935 template<class _Tp>
936 complex<_Tp>
937 acosh(const complex<_Tp>& __x)
938 {
939 const _Tp __pi(atan2(+0., -0.));
940 if (isinf(__x.real()))
941 {
942 if (isnan(__x.imag()))
943 return complex<_Tp>(abs(__x.real()), __x.imag());
944 if (isinf(__x.imag()))
945 {
946 if (__x.real() > 0)
947 return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag()));
948 else
949 return complex<_Tp>(-__x.real(), copysign(__pi * _Tp(0.75), __x.imag()));
950 }
951 if (__x.real() < 0)
952 return complex<_Tp>(-__x.real(), copysign(__pi, __x.imag()));
953 return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag()));
954 }
955 if (isnan(__x.real()))
956 {
957 if (isinf(__x.imag()))
958 return complex<_Tp>(abs(__x.imag()), __x.real());
959 return complex<_Tp>(__x.real(), __x.real());
960 }
961 if (isinf(__x.imag()))
962 return complex<_Tp>(abs(__x.imag()), copysign(__pi/_Tp(2), __x.imag()));
963 complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1)));
964 return complex<_Tp>(copysign(__z.real(), _Tp(0)), copysign(__z.imag(), __x.imag()));
965 }
966
967 // atanh
968
969 template<class _Tp>
970 complex<_Tp>
971 atanh(const complex<_Tp>& __x)
972 {
973 const _Tp __pi(atan2(+0., -0.));
974 if (isinf(__x.imag()))
975 {
976 return complex<_Tp>(copysign(_Tp(0), __x.real()), copysign(__pi/_Tp(2), __x.imag()));
977 }
978 if (isnan(__x.imag()))
979 {
980 if (isinf(__x.real()) || __x.real() == 0)
981 return complex<_Tp>(copysign(_Tp(0), __x.real()), __x.imag());
982 return complex<_Tp>(__x.imag(), __x.imag());
983 }
984 if (isnan(__x.real()))
985 {
986 return complex<_Tp>(__x.real(), __x.real());
987 }
988 if (isinf(__x.real()))
989 {
990 return complex<_Tp>(copysign(_Tp(0), __x.real()), copysign(__pi/_Tp(2), __x.imag()));
991 }
992 if (abs(__x.real()) == _Tp(1) && __x.imag() == _Tp(0))
993 {
994 return complex<_Tp>(copysign(_Tp(INFINITY), __x.real()), copysign(_Tp(0), __x.imag()));
995 }
996 complex<_Tp> __z = log((_Tp(1) + __x) / (_Tp(1) - __x)) / _Tp(2);
997 return complex<_Tp>(copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag()));
998 }
999
1000 // sinh
1001
1002 template<class _Tp>
1003 complex<_Tp>
1004 sinh(const complex<_Tp>& __x)
1005 {
1006 if (isinf(__x.real()) && !isfinite(__x.imag()))
1007 return complex<_Tp>(__x.real(), _Tp(NAN));
1008 if (__x.real() == 0 && !isfinite(__x.imag()))
1009 return complex<_Tp>(__x.real(), _Tp(NAN));
1010 if (__x.imag() == 0 && !isfinite(__x.real()))
1011 return __x;
1012 return complex<_Tp>(sinh(__x.real()) * cos(__x.imag()), cosh(__x.real()) * sin(__x.imag()));
1013 }
1014
1015 // cosh
1016
1017 template<class _Tp>
1018 complex<_Tp>
1019 cosh(const complex<_Tp>& __x)
1020 {
1021 if (isinf(__x.real()) && !isfinite(__x.imag()))
1022 return complex<_Tp>(abs(__x.real()), _Tp(NAN));
1023 if (__x.real() == 0 && !isfinite(__x.imag()))
1024 return complex<_Tp>(_Tp(NAN), __x.real());
1025 if (__x.real() == 0 && __x.imag() == 0)
1026 return complex<_Tp>(_Tp(1), __x.imag());
1027 if (__x.imag() == 0 && !isfinite(__x.real()))
1028 return complex<_Tp>(abs(__x.real()), __x.imag());
1029 return complex<_Tp>(cosh(__x.real()) * cos(__x.imag()), sinh(__x.real()) * sin(__x.imag()));
1030 }
1031
1032 // tanh
1033
1034 template<class _Tp>
1035 complex<_Tp>
1036 tanh(const complex<_Tp>& __x)
1037 {
1038 if (isinf(__x.real()))
1039 {
1040 if (!isfinite(__x.imag()))
1041 return complex<_Tp>(copysign(_Tp(1), __x.real()), _Tp(0));
1042 return complex<_Tp>(copysign(_Tp(1), __x.real()), copysign(_Tp(0), sin(_Tp(2) * __x.imag())));
1043 }
1044 if (isnan(__x.real()) && __x.imag() == 0)
1045 return __x;
1046 _Tp __2r(_Tp(2) * __x.real());
1047 _Tp __2i(_Tp(2) * __x.imag());
1048 _Tp __d(cosh(__2r) + cos(__2i));
1049 _Tp __2rsh(sinh(__2r));
1050 if (isinf(__2rsh) && isinf(__d))
1051 return complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1),
1052 __2i > _Tp(0) ? _Tp(0) : _Tp(-0.));
1053 return complex<_Tp>(__2rsh/__d, sin(__2i)/__d);
1054 }
1055
1056 // asin
1057
1058 template<class _Tp>
1059 complex<_Tp>
1060 asin(const complex<_Tp>& __x)
1061 {
1062 complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real()));
1063 return complex<_Tp>(__z.imag(), -__z.real());
1064 }
1065
1066 // acos
1067
1068 template<class _Tp>
1069 complex<_Tp>
1070 acos(const complex<_Tp>& __x)
1071 {
1072 const _Tp __pi(atan2(+0., -0.));
1073 if (isinf(__x.real()))
1074 {
1075 if (isnan(__x.imag()))
1076 return complex<_Tp>(__x.imag(), __x.real());
1077 if (isinf(__x.imag()))
1078 {
1079 if (__x.real() < _Tp(0))
1080 return complex<_Tp>(_Tp(0.75) * __pi, -__x.imag());
1081 return complex<_Tp>(_Tp(0.25) * __pi, -__x.imag());
1082 }
1083 if (__x.real() < _Tp(0))
1084 return complex<_Tp>(__pi, signbit(__x.imag()) ? -__x.real() : __x.real());
1085 return complex<_Tp>(_Tp(0), signbit(__x.imag()) ? __x.real() : -__x.real());
1086 }
1087 if (isnan(__x.real()))
1088 {
1089 if (isinf(__x.imag()))
1090 return complex<_Tp>(__x.real(), -__x.imag());
1091 return complex<_Tp>(__x.real(), __x.real());
1092 }
1093 if (isinf(__x.imag()))
1094 return complex<_Tp>(__pi/_Tp(2), -__x.imag());
1095 if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag())))
1096 return complex<_Tp>(__pi/_Tp(2), -__x.imag());
1097 complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1)));
1098 if (signbit(__x.imag()))
1099 return complex<_Tp>(abs(__z.imag()), abs(__z.real()));
1100 return complex<_Tp>(abs(__z.imag()), -abs(__z.real()));
1101 }
1102
1103 // atan
1104
1105 template<class _Tp>
1106 complex<_Tp>
1107 atan(const complex<_Tp>& __x)
1108 {
1109 complex<_Tp> __z = atanh(complex<_Tp>(-__x.imag(), __x.real()));
1110 return complex<_Tp>(__z.imag(), -__z.real());
1111 }
1112
1113 // sin
1114
1115 template<class _Tp>
1116 complex<_Tp>
1117 sin(const complex<_Tp>& __x)
1118 {
1119 complex<_Tp> __z = sinh(complex<_Tp>(-__x.imag(), __x.real()));
1120 return complex<_Tp>(__z.imag(), -__z.real());
1121 }
1122
1123 // cos
1124
1125 template<class _Tp>
1126 inline
1127 complex<_Tp>
1128 cos(const complex<_Tp>& __x)
1129 {
1130 return cosh(complex<_Tp>(-__x.imag(), __x.real()));
1131 }
1132
1133 // tan
1134
1135 template<class _Tp>
1136 complex<_Tp>
1137 tan(const complex<_Tp>& __x)
1138 {
1139 complex<_Tp> __z = tanh(complex<_Tp>(-__x.imag(), __x.real()));
1140 return complex<_Tp>(__z.imag(), -__z.real());
1141 }
1142
1143 // Literal suffix for complex number literals [complex.literals]
1144 inline namespace literals
1145 {
1146 inline namespace complex_literals
1147 {
1148 constexpr complex<double> operator""i(long double __im)
1149 {
1150 return { 0.0, static_cast<double>(__im) };
1151 }
1152
1153 constexpr complex<double> operator""i(unsigned long long __im)
1154 {
1155 return { 0.0, static_cast<double>(__im) };
1156 }
1157
1158
1159 constexpr complex<float> operator""if(long double __im)
1160 {
1161 return { 0.0f, static_cast<float>(__im) };
1162 }
1163
1164 constexpr complex<float> operator""if(unsigned long long __im)
1165 {
1166 return { 0.0f, static_cast<float>(__im) };
1167 }
1168 } // namespace complex_literals
1169 } // namespace literals
1170
1171 } // namespace std
1172
1173 )ESCAPE";
1174
get_complex_math_string()1175 const std::string &get_complex_math_string() {
1176 return complex_math;
1177 }
1178
1179 } // namespace at::cuda
1180