xref: /aosp_15_r20/external/eigen/Eigen/src/Core/arch/CUDA/Complex.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 // Copyright (C) 2021 C. Antonio Sanchez <[email protected]>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_COMPLEX_CUDA_H
12 #define EIGEN_COMPLEX_CUDA_H
13 
14 // clang-format off
15 // Many std::complex methods such as operator+, operator-, operator* and
16 // operator/ are not constexpr. Due to this, GCC and older versions of clang do
17 // not treat them as device functions and thus Eigen functors making use of
18 // these operators fail to compile. Here, we manually specialize these
19 // operators and functors for complex types when building for CUDA to enable
20 // their use on-device.
21 
22 #if defined(EIGEN_CUDACC) && defined(EIGEN_GPU_COMPILE_PHASE)
23 
24 // ICC already specializes std::complex<float> and std::complex<double>
25 // operators, preventing us from making them device functions here.
26 // This will lead to silent runtime errors if the operators are used on device.
27 //
28 // To allow std::complex operator use on device, define _OVERRIDE_COMPLEX_SPECIALIZATION_
29 // prior to first inclusion of <complex>.  This prevents ICC from adding
30 // its own specializations, so our custom ones below can be used instead.
31 #if !(defined(EIGEN_COMP_ICC) && defined(_USE_COMPLEX_SPECIALIZATION_))
32 
33 // Import Eigen's internal operator specializations.
34 #define EIGEN_USING_STD_COMPLEX_OPERATORS           \
35   using Eigen::complex_operator_detail::operator+;  \
36   using Eigen::complex_operator_detail::operator-;  \
37   using Eigen::complex_operator_detail::operator*;  \
38   using Eigen::complex_operator_detail::operator/;  \
39   using Eigen::complex_operator_detail::operator+=; \
40   using Eigen::complex_operator_detail::operator-=; \
41   using Eigen::complex_operator_detail::operator*=; \
42   using Eigen::complex_operator_detail::operator/=; \
43   using Eigen::complex_operator_detail::operator==; \
44   using Eigen::complex_operator_detail::operator!=;
45 
46 namespace Eigen {
47 
48 // Specialized std::complex overloads.
49 namespace complex_operator_detail {
50 
51 template<typename T>
52 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
complex_multiply(const std::complex<T> & a,const std::complex<T> & b)53 std::complex<T> complex_multiply(const std::complex<T>& a, const std::complex<T>& b) {
54   const T a_real = numext::real(a);
55   const T a_imag = numext::imag(a);
56   const T b_real = numext::real(b);
57   const T b_imag = numext::imag(b);
58   return std::complex<T>(
59       a_real * b_real - a_imag * b_imag,
60       a_imag * b_real + a_real * b_imag);
61 }
62 
63 template<typename T>
64 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
complex_divide_fast(const std::complex<T> & a,const std::complex<T> & b)65 std::complex<T> complex_divide_fast(const std::complex<T>& a, const std::complex<T>& b) {
66   const T a_real = numext::real(a);
67   const T a_imag = numext::imag(a);
68   const T b_real = numext::real(b);
69   const T b_imag = numext::imag(b);
70   const T norm = (b_real * b_real + b_imag * b_imag);
71   return std::complex<T>((a_real * b_real + a_imag * b_imag) / norm,
72                           (a_imag * b_real - a_real * b_imag) / norm);
73 }
74 
75 template<typename T>
76 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
complex_divide_stable(const std::complex<T> & a,const std::complex<T> & b)77 std::complex<T> complex_divide_stable(const std::complex<T>& a, const std::complex<T>& b) {
78   const T a_real = numext::real(a);
79   const T a_imag = numext::imag(a);
80   const T b_real = numext::real(b);
81   const T b_imag = numext::imag(b);
82   // Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
83   // guards against over/under-flow.
84   const bool scale_imag = numext::abs(b_imag) <= numext::abs(b_real);
85   const T rscale = scale_imag ? T(1) : b_real / b_imag;
86   const T iscale = scale_imag ? b_imag / b_real : T(1);
87   const T denominator = b_real * rscale + b_imag * iscale;
88   return std::complex<T>((a_real * rscale + a_imag * iscale) / denominator,
89                          (a_imag * rscale - a_real * iscale) / denominator);
90 }
91 
92 template<typename T>
93 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
complex_divide(const std::complex<T> & a,const std::complex<T> & b)94 std::complex<T> complex_divide(const std::complex<T>& a, const std::complex<T>& b) {
95 #if EIGEN_FAST_MATH
96   return complex_divide_fast(a, b);
97 #else
98   return complex_divide_stable(a, b);
99 #endif
100 }
101 
102 // NOTE: We cannot specialize compound assignment operators with Scalar T,
103 //         (i.e.  operator@=(const T&), for @=+,-,*,/)
104 //       since they are already specialized for float/double/long double within
105 //       the standard <complex> header. We also do not specialize the stream
106 //       operators.
107 #define EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(T)                                    \
108                                                                                                 \
109 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
110 std::complex<T> operator+(const std::complex<T>& a) { return a; }                               \
111                                                                                                 \
112 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
113 std::complex<T> operator-(const std::complex<T>& a) {                                           \
114   return std::complex<T>(-numext::real(a), -numext::imag(a));                                   \
115 }                                                                                               \
116                                                                                                 \
117 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
118 std::complex<T> operator+(const std::complex<T>& a, const std::complex<T>& b) {                 \
119   return std::complex<T>(numext::real(a) + numext::real(b), numext::imag(a) + numext::imag(b)); \
120 }                                                                                               \
121                                                                                                 \
122 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
123 std::complex<T> operator+(const std::complex<T>& a, const T& b) {                               \
124   return std::complex<T>(numext::real(a) + b, numext::imag(a));                                 \
125 }                                                                                               \
126                                                                                                 \
127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
128 std::complex<T> operator+(const T& a, const std::complex<T>& b) {                               \
129   return std::complex<T>(a + numext::real(b), numext::imag(b));                                 \
130 }                                                                                               \
131                                                                                                 \
132 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
133 std::complex<T> operator-(const std::complex<T>& a, const std::complex<T>& b) {                 \
134   return std::complex<T>(numext::real(a) - numext::real(b), numext::imag(a) - numext::imag(b)); \
135 }                                                                                               \
136                                                                                                 \
137 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
138 std::complex<T> operator-(const std::complex<T>& a, const T& b) {                               \
139   return std::complex<T>(numext::real(a) - b, numext::imag(a));                                 \
140 }                                                                                               \
141                                                                                                 \
142 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
143 std::complex<T> operator-(const T& a, const std::complex<T>& b) {                               \
144   return std::complex<T>(a - numext::real(b), -numext::imag(b));                                \
145 }                                                                                               \
146                                                                                                 \
147 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
148 std::complex<T> operator*(const std::complex<T>& a, const std::complex<T>& b) {                 \
149   return complex_multiply(a, b);                                                                \
150 }                                                                                               \
151                                                                                                 \
152 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
153 std::complex<T> operator*(const std::complex<T>& a, const T& b) {                               \
154   return std::complex<T>(numext::real(a) * b, numext::imag(a) * b);                             \
155 }                                                                                               \
156                                                                                                 \
157 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
158 std::complex<T> operator*(const T& a, const std::complex<T>& b) {                               \
159   return std::complex<T>(a * numext::real(b), a * numext::imag(b));                             \
160 }                                                                                               \
161                                                                                                 \
162 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
163 std::complex<T> operator/(const std::complex<T>& a, const std::complex<T>& b) {                 \
164   return complex_divide(a, b);                                                                  \
165 }                                                                                               \
166                                                                                                 \
167 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
168 std::complex<T> operator/(const std::complex<T>& a, const T& b) {                               \
169   return std::complex<T>(numext::real(a) / b, numext::imag(a) / b);                             \
170 }                                                                                               \
171                                                                                                 \
172 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
173 std::complex<T> operator/(const T& a, const std::complex<T>& b) {                               \
174   return complex_divide(std::complex<T>(a, 0), b);                                              \
175 }                                                                                               \
176                                                                                                 \
177 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
178 std::complex<T>& operator+=(std::complex<T>& a, const std::complex<T>& b) {                     \
179   numext::real_ref(a) += numext::real(b);                                                       \
180   numext::imag_ref(a) += numext::imag(b);                                                       \
181   return a;                                                                                     \
182 }                                                                                               \
183                                                                                                 \
184 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
185 std::complex<T>& operator-=(std::complex<T>& a, const std::complex<T>& b) {                     \
186   numext::real_ref(a) -= numext::real(b);                                                       \
187   numext::imag_ref(a) -= numext::imag(b);                                                       \
188   return a;                                                                                     \
189 }                                                                                               \
190                                                                                                 \
191 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
192 std::complex<T>& operator*=(std::complex<T>& a, const std::complex<T>& b) {                     \
193   a = complex_multiply(a, b);                                                                   \
194   return a;                                                                                     \
195 }                                                                                               \
196                                                                                                 \
197 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
198 std::complex<T>& operator/=(std::complex<T>& a, const std::complex<T>& b) {                     \
199   a = complex_divide(a, b);                                                                     \
200   return  a;                                                                                    \
201 }                                                                                               \
202                                                                                                 \
203 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
204 bool operator==(const std::complex<T>& a, const std::complex<T>& b) {                           \
205   return numext::real(a) == numext::real(b) && numext::imag(a) == numext::imag(b);              \
206 }                                                                                               \
207                                                                                                 \
208 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
209 bool operator==(const std::complex<T>& a, const T& b) {                                         \
210   return numext::real(a) == b && numext::imag(a) == 0;                                          \
211 }                                                                                               \
212                                                                                                 \
213 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
214 bool operator==(const T& a, const std::complex<T>& b) {                                         \
215   return a  == numext::real(b) && 0 == numext::imag(b);                                         \
216 }                                                                                               \
217                                                                                                 \
218 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
219 bool operator!=(const std::complex<T>& a, const std::complex<T>& b) {                           \
220   return !(a == b);                                                                             \
221 }                                                                                               \
222                                                                                                 \
223 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
224 bool operator!=(const std::complex<T>& a, const T& b) {                                         \
225   return !(a == b);                                                                             \
226 }                                                                                               \
227                                                                                                 \
228 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE                                                           \
229 bool operator!=(const T& a, const std::complex<T>& b) {                                         \
230   return !(a == b);                                                                             \
231 }
232 
233 // Do not specialize for long double, since that reduces to double on device.
234 EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(float)
235 EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(double)
236 
237 #undef EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS
238 
239 
240 }  // namespace complex_operator_detail
241 
242 EIGEN_USING_STD_COMPLEX_OPERATORS
243 
244 namespace numext {
245 EIGEN_USING_STD_COMPLEX_OPERATORS
246 }  // namespace numext
247 
248 namespace internal {
249 EIGEN_USING_STD_COMPLEX_OPERATORS
250 
251 }  // namespace internal
252 }  // namespace Eigen
253 
254 #endif  // !(EIGEN_COMP_ICC && _USE_COMPLEX_SPECIALIZATION_)
255 
256 #endif  // EIGEN_CUDACC && EIGEN_GPU_COMPILE_PHASE
257 
258 #endif  // EIGEN_COMPLEX_CUDA_H
259