xref: /aosp_15_r20/external/pytorch/c10/util/complex.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <complex>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
8*da0073e9SAndroid Build Coastguard Worker #include <thrust/complex.h>
9*da0073e9SAndroid Build Coastguard Worker #endif
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker C10_CLANG_DIAGNOSTIC_PUSH()
12*da0073e9SAndroid Build Coastguard Worker #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
13*da0073e9SAndroid Build Coastguard Worker C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
14*da0073e9SAndroid Build Coastguard Worker #endif
15*da0073e9SAndroid Build Coastguard Worker #if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
16*da0073e9SAndroid Build Coastguard Worker C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
17*da0073e9SAndroid Build Coastguard Worker #endif
18*da0073e9SAndroid Build Coastguard Worker 
19*da0073e9SAndroid Build Coastguard Worker namespace c10 {
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker // c10::complex is an implementation of complex numbers that aims
22*da0073e9SAndroid Build Coastguard Worker // to work on all devices supported by PyTorch
23*da0073e9SAndroid Build Coastguard Worker //
24*da0073e9SAndroid Build Coastguard Worker // Most of the APIs duplicates std::complex
25*da0073e9SAndroid Build Coastguard Worker // Reference: https://en.cppreference.com/w/cpp/numeric/complex
26*da0073e9SAndroid Build Coastguard Worker //
27*da0073e9SAndroid Build Coastguard Worker // [NOTE: Complex Operator Unification]
28*da0073e9SAndroid Build Coastguard Worker // Operators currently use a mix of std::complex, thrust::complex, and
29*da0073e9SAndroid Build Coastguard Worker // c10::complex internally. The end state is that all operators will use
30*da0073e9SAndroid Build Coastguard Worker // c10::complex internally.  Until then, there may be some hacks to support all
31*da0073e9SAndroid Build Coastguard Worker // variants.
32*da0073e9SAndroid Build Coastguard Worker //
33*da0073e9SAndroid Build Coastguard Worker //
34*da0073e9SAndroid Build Coastguard Worker // [Note on Constructors]
35*da0073e9SAndroid Build Coastguard Worker //
36*da0073e9SAndroid Build Coastguard Worker // The APIs of constructors are mostly copied from C++ standard:
37*da0073e9SAndroid Build Coastguard Worker //   https://en.cppreference.com/w/cpp/numeric/complex/complex
38*da0073e9SAndroid Build Coastguard Worker //
39*da0073e9SAndroid Build Coastguard Worker // Since C++14, all constructors are constexpr in std::complex
40*da0073e9SAndroid Build Coastguard Worker //
41*da0073e9SAndroid Build Coastguard Worker // There are three types of constructors:
42*da0073e9SAndroid Build Coastguard Worker // - initializing from real and imag:
43*da0073e9SAndroid Build Coastguard Worker //     `constexpr complex( const T& re = T(), const T& im = T() );`
44*da0073e9SAndroid Build Coastguard Worker // - implicitly-declared copy constructor
45*da0073e9SAndroid Build Coastguard Worker // - converting constructors
46*da0073e9SAndroid Build Coastguard Worker //
47*da0073e9SAndroid Build Coastguard Worker // Converting constructors:
48*da0073e9SAndroid Build Coastguard Worker // - std::complex defines converting constructor between float/double/long
49*da0073e9SAndroid Build Coastguard Worker // double,
50*da0073e9SAndroid Build Coastguard Worker //   while we define converting constructor between float/double.
51*da0073e9SAndroid Build Coastguard Worker // - For these converting constructors, upcasting is implicit, downcasting is
52*da0073e9SAndroid Build Coastguard Worker //   explicit.
53*da0073e9SAndroid Build Coastguard Worker // - We also define explicit casting from std::complex/thrust::complex
54*da0073e9SAndroid Build Coastguard Worker //   - Note that the conversion from thrust is not constexpr, because
55*da0073e9SAndroid Build Coastguard Worker //     thrust does not define them as constexpr ????
56*da0073e9SAndroid Build Coastguard Worker //
57*da0073e9SAndroid Build Coastguard Worker //
58*da0073e9SAndroid Build Coastguard Worker // [Operator =]
59*da0073e9SAndroid Build Coastguard Worker //
60*da0073e9SAndroid Build Coastguard Worker // The APIs of operator = are mostly copied from C++ standard:
61*da0073e9SAndroid Build Coastguard Worker //   https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
62*da0073e9SAndroid Build Coastguard Worker //
63*da0073e9SAndroid Build Coastguard Worker // Since C++20, all operator= are constexpr. Although we are not building with
64*da0073e9SAndroid Build Coastguard Worker // C++20, we also obey this behavior.
65*da0073e9SAndroid Build Coastguard Worker //
66*da0073e9SAndroid Build Coastguard Worker // There are three types of assign operator:
67*da0073e9SAndroid Build Coastguard Worker // - Assign a real value from the same scalar type
68*da0073e9SAndroid Build Coastguard Worker //   - In std, this is templated as complex& operator=(const T& x)
69*da0073e9SAndroid Build Coastguard Worker //     with specialization `complex& operator=(T x)` for float/double/long
70*da0073e9SAndroid Build Coastguard Worker //     double Since we only support float and double, on will use `complex&
71*da0073e9SAndroid Build Coastguard Worker //     operator=(T x)`
72*da0073e9SAndroid Build Coastguard Worker // - Copy assignment operator and converting assignment operator
73*da0073e9SAndroid Build Coastguard Worker //   - There is no specialization of converting assignment operators, which type
74*da0073e9SAndroid Build Coastguard Worker //   is
75*da0073e9SAndroid Build Coastguard Worker //     convertible is solely dependent on whether the scalar type is convertible
76*da0073e9SAndroid Build Coastguard Worker //
77*da0073e9SAndroid Build Coastguard Worker // In addition to the standard assignment, we also provide assignment operators
78*da0073e9SAndroid Build Coastguard Worker // with std and thrust
79*da0073e9SAndroid Build Coastguard Worker //
80*da0073e9SAndroid Build Coastguard Worker //
81*da0073e9SAndroid Build Coastguard Worker // [Casting operators]
82*da0073e9SAndroid Build Coastguard Worker //
83*da0073e9SAndroid Build Coastguard Worker // std::complex does not have casting operators. We define casting operators
84*da0073e9SAndroid Build Coastguard Worker // casting to std::complex and thrust::complex
85*da0073e9SAndroid Build Coastguard Worker //
86*da0073e9SAndroid Build Coastguard Worker //
87*da0073e9SAndroid Build Coastguard Worker // [Operator ""]
88*da0073e9SAndroid Build Coastguard Worker //
89*da0073e9SAndroid Build Coastguard Worker // std::complex has custom literals `i`, `if` and `il` defined in namespace
90*da0073e9SAndroid Build Coastguard Worker // `std::literals::complex_literals`. We define our own custom literals in the
91*da0073e9SAndroid Build Coastguard Worker // namespace `c10::complex_literals`. Our custom literals does not follow the
92*da0073e9SAndroid Build Coastguard Worker // same behavior as in std::complex, instead, we define _if, _id to construct
93*da0073e9SAndroid Build Coastguard Worker // float/double complex literals.
94*da0073e9SAndroid Build Coastguard Worker //
95*da0073e9SAndroid Build Coastguard Worker //
96*da0073e9SAndroid Build Coastguard Worker // [real() and imag()]
97*da0073e9SAndroid Build Coastguard Worker //
98*da0073e9SAndroid Build Coastguard Worker // In C++20, there are two overload of these functions, one it to return the
99*da0073e9SAndroid Build Coastguard Worker // real/imag, another is to set real/imag, they are both constexpr. We follow
100*da0073e9SAndroid Build Coastguard Worker // this design.
101*da0073e9SAndroid Build Coastguard Worker //
102*da0073e9SAndroid Build Coastguard Worker //
103*da0073e9SAndroid Build Coastguard Worker // [Operator +=,-=,*=,/=]
104*da0073e9SAndroid Build Coastguard Worker //
105*da0073e9SAndroid Build Coastguard Worker // Since C++20, these operators become constexpr. In our implementation, they
106*da0073e9SAndroid Build Coastguard Worker // are also constexpr.
107*da0073e9SAndroid Build Coastguard Worker //
108*da0073e9SAndroid Build Coastguard Worker // There are two types of such operators: operating with a real number, or
109*da0073e9SAndroid Build Coastguard Worker // operating with another complex number. For the operating with a real number,
110*da0073e9SAndroid Build Coastguard Worker // the generic template form has argument type `const T &`, while the overload
111*da0073e9SAndroid Build Coastguard Worker // for float/double/long double has `T`. We will follow the same type as
112*da0073e9SAndroid Build Coastguard Worker // float/double/long double in std.
113*da0073e9SAndroid Build Coastguard Worker //
114*da0073e9SAndroid Build Coastguard Worker // [Unary operator +-]
115*da0073e9SAndroid Build Coastguard Worker //
116*da0073e9SAndroid Build Coastguard Worker // Since C++20, they are constexpr. We also make them expr
117*da0073e9SAndroid Build Coastguard Worker //
118*da0073e9SAndroid Build Coastguard Worker // [Binary operators +-*/]
119*da0073e9SAndroid Build Coastguard Worker //
120*da0073e9SAndroid Build Coastguard Worker // Each operator has three versions (taking + as example):
121*da0073e9SAndroid Build Coastguard Worker // - complex + complex
122*da0073e9SAndroid Build Coastguard Worker // - complex + real
123*da0073e9SAndroid Build Coastguard Worker // - real + complex
124*da0073e9SAndroid Build Coastguard Worker //
125*da0073e9SAndroid Build Coastguard Worker // [Operator ==, !=]
126*da0073e9SAndroid Build Coastguard Worker //
127*da0073e9SAndroid Build Coastguard Worker // Each operator has three versions (taking == as example):
128*da0073e9SAndroid Build Coastguard Worker // - complex == complex
129*da0073e9SAndroid Build Coastguard Worker // - complex == real
130*da0073e9SAndroid Build Coastguard Worker // - real == complex
131*da0073e9SAndroid Build Coastguard Worker //
132*da0073e9SAndroid Build Coastguard Worker // Some of them are removed on C++20, but we decide to keep them
133*da0073e9SAndroid Build Coastguard Worker //
134*da0073e9SAndroid Build Coastguard Worker // [Operator <<, >>]
135*da0073e9SAndroid Build Coastguard Worker //
136*da0073e9SAndroid Build Coastguard Worker // These are implemented by casting to std::complex
137*da0073e9SAndroid Build Coastguard Worker //
138*da0073e9SAndroid Build Coastguard Worker //
139*da0073e9SAndroid Build Coastguard Worker //
140*da0073e9SAndroid Build Coastguard Worker // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
141*da0073e9SAndroid Build Coastguard Worker // because:
142*da0073e9SAndroid Build Coastguard Worker //  - lots of members and functions of c10::Half are not constexpr
143*da0073e9SAndroid Build Coastguard Worker //  - thrust::complex only support float and double
144*da0073e9SAndroid Build Coastguard Worker 
145*da0073e9SAndroid Build Coastguard Worker template <typename T>
146*da0073e9SAndroid Build Coastguard Worker struct alignas(sizeof(T) * 2) complex {
147*da0073e9SAndroid Build Coastguard Worker   using value_type = T;
148*da0073e9SAndroid Build Coastguard Worker 
149*da0073e9SAndroid Build Coastguard Worker   T real_ = T(0);
150*da0073e9SAndroid Build Coastguard Worker   T imag_ = T(0);
151*da0073e9SAndroid Build Coastguard Worker 
152*da0073e9SAndroid Build Coastguard Worker   constexpr complex() = default;
153*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
real_complex154*da0073e9SAndroid Build Coastguard Worker       : real_(re), imag_(im) {}
155*da0073e9SAndroid Build Coastguard Worker   template <typename U>
complexcomplex156*da0073e9SAndroid Build Coastguard Worker   explicit constexpr complex(const std::complex<U>& other)
157*da0073e9SAndroid Build Coastguard Worker       : complex(other.real(), other.imag()) {}
158*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
159*da0073e9SAndroid Build Coastguard Worker   template <typename U>
complexcomplex160*da0073e9SAndroid Build Coastguard Worker   explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
161*da0073e9SAndroid Build Coastguard Worker       : real_(other.real()), imag_(other.imag()) {}
162*da0073e9SAndroid Build Coastguard Worker // NOTE can not be implemented as follow due to ROCm bug:
163*da0073e9SAndroid Build Coastguard Worker //   explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
164*da0073e9SAndroid Build Coastguard Worker //   complex(other.real(), other.imag()) {}
165*da0073e9SAndroid Build Coastguard Worker #endif
166*da0073e9SAndroid Build Coastguard Worker 
167*da0073e9SAndroid Build Coastguard Worker   // Use SFINAE to specialize casting constructor for c10::complex<float> and
168*da0073e9SAndroid Build Coastguard Worker   // c10::complex<double>
169*da0073e9SAndroid Build Coastguard Worker   template <typename U = T>
complexcomplex170*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE explicit constexpr complex(
171*da0073e9SAndroid Build Coastguard Worker       const std::enable_if_t<std::is_same_v<U, float>, complex<double>>& other)
172*da0073e9SAndroid Build Coastguard Worker       : real_(other.real_), imag_(other.imag_) {}
173*da0073e9SAndroid Build Coastguard Worker   template <typename U = T>
complexcomplex174*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE constexpr complex(
175*da0073e9SAndroid Build Coastguard Worker       const std::enable_if_t<std::is_same_v<U, double>, complex<float>>& other)
176*da0073e9SAndroid Build Coastguard Worker       : real_(other.real_), imag_(other.imag_) {}
177*da0073e9SAndroid Build Coastguard Worker 
178*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator=(T re) {
179*da0073e9SAndroid Build Coastguard Worker     real_ = re;
180*da0073e9SAndroid Build Coastguard Worker     imag_ = 0;
181*da0073e9SAndroid Build Coastguard Worker     return *this;
182*da0073e9SAndroid Build Coastguard Worker   }
183*da0073e9SAndroid Build Coastguard Worker 
184*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator+=(T re) {
185*da0073e9SAndroid Build Coastguard Worker     real_ += re;
186*da0073e9SAndroid Build Coastguard Worker     return *this;
187*da0073e9SAndroid Build Coastguard Worker   }
188*da0073e9SAndroid Build Coastguard Worker 
189*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator-=(T re) {
190*da0073e9SAndroid Build Coastguard Worker     real_ -= re;
191*da0073e9SAndroid Build Coastguard Worker     return *this;
192*da0073e9SAndroid Build Coastguard Worker   }
193*da0073e9SAndroid Build Coastguard Worker 
194*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator*=(T re) {
195*da0073e9SAndroid Build Coastguard Worker     real_ *= re;
196*da0073e9SAndroid Build Coastguard Worker     imag_ *= re;
197*da0073e9SAndroid Build Coastguard Worker     return *this;
198*da0073e9SAndroid Build Coastguard Worker   }
199*da0073e9SAndroid Build Coastguard Worker 
200*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator/=(T re) {
201*da0073e9SAndroid Build Coastguard Worker     real_ /= re;
202*da0073e9SAndroid Build Coastguard Worker     imag_ /= re;
203*da0073e9SAndroid Build Coastguard Worker     return *this;
204*da0073e9SAndroid Build Coastguard Worker   }
205*da0073e9SAndroid Build Coastguard Worker 
206*da0073e9SAndroid Build Coastguard Worker   template <typename U>
207*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator=(const complex<U>& rhs) {
208*da0073e9SAndroid Build Coastguard Worker     real_ = rhs.real();
209*da0073e9SAndroid Build Coastguard Worker     imag_ = rhs.imag();
210*da0073e9SAndroid Build Coastguard Worker     return *this;
211*da0073e9SAndroid Build Coastguard Worker   }
212*da0073e9SAndroid Build Coastguard Worker 
213*da0073e9SAndroid Build Coastguard Worker   template <typename U>
214*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator+=(const complex<U>& rhs) {
215*da0073e9SAndroid Build Coastguard Worker     real_ += rhs.real();
216*da0073e9SAndroid Build Coastguard Worker     imag_ += rhs.imag();
217*da0073e9SAndroid Build Coastguard Worker     return *this;
218*da0073e9SAndroid Build Coastguard Worker   }
219*da0073e9SAndroid Build Coastguard Worker 
220*da0073e9SAndroid Build Coastguard Worker   template <typename U>
221*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator-=(const complex<U>& rhs) {
222*da0073e9SAndroid Build Coastguard Worker     real_ -= rhs.real();
223*da0073e9SAndroid Build Coastguard Worker     imag_ -= rhs.imag();
224*da0073e9SAndroid Build Coastguard Worker     return *this;
225*da0073e9SAndroid Build Coastguard Worker   }
226*da0073e9SAndroid Build Coastguard Worker 
227*da0073e9SAndroid Build Coastguard Worker   template <typename U>
228*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator*=(const complex<U>& rhs) {
229*da0073e9SAndroid Build Coastguard Worker     // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
230*da0073e9SAndroid Build Coastguard Worker     T a = real_;
231*da0073e9SAndroid Build Coastguard Worker     T b = imag_;
232*da0073e9SAndroid Build Coastguard Worker     U c = rhs.real();
233*da0073e9SAndroid Build Coastguard Worker     U d = rhs.imag();
234*da0073e9SAndroid Build Coastguard Worker     real_ = a * c - b * d;
235*da0073e9SAndroid Build Coastguard Worker     imag_ = a * d + b * c;
236*da0073e9SAndroid Build Coastguard Worker     return *this;
237*da0073e9SAndroid Build Coastguard Worker   }
238*da0073e9SAndroid Build Coastguard Worker 
239*da0073e9SAndroid Build Coastguard Worker #ifdef __APPLE__
240*da0073e9SAndroid Build Coastguard Worker #define FORCE_INLINE_APPLE __attribute__((always_inline))
241*da0073e9SAndroid Build Coastguard Worker #else
242*da0073e9SAndroid Build Coastguard Worker #define FORCE_INLINE_APPLE
243*da0073e9SAndroid Build Coastguard Worker #endif
244*da0073e9SAndroid Build Coastguard Worker   template <typename U>
245*da0073e9SAndroid Build Coastguard Worker   constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
246*da0073e9SAndroid Build Coastguard Worker       __ubsan_ignore_float_divide_by_zero__ {
247*da0073e9SAndroid Build Coastguard Worker     // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
248*da0073e9SAndroid Build Coastguard Worker     // the calculation below follows numpy's complex division
249*da0073e9SAndroid Build Coastguard Worker     T a = real_;
250*da0073e9SAndroid Build Coastguard Worker     T b = imag_;
251*da0073e9SAndroid Build Coastguard Worker     U c = rhs.real();
252*da0073e9SAndroid Build Coastguard Worker     U d = rhs.imag();
253*da0073e9SAndroid Build Coastguard Worker 
254*da0073e9SAndroid Build Coastguard Worker #if defined(__GNUC__) && !defined(__clang__)
255*da0073e9SAndroid Build Coastguard Worker     // std::abs is already constexpr by gcc
256*da0073e9SAndroid Build Coastguard Worker     auto abs_c = std::abs(c);
257*da0073e9SAndroid Build Coastguard Worker     auto abs_d = std::abs(d);
258*da0073e9SAndroid Build Coastguard Worker #else
259*da0073e9SAndroid Build Coastguard Worker     auto abs_c = c < 0 ? -c : c;
260*da0073e9SAndroid Build Coastguard Worker     auto abs_d = d < 0 ? -d : d;
261*da0073e9SAndroid Build Coastguard Worker #endif
262*da0073e9SAndroid Build Coastguard Worker 
263*da0073e9SAndroid Build Coastguard Worker     if (abs_c >= abs_d) {
264*da0073e9SAndroid Build Coastguard Worker       if (abs_c == U(0) && abs_d == U(0)) {
265*da0073e9SAndroid Build Coastguard Worker         /* divide by zeros should yield a complex inf or nan */
266*da0073e9SAndroid Build Coastguard Worker         real_ = a / abs_c;
267*da0073e9SAndroid Build Coastguard Worker         imag_ = b / abs_d;
268*da0073e9SAndroid Build Coastguard Worker       } else {
269*da0073e9SAndroid Build Coastguard Worker         auto rat = d / c;
270*da0073e9SAndroid Build Coastguard Worker         auto scl = U(1.0) / (c + d * rat);
271*da0073e9SAndroid Build Coastguard Worker         real_ = (a + b * rat) * scl;
272*da0073e9SAndroid Build Coastguard Worker         imag_ = (b - a * rat) * scl;
273*da0073e9SAndroid Build Coastguard Worker       }
274*da0073e9SAndroid Build Coastguard Worker     } else {
275*da0073e9SAndroid Build Coastguard Worker       auto rat = c / d;
276*da0073e9SAndroid Build Coastguard Worker       auto scl = U(1.0) / (d + c * rat);
277*da0073e9SAndroid Build Coastguard Worker       real_ = (a * rat + b) * scl;
278*da0073e9SAndroid Build Coastguard Worker       imag_ = (b * rat - a) * scl;
279*da0073e9SAndroid Build Coastguard Worker     }
280*da0073e9SAndroid Build Coastguard Worker     return *this;
281*da0073e9SAndroid Build Coastguard Worker   }
282*da0073e9SAndroid Build Coastguard Worker #undef FORCE_INLINE_APPLE
283*da0073e9SAndroid Build Coastguard Worker 
284*da0073e9SAndroid Build Coastguard Worker   template <typename U>
285*da0073e9SAndroid Build Coastguard Worker   constexpr complex<T>& operator=(const std::complex<U>& rhs) {
286*da0073e9SAndroid Build Coastguard Worker     real_ = rhs.real();
287*da0073e9SAndroid Build Coastguard Worker     imag_ = rhs.imag();
288*da0073e9SAndroid Build Coastguard Worker     return *this;
289*da0073e9SAndroid Build Coastguard Worker   }
290*da0073e9SAndroid Build Coastguard Worker 
291*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
292*da0073e9SAndroid Build Coastguard Worker   template <typename U>
293*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
294*da0073e9SAndroid Build Coastguard Worker     real_ = rhs.real();
295*da0073e9SAndroid Build Coastguard Worker     imag_ = rhs.imag();
296*da0073e9SAndroid Build Coastguard Worker     return *this;
297*da0073e9SAndroid Build Coastguard Worker   }
298*da0073e9SAndroid Build Coastguard Worker #endif
299*da0073e9SAndroid Build Coastguard Worker 
300*da0073e9SAndroid Build Coastguard Worker   template <typename U>
complexcomplex301*da0073e9SAndroid Build Coastguard Worker   explicit constexpr operator std::complex<U>() const {
302*da0073e9SAndroid Build Coastguard Worker     return std::complex<U>(std::complex<T>(real(), imag()));
303*da0073e9SAndroid Build Coastguard Worker   }
304*da0073e9SAndroid Build Coastguard Worker 
305*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
306*da0073e9SAndroid Build Coastguard Worker   template <typename U>
complexcomplex307*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
308*da0073e9SAndroid Build Coastguard Worker     return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
309*da0073e9SAndroid Build Coastguard Worker   }
310*da0073e9SAndroid Build Coastguard Worker #endif
311*da0073e9SAndroid Build Coastguard Worker 
312*da0073e9SAndroid Build Coastguard Worker   // consistent with NumPy behavior
313*da0073e9SAndroid Build Coastguard Worker   explicit constexpr operator bool() const {
314*da0073e9SAndroid Build Coastguard Worker     return real() || imag();
315*da0073e9SAndroid Build Coastguard Worker   }
316*da0073e9SAndroid Build Coastguard Worker 
realcomplex317*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE constexpr T real() const {
318*da0073e9SAndroid Build Coastguard Worker     return real_;
319*da0073e9SAndroid Build Coastguard Worker   }
realcomplex320*da0073e9SAndroid Build Coastguard Worker   constexpr void real(T value) {
321*da0073e9SAndroid Build Coastguard Worker     real_ = value;
322*da0073e9SAndroid Build Coastguard Worker   }
imagcomplex323*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE constexpr T imag() const {
324*da0073e9SAndroid Build Coastguard Worker     return imag_;
325*da0073e9SAndroid Build Coastguard Worker   }
imagcomplex326*da0073e9SAndroid Build Coastguard Worker   constexpr void imag(T value) {
327*da0073e9SAndroid Build Coastguard Worker     imag_ = value;
328*da0073e9SAndroid Build Coastguard Worker   }
329*da0073e9SAndroid Build Coastguard Worker };
330*da0073e9SAndroid Build Coastguard Worker 
331*da0073e9SAndroid Build Coastguard Worker namespace complex_literals {
332*da0073e9SAndroid Build Coastguard Worker 
333*da0073e9SAndroid Build Coastguard Worker constexpr complex<float> operator""_if(long double imag) {
334*da0073e9SAndroid Build Coastguard Worker   return complex<float>(0.0f, static_cast<float>(imag));
335*da0073e9SAndroid Build Coastguard Worker }
336*da0073e9SAndroid Build Coastguard Worker 
337*da0073e9SAndroid Build Coastguard Worker constexpr complex<double> operator""_id(long double imag) {
338*da0073e9SAndroid Build Coastguard Worker   return complex<double>(0.0, static_cast<double>(imag));
339*da0073e9SAndroid Build Coastguard Worker }
340*da0073e9SAndroid Build Coastguard Worker 
341*da0073e9SAndroid Build Coastguard Worker constexpr complex<float> operator""_if(unsigned long long imag) {
342*da0073e9SAndroid Build Coastguard Worker   return complex<float>(0.0f, static_cast<float>(imag));
343*da0073e9SAndroid Build Coastguard Worker }
344*da0073e9SAndroid Build Coastguard Worker 
345*da0073e9SAndroid Build Coastguard Worker constexpr complex<double> operator""_id(unsigned long long imag) {
346*da0073e9SAndroid Build Coastguard Worker   return complex<double>(0.0, static_cast<double>(imag));
347*da0073e9SAndroid Build Coastguard Worker }
348*da0073e9SAndroid Build Coastguard Worker 
349*da0073e9SAndroid Build Coastguard Worker } // namespace complex_literals
350*da0073e9SAndroid Build Coastguard Worker 
351*da0073e9SAndroid Build Coastguard Worker template <typename T>
352*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator+(const complex<T>& val) {
353*da0073e9SAndroid Build Coastguard Worker   return val;
354*da0073e9SAndroid Build Coastguard Worker }
355*da0073e9SAndroid Build Coastguard Worker 
356*da0073e9SAndroid Build Coastguard Worker template <typename T>
357*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator-(const complex<T>& val) {
358*da0073e9SAndroid Build Coastguard Worker   return complex<T>(-val.real(), -val.imag());
359*da0073e9SAndroid Build Coastguard Worker }
360*da0073e9SAndroid Build Coastguard Worker 
361*da0073e9SAndroid Build Coastguard Worker template <typename T>
362*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
363*da0073e9SAndroid Build Coastguard Worker   complex<T> result = lhs;
364*da0073e9SAndroid Build Coastguard Worker   return result += rhs;
365*da0073e9SAndroid Build Coastguard Worker }
366*da0073e9SAndroid Build Coastguard Worker 
367*da0073e9SAndroid Build Coastguard Worker template <typename T>
368*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
369*da0073e9SAndroid Build Coastguard Worker   complex<T> result = lhs;
370*da0073e9SAndroid Build Coastguard Worker   return result += rhs;
371*da0073e9SAndroid Build Coastguard Worker }
372*da0073e9SAndroid Build Coastguard Worker 
373*da0073e9SAndroid Build Coastguard Worker template <typename T>
374*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
375*da0073e9SAndroid Build Coastguard Worker   return complex<T>(lhs + rhs.real(), rhs.imag());
376*da0073e9SAndroid Build Coastguard Worker }
377*da0073e9SAndroid Build Coastguard Worker 
378*da0073e9SAndroid Build Coastguard Worker template <typename T>
379*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
380*da0073e9SAndroid Build Coastguard Worker   complex<T> result = lhs;
381*da0073e9SAndroid Build Coastguard Worker   return result -= rhs;
382*da0073e9SAndroid Build Coastguard Worker }
383*da0073e9SAndroid Build Coastguard Worker 
384*da0073e9SAndroid Build Coastguard Worker template <typename T>
385*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
386*da0073e9SAndroid Build Coastguard Worker   complex<T> result = lhs;
387*da0073e9SAndroid Build Coastguard Worker   return result -= rhs;
388*da0073e9SAndroid Build Coastguard Worker }
389*da0073e9SAndroid Build Coastguard Worker 
390*da0073e9SAndroid Build Coastguard Worker template <typename T>
391*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
392*da0073e9SAndroid Build Coastguard Worker   complex<T> result = -rhs;
393*da0073e9SAndroid Build Coastguard Worker   return result += lhs;
394*da0073e9SAndroid Build Coastguard Worker }
395*da0073e9SAndroid Build Coastguard Worker 
396*da0073e9SAndroid Build Coastguard Worker template <typename T>
397*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
398*da0073e9SAndroid Build Coastguard Worker   complex<T> result = lhs;
399*da0073e9SAndroid Build Coastguard Worker   return result *= rhs;
400*da0073e9SAndroid Build Coastguard Worker }
401*da0073e9SAndroid Build Coastguard Worker 
402*da0073e9SAndroid Build Coastguard Worker template <typename T>
403*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
404*da0073e9SAndroid Build Coastguard Worker   complex<T> result = lhs;
405*da0073e9SAndroid Build Coastguard Worker   return result *= rhs;
406*da0073e9SAndroid Build Coastguard Worker }
407*da0073e9SAndroid Build Coastguard Worker 
408*da0073e9SAndroid Build Coastguard Worker template <typename T>
409*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
410*da0073e9SAndroid Build Coastguard Worker   complex<T> result = rhs;
411*da0073e9SAndroid Build Coastguard Worker   return result *= lhs;
412*da0073e9SAndroid Build Coastguard Worker }
413*da0073e9SAndroid Build Coastguard Worker 
414*da0073e9SAndroid Build Coastguard Worker template <typename T>
415*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
416*da0073e9SAndroid Build Coastguard Worker   complex<T> result = lhs;
417*da0073e9SAndroid Build Coastguard Worker   return result /= rhs;
418*da0073e9SAndroid Build Coastguard Worker }
419*da0073e9SAndroid Build Coastguard Worker 
420*da0073e9SAndroid Build Coastguard Worker template <typename T>
421*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
422*da0073e9SAndroid Build Coastguard Worker   complex<T> result = lhs;
423*da0073e9SAndroid Build Coastguard Worker   return result /= rhs;
424*da0073e9SAndroid Build Coastguard Worker }
425*da0073e9SAndroid Build Coastguard Worker 
426*da0073e9SAndroid Build Coastguard Worker template <typename T>
427*da0073e9SAndroid Build Coastguard Worker constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
428*da0073e9SAndroid Build Coastguard Worker   complex<T> result(lhs, T());
429*da0073e9SAndroid Build Coastguard Worker   return result /= rhs;
430*da0073e9SAndroid Build Coastguard Worker }
431*da0073e9SAndroid Build Coastguard Worker 
432*da0073e9SAndroid Build Coastguard Worker // Define operators between integral scalars and c10::complex. std::complex does
433*da0073e9SAndroid Build Coastguard Worker // not support this when T is a floating-point number. This is useful because it
434*da0073e9SAndroid Build Coastguard Worker // saves a lot of "static_cast" when operate a complex and an integer. This
435*da0073e9SAndroid Build Coastguard Worker // makes the code both less verbose and potentially more efficient.
436*da0073e9SAndroid Build Coastguard Worker #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION                 \
437*da0073e9SAndroid Build Coastguard Worker   typename std::enable_if_t<                                  \
438*da0073e9SAndroid Build Coastguard Worker       std::is_floating_point_v<fT> && std::is_integral_v<iT>, \
439*da0073e9SAndroid Build Coastguard Worker       int> = 0
440*da0073e9SAndroid Build Coastguard Worker 
441*da0073e9SAndroid Build Coastguard Worker template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
442*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
443*da0073e9SAndroid Build Coastguard Worker   return a + static_cast<fT>(b);
444*da0073e9SAndroid Build Coastguard Worker }
445*da0073e9SAndroid Build Coastguard Worker 
446*da0073e9SAndroid Build Coastguard Worker template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
447*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
448*da0073e9SAndroid Build Coastguard Worker   return static_cast<fT>(a) + b;
449*da0073e9SAndroid Build Coastguard Worker }
450*da0073e9SAndroid Build Coastguard Worker 
451*da0073e9SAndroid Build Coastguard Worker template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
452*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
453*da0073e9SAndroid Build Coastguard Worker   return a - static_cast<fT>(b);
454*da0073e9SAndroid Build Coastguard Worker }
455*da0073e9SAndroid Build Coastguard Worker 
456*da0073e9SAndroid Build Coastguard Worker template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
457*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
458*da0073e9SAndroid Build Coastguard Worker   return static_cast<fT>(a) - b;
459*da0073e9SAndroid Build Coastguard Worker }
460*da0073e9SAndroid Build Coastguard Worker 
461*da0073e9SAndroid Build Coastguard Worker template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
462*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
463*da0073e9SAndroid Build Coastguard Worker   return a * static_cast<fT>(b);
464*da0073e9SAndroid Build Coastguard Worker }
465*da0073e9SAndroid Build Coastguard Worker 
466*da0073e9SAndroid Build Coastguard Worker template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
467*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
468*da0073e9SAndroid Build Coastguard Worker   return static_cast<fT>(a) * b;
469*da0073e9SAndroid Build Coastguard Worker }
470*da0073e9SAndroid Build Coastguard Worker 
471*da0073e9SAndroid Build Coastguard Worker template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
472*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
473*da0073e9SAndroid Build Coastguard Worker   return a / static_cast<fT>(b);
474*da0073e9SAndroid Build Coastguard Worker }
475*da0073e9SAndroid Build Coastguard Worker 
476*da0073e9SAndroid Build Coastguard Worker template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
477*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
478*da0073e9SAndroid Build Coastguard Worker   return static_cast<fT>(a) / b;
479*da0073e9SAndroid Build Coastguard Worker }
480*da0073e9SAndroid Build Coastguard Worker 
481*da0073e9SAndroid Build Coastguard Worker #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
482*da0073e9SAndroid Build Coastguard Worker 
483*da0073e9SAndroid Build Coastguard Worker template <typename T>
484*da0073e9SAndroid Build Coastguard Worker constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
485*da0073e9SAndroid Build Coastguard Worker   return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
486*da0073e9SAndroid Build Coastguard Worker }
487*da0073e9SAndroid Build Coastguard Worker 
488*da0073e9SAndroid Build Coastguard Worker template <typename T>
489*da0073e9SAndroid Build Coastguard Worker constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
490*da0073e9SAndroid Build Coastguard Worker   return (lhs.real() == rhs) && (lhs.imag() == T());
491*da0073e9SAndroid Build Coastguard Worker }
492*da0073e9SAndroid Build Coastguard Worker 
493*da0073e9SAndroid Build Coastguard Worker template <typename T>
494*da0073e9SAndroid Build Coastguard Worker constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
495*da0073e9SAndroid Build Coastguard Worker   return (lhs == rhs.real()) && (T() == rhs.imag());
496*da0073e9SAndroid Build Coastguard Worker }
497*da0073e9SAndroid Build Coastguard Worker 
498*da0073e9SAndroid Build Coastguard Worker template <typename T>
499*da0073e9SAndroid Build Coastguard Worker constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
500*da0073e9SAndroid Build Coastguard Worker   return !(lhs == rhs);
501*da0073e9SAndroid Build Coastguard Worker }
502*da0073e9SAndroid Build Coastguard Worker 
503*da0073e9SAndroid Build Coastguard Worker template <typename T>
504*da0073e9SAndroid Build Coastguard Worker constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
505*da0073e9SAndroid Build Coastguard Worker   return !(lhs == rhs);
506*da0073e9SAndroid Build Coastguard Worker }
507*da0073e9SAndroid Build Coastguard Worker 
508*da0073e9SAndroid Build Coastguard Worker template <typename T>
509*da0073e9SAndroid Build Coastguard Worker constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
510*da0073e9SAndroid Build Coastguard Worker   return !(lhs == rhs);
511*da0073e9SAndroid Build Coastguard Worker }
512*da0073e9SAndroid Build Coastguard Worker 
513*da0073e9SAndroid Build Coastguard Worker template <typename T, typename CharT, typename Traits>
514*da0073e9SAndroid Build Coastguard Worker std::basic_ostream<CharT, Traits>& operator<<(
515*da0073e9SAndroid Build Coastguard Worker     std::basic_ostream<CharT, Traits>& os,
516*da0073e9SAndroid Build Coastguard Worker     const complex<T>& x) {
517*da0073e9SAndroid Build Coastguard Worker   return (os << static_cast<std::complex<T>>(x));
518*da0073e9SAndroid Build Coastguard Worker }
519*da0073e9SAndroid Build Coastguard Worker 
520*da0073e9SAndroid Build Coastguard Worker template <typename T, typename CharT, typename Traits>
521*da0073e9SAndroid Build Coastguard Worker std::basic_istream<CharT, Traits>& operator>>(
522*da0073e9SAndroid Build Coastguard Worker     std::basic_istream<CharT, Traits>& is,
523*da0073e9SAndroid Build Coastguard Worker     complex<T>& x) {
524*da0073e9SAndroid Build Coastguard Worker   std::complex<T> tmp;
525*da0073e9SAndroid Build Coastguard Worker   is >> tmp;
526*da0073e9SAndroid Build Coastguard Worker   x = tmp;
527*da0073e9SAndroid Build Coastguard Worker   return is;
528*da0073e9SAndroid Build Coastguard Worker }
529*da0073e9SAndroid Build Coastguard Worker 
530*da0073e9SAndroid Build Coastguard Worker } // namespace c10
531*da0073e9SAndroid Build Coastguard Worker 
532*da0073e9SAndroid Build Coastguard Worker // std functions
533*da0073e9SAndroid Build Coastguard Worker //
534*da0073e9SAndroid Build Coastguard Worker // The implementation of these functions also follow the design of C++20
535*da0073e9SAndroid Build Coastguard Worker 
536*da0073e9SAndroid Build Coastguard Worker namespace std {
537*da0073e9SAndroid Build Coastguard Worker 
538*da0073e9SAndroid Build Coastguard Worker template <typename T>
real(const c10::complex<T> & z)539*da0073e9SAndroid Build Coastguard Worker constexpr T real(const c10::complex<T>& z) {
540*da0073e9SAndroid Build Coastguard Worker   return z.real();
541*da0073e9SAndroid Build Coastguard Worker }
542*da0073e9SAndroid Build Coastguard Worker 
543*da0073e9SAndroid Build Coastguard Worker template <typename T>
imag(const c10::complex<T> & z)544*da0073e9SAndroid Build Coastguard Worker constexpr T imag(const c10::complex<T>& z) {
545*da0073e9SAndroid Build Coastguard Worker   return z.imag();
546*da0073e9SAndroid Build Coastguard Worker }
547*da0073e9SAndroid Build Coastguard Worker 
548*da0073e9SAndroid Build Coastguard Worker template <typename T>
abs(const c10::complex<T> & z)549*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
550*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
551*da0073e9SAndroid Build Coastguard Worker   return thrust::abs(static_cast<thrust::complex<T>>(z));
552*da0073e9SAndroid Build Coastguard Worker #else
553*da0073e9SAndroid Build Coastguard Worker   return std::abs(static_cast<std::complex<T>>(z));
554*da0073e9SAndroid Build Coastguard Worker #endif
555*da0073e9SAndroid Build Coastguard Worker }
556*da0073e9SAndroid Build Coastguard Worker 
557*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM)
558*da0073e9SAndroid Build Coastguard Worker #define ROCm_Bug(x)
559*da0073e9SAndroid Build Coastguard Worker #else
560*da0073e9SAndroid Build Coastguard Worker #define ROCm_Bug(x) x
561*da0073e9SAndroid Build Coastguard Worker #endif
562*da0073e9SAndroid Build Coastguard Worker 
563*da0073e9SAndroid Build Coastguard Worker template <typename T>
arg(const c10::complex<T> & z)564*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
565*da0073e9SAndroid Build Coastguard Worker   return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
566*da0073e9SAndroid Build Coastguard Worker }
567*da0073e9SAndroid Build Coastguard Worker 
568*da0073e9SAndroid Build Coastguard Worker #undef ROCm_Bug
569*da0073e9SAndroid Build Coastguard Worker 
570*da0073e9SAndroid Build Coastguard Worker template <typename T>
norm(const c10::complex<T> & z)571*da0073e9SAndroid Build Coastguard Worker constexpr T norm(const c10::complex<T>& z) {
572*da0073e9SAndroid Build Coastguard Worker   return z.real() * z.real() + z.imag() * z.imag();
573*da0073e9SAndroid Build Coastguard Worker }
574*da0073e9SAndroid Build Coastguard Worker 
575*da0073e9SAndroid Build Coastguard Worker // For std::conj, there are other versions of it:
576*da0073e9SAndroid Build Coastguard Worker //   constexpr std::complex<float> conj( float z );
577*da0073e9SAndroid Build Coastguard Worker //   template< class DoubleOrInteger >
578*da0073e9SAndroid Build Coastguard Worker //   constexpr std::complex<double> conj( DoubleOrInteger z );
579*da0073e9SAndroid Build Coastguard Worker //   constexpr std::complex<long double> conj( long double z );
580*da0073e9SAndroid Build Coastguard Worker // These are not implemented
581*da0073e9SAndroid Build Coastguard Worker // TODO(@zasdfgbnm): implement them as c10::conj
582*da0073e9SAndroid Build Coastguard Worker template <typename T>
conj(const c10::complex<T> & z)583*da0073e9SAndroid Build Coastguard Worker constexpr c10::complex<T> conj(const c10::complex<T>& z) {
584*da0073e9SAndroid Build Coastguard Worker   return c10::complex<T>(z.real(), -z.imag());
585*da0073e9SAndroid Build Coastguard Worker }
586*da0073e9SAndroid Build Coastguard Worker 
587*da0073e9SAndroid Build Coastguard Worker // Thrust does not have complex --> complex version of thrust::proj,
588*da0073e9SAndroid Build Coastguard Worker // so this function is not implemented at c10 right now.
589*da0073e9SAndroid Build Coastguard Worker // TODO(@zasdfgbnm): implement it by ourselves
590*da0073e9SAndroid Build Coastguard Worker 
591*da0073e9SAndroid Build Coastguard Worker // There is no c10 version of std::polar, because std::polar always
592*da0073e9SAndroid Build Coastguard Worker // returns std::complex. Use c10::polar instead;
593*da0073e9SAndroid Build Coastguard Worker 
594*da0073e9SAndroid Build Coastguard Worker } // namespace std
595*da0073e9SAndroid Build Coastguard Worker 
596*da0073e9SAndroid Build Coastguard Worker namespace c10 {
597*da0073e9SAndroid Build Coastguard Worker 
598*da0073e9SAndroid Build Coastguard Worker template <typename T>
599*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
600*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
601*da0073e9SAndroid Build Coastguard Worker   return static_cast<complex<T>>(thrust::polar(r, theta));
602*da0073e9SAndroid Build Coastguard Worker #else
603*da0073e9SAndroid Build Coastguard Worker   // std::polar() requires r >= 0, so spell out the explicit implementation to
604*da0073e9SAndroid Build Coastguard Worker   // avoid a branch.
605*da0073e9SAndroid Build Coastguard Worker   return complex<T>(r * std::cos(theta), r * std::sin(theta));
606*da0073e9SAndroid Build Coastguard Worker #endif
607*da0073e9SAndroid Build Coastguard Worker }
608*da0073e9SAndroid Build Coastguard Worker 
609*da0073e9SAndroid Build Coastguard Worker } // namespace c10
610*da0073e9SAndroid Build Coastguard Worker 
611*da0073e9SAndroid Build Coastguard Worker C10_CLANG_DIAGNOSTIC_POP()
612*da0073e9SAndroid Build Coastguard Worker 
613*da0073e9SAndroid Build Coastguard Worker #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
614*da0073e9SAndroid Build Coastguard Worker // math functions are included in a separate file
615*da0073e9SAndroid Build Coastguard Worker #include <c10/util/complex_math.h> // IWYU pragma: keep
616*da0073e9SAndroid Build Coastguard Worker // utilities for complex types
617*da0073e9SAndroid Build Coastguard Worker #include <c10/util/complex_utils.h> // IWYU pragma: keep
618*da0073e9SAndroid Build Coastguard Worker #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
619