xref: /aosp_15_r20/external/pytorch/c10/util/complex.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <complex>
4 
5 #include <c10/macros/Macros.h>
6 
7 #if defined(__CUDACC__) || defined(__HIPCC__)
8 #include <thrust/complex.h>
9 #endif
10 
11 C10_CLANG_DIAGNOSTIC_PUSH()
12 #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
13 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
14 #endif
15 #if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
16 C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
17 #endif
18 
19 namespace c10 {
20 
21 // c10::complex is an implementation of complex numbers that aims
22 // to work on all devices supported by PyTorch
23 //
24 // Most of the APIs duplicates std::complex
25 // Reference: https://en.cppreference.com/w/cpp/numeric/complex
26 //
27 // [NOTE: Complex Operator Unification]
28 // Operators currently use a mix of std::complex, thrust::complex, and
29 // c10::complex internally. The end state is that all operators will use
30 // c10::complex internally.  Until then, there may be some hacks to support all
31 // variants.
32 //
33 //
34 // [Note on Constructors]
35 //
36 // The APIs of constructors are mostly copied from C++ standard:
37 //   https://en.cppreference.com/w/cpp/numeric/complex/complex
38 //
39 // Since C++14, all constructors are constexpr in std::complex
40 //
41 // There are three types of constructors:
42 // - initializing from real and imag:
43 //     `constexpr complex( const T& re = T(), const T& im = T() );`
44 // - implicitly-declared copy constructor
45 // - converting constructors
46 //
47 // Converting constructors:
48 // - std::complex defines converting constructor between float/double/long
49 // double,
50 //   while we define converting constructor between float/double.
51 // - For these converting constructors, upcasting is implicit, downcasting is
52 //   explicit.
53 // - We also define explicit casting from std::complex/thrust::complex
54 //   - Note that the conversion from thrust is not constexpr, because
55 //     thrust does not define them as constexpr ????
56 //
57 //
58 // [Operator =]
59 //
60 // The APIs of operator = are mostly copied from C++ standard:
61 //   https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
62 //
63 // Since C++20, all operator= are constexpr. Although we are not building with
64 // C++20, we also obey this behavior.
65 //
66 // There are three types of assign operator:
67 // - Assign a real value from the same scalar type
68 //   - In std, this is templated as complex& operator=(const T& x)
69 //     with specialization `complex& operator=(T x)` for float/double/long
70 //     double Since we only support float and double, on will use `complex&
71 //     operator=(T x)`
72 // - Copy assignment operator and converting assignment operator
73 //   - There is no specialization of converting assignment operators, which type
74 //   is
75 //     convertible is solely dependent on whether the scalar type is convertible
76 //
77 // In addition to the standard assignment, we also provide assignment operators
78 // with std and thrust
79 //
80 //
81 // [Casting operators]
82 //
83 // std::complex does not have casting operators. We define casting operators
84 // casting to std::complex and thrust::complex
85 //
86 //
87 // [Operator ""]
88 //
89 // std::complex has custom literals `i`, `if` and `il` defined in namespace
90 // `std::literals::complex_literals`. We define our own custom literals in the
91 // namespace `c10::complex_literals`. Our custom literals does not follow the
92 // same behavior as in std::complex, instead, we define _if, _id to construct
93 // float/double complex literals.
94 //
95 //
96 // [real() and imag()]
97 //
98 // In C++20, there are two overload of these functions, one it to return the
99 // real/imag, another is to set real/imag, they are both constexpr. We follow
100 // this design.
101 //
102 //
103 // [Operator +=,-=,*=,/=]
104 //
105 // Since C++20, these operators become constexpr. In our implementation, they
106 // are also constexpr.
107 //
108 // There are two types of such operators: operating with a real number, or
109 // operating with another complex number. For the operating with a real number,
110 // the generic template form has argument type `const T &`, while the overload
111 // for float/double/long double has `T`. We will follow the same type as
112 // float/double/long double in std.
113 //
114 // [Unary operator +-]
115 //
116 // Since C++20, they are constexpr. We also make them expr
117 //
118 // [Binary operators +-*/]
119 //
120 // Each operator has three versions (taking + as example):
121 // - complex + complex
122 // - complex + real
123 // - real + complex
124 //
125 // [Operator ==, !=]
126 //
127 // Each operator has three versions (taking == as example):
128 // - complex == complex
129 // - complex == real
130 // - real == complex
131 //
132 // Some of them are removed on C++20, but we decide to keep them
133 //
134 // [Operator <<, >>]
135 //
136 // These are implemented by casting to std::complex
137 //
138 //
139 //
140 // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
141 // because:
142 //  - lots of members and functions of c10::Half are not constexpr
143 //  - thrust::complex only support float and double
144 
145 template <typename T>
146 struct alignas(sizeof(T) * 2) complex {
147   using value_type = T;
148 
149   T real_ = T(0);
150   T imag_ = T(0);
151 
152   constexpr complex() = default;
153   C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
real_complex154       : real_(re), imag_(im) {}
155   template <typename U>
complexcomplex156   explicit constexpr complex(const std::complex<U>& other)
157       : complex(other.real(), other.imag()) {}
158 #if defined(__CUDACC__) || defined(__HIPCC__)
159   template <typename U>
complexcomplex160   explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
161       : real_(other.real()), imag_(other.imag()) {}
162 // NOTE can not be implemented as follow due to ROCm bug:
163 //   explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
164 //   complex(other.real(), other.imag()) {}
165 #endif
166 
167   // Use SFINAE to specialize casting constructor for c10::complex<float> and
168   // c10::complex<double>
169   template <typename U = T>
complexcomplex170   C10_HOST_DEVICE explicit constexpr complex(
171       const std::enable_if_t<std::is_same_v<U, float>, complex<double>>& other)
172       : real_(other.real_), imag_(other.imag_) {}
173   template <typename U = T>
complexcomplex174   C10_HOST_DEVICE constexpr complex(
175       const std::enable_if_t<std::is_same_v<U, double>, complex<float>>& other)
176       : real_(other.real_), imag_(other.imag_) {}
177 
178   constexpr complex<T>& operator=(T re) {
179     real_ = re;
180     imag_ = 0;
181     return *this;
182   }
183 
184   constexpr complex<T>& operator+=(T re) {
185     real_ += re;
186     return *this;
187   }
188 
189   constexpr complex<T>& operator-=(T re) {
190     real_ -= re;
191     return *this;
192   }
193 
194   constexpr complex<T>& operator*=(T re) {
195     real_ *= re;
196     imag_ *= re;
197     return *this;
198   }
199 
200   constexpr complex<T>& operator/=(T re) {
201     real_ /= re;
202     imag_ /= re;
203     return *this;
204   }
205 
206   template <typename U>
207   constexpr complex<T>& operator=(const complex<U>& rhs) {
208     real_ = rhs.real();
209     imag_ = rhs.imag();
210     return *this;
211   }
212 
213   template <typename U>
214   constexpr complex<T>& operator+=(const complex<U>& rhs) {
215     real_ += rhs.real();
216     imag_ += rhs.imag();
217     return *this;
218   }
219 
220   template <typename U>
221   constexpr complex<T>& operator-=(const complex<U>& rhs) {
222     real_ -= rhs.real();
223     imag_ -= rhs.imag();
224     return *this;
225   }
226 
227   template <typename U>
228   constexpr complex<T>& operator*=(const complex<U>& rhs) {
229     // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
230     T a = real_;
231     T b = imag_;
232     U c = rhs.real();
233     U d = rhs.imag();
234     real_ = a * c - b * d;
235     imag_ = a * d + b * c;
236     return *this;
237   }
238 
239 #ifdef __APPLE__
240 #define FORCE_INLINE_APPLE __attribute__((always_inline))
241 #else
242 #define FORCE_INLINE_APPLE
243 #endif
244   template <typename U>
245   constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
246       __ubsan_ignore_float_divide_by_zero__ {
247     // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
248     // the calculation below follows numpy's complex division
249     T a = real_;
250     T b = imag_;
251     U c = rhs.real();
252     U d = rhs.imag();
253 
254 #if defined(__GNUC__) && !defined(__clang__)
255     // std::abs is already constexpr by gcc
256     auto abs_c = std::abs(c);
257     auto abs_d = std::abs(d);
258 #else
259     auto abs_c = c < 0 ? -c : c;
260     auto abs_d = d < 0 ? -d : d;
261 #endif
262 
263     if (abs_c >= abs_d) {
264       if (abs_c == U(0) && abs_d == U(0)) {
265         /* divide by zeros should yield a complex inf or nan */
266         real_ = a / abs_c;
267         imag_ = b / abs_d;
268       } else {
269         auto rat = d / c;
270         auto scl = U(1.0) / (c + d * rat);
271         real_ = (a + b * rat) * scl;
272         imag_ = (b - a * rat) * scl;
273       }
274     } else {
275       auto rat = c / d;
276       auto scl = U(1.0) / (d + c * rat);
277       real_ = (a * rat + b) * scl;
278       imag_ = (b * rat - a) * scl;
279     }
280     return *this;
281   }
282 #undef FORCE_INLINE_APPLE
283 
284   template <typename U>
285   constexpr complex<T>& operator=(const std::complex<U>& rhs) {
286     real_ = rhs.real();
287     imag_ = rhs.imag();
288     return *this;
289   }
290 
291 #if defined(__CUDACC__) || defined(__HIPCC__)
292   template <typename U>
293   C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
294     real_ = rhs.real();
295     imag_ = rhs.imag();
296     return *this;
297   }
298 #endif
299 
300   template <typename U>
complexcomplex301   explicit constexpr operator std::complex<U>() const {
302     return std::complex<U>(std::complex<T>(real(), imag()));
303   }
304 
305 #if defined(__CUDACC__) || defined(__HIPCC__)
306   template <typename U>
complexcomplex307   C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
308     return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
309   }
310 #endif
311 
312   // consistent with NumPy behavior
313   explicit constexpr operator bool() const {
314     return real() || imag();
315   }
316 
realcomplex317   C10_HOST_DEVICE constexpr T real() const {
318     return real_;
319   }
realcomplex320   constexpr void real(T value) {
321     real_ = value;
322   }
imagcomplex323   C10_HOST_DEVICE constexpr T imag() const {
324     return imag_;
325   }
imagcomplex326   constexpr void imag(T value) {
327     imag_ = value;
328   }
329 };
330 
331 namespace complex_literals {
332 
333 constexpr complex<float> operator""_if(long double imag) {
334   return complex<float>(0.0f, static_cast<float>(imag));
335 }
336 
337 constexpr complex<double> operator""_id(long double imag) {
338   return complex<double>(0.0, static_cast<double>(imag));
339 }
340 
341 constexpr complex<float> operator""_if(unsigned long long imag) {
342   return complex<float>(0.0f, static_cast<float>(imag));
343 }
344 
345 constexpr complex<double> operator""_id(unsigned long long imag) {
346   return complex<double>(0.0, static_cast<double>(imag));
347 }
348 
349 } // namespace complex_literals
350 
351 template <typename T>
352 constexpr complex<T> operator+(const complex<T>& val) {
353   return val;
354 }
355 
356 template <typename T>
357 constexpr complex<T> operator-(const complex<T>& val) {
358   return complex<T>(-val.real(), -val.imag());
359 }
360 
361 template <typename T>
362 constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
363   complex<T> result = lhs;
364   return result += rhs;
365 }
366 
367 template <typename T>
368 constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
369   complex<T> result = lhs;
370   return result += rhs;
371 }
372 
373 template <typename T>
374 constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
375   return complex<T>(lhs + rhs.real(), rhs.imag());
376 }
377 
378 template <typename T>
379 constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
380   complex<T> result = lhs;
381   return result -= rhs;
382 }
383 
384 template <typename T>
385 constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
386   complex<T> result = lhs;
387   return result -= rhs;
388 }
389 
390 template <typename T>
391 constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
392   complex<T> result = -rhs;
393   return result += lhs;
394 }
395 
396 template <typename T>
397 constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
398   complex<T> result = lhs;
399   return result *= rhs;
400 }
401 
402 template <typename T>
403 constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
404   complex<T> result = lhs;
405   return result *= rhs;
406 }
407 
408 template <typename T>
409 constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
410   complex<T> result = rhs;
411   return result *= lhs;
412 }
413 
414 template <typename T>
415 constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
416   complex<T> result = lhs;
417   return result /= rhs;
418 }
419 
420 template <typename T>
421 constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
422   complex<T> result = lhs;
423   return result /= rhs;
424 }
425 
426 template <typename T>
427 constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
428   complex<T> result(lhs, T());
429   return result /= rhs;
430 }
431 
432 // Define operators between integral scalars and c10::complex. std::complex does
433 // not support this when T is a floating-point number. This is useful because it
434 // saves a lot of "static_cast" when operate a complex and an integer. This
435 // makes the code both less verbose and potentially more efficient.
436 #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION                 \
437   typename std::enable_if_t<                                  \
438       std::is_floating_point_v<fT> && std::is_integral_v<iT>, \
439       int> = 0
440 
441 template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
442 constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
443   return a + static_cast<fT>(b);
444 }
445 
446 template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
447 constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
448   return static_cast<fT>(a) + b;
449 }
450 
451 template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
452 constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
453   return a - static_cast<fT>(b);
454 }
455 
456 template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
457 constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
458   return static_cast<fT>(a) - b;
459 }
460 
461 template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
462 constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
463   return a * static_cast<fT>(b);
464 }
465 
466 template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
467 constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
468   return static_cast<fT>(a) * b;
469 }
470 
471 template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
472 constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
473   return a / static_cast<fT>(b);
474 }
475 
476 template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
477 constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
478   return static_cast<fT>(a) / b;
479 }
480 
481 #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
482 
483 template <typename T>
484 constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
485   return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
486 }
487 
488 template <typename T>
489 constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
490   return (lhs.real() == rhs) && (lhs.imag() == T());
491 }
492 
493 template <typename T>
494 constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
495   return (lhs == rhs.real()) && (T() == rhs.imag());
496 }
497 
498 template <typename T>
499 constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
500   return !(lhs == rhs);
501 }
502 
503 template <typename T>
504 constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
505   return !(lhs == rhs);
506 }
507 
508 template <typename T>
509 constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
510   return !(lhs == rhs);
511 }
512 
513 template <typename T, typename CharT, typename Traits>
514 std::basic_ostream<CharT, Traits>& operator<<(
515     std::basic_ostream<CharT, Traits>& os,
516     const complex<T>& x) {
517   return (os << static_cast<std::complex<T>>(x));
518 }
519 
520 template <typename T, typename CharT, typename Traits>
521 std::basic_istream<CharT, Traits>& operator>>(
522     std::basic_istream<CharT, Traits>& is,
523     complex<T>& x) {
524   std::complex<T> tmp;
525   is >> tmp;
526   x = tmp;
527   return is;
528 }
529 
530 } // namespace c10
531 
532 // std functions
533 //
534 // The implementation of these functions also follow the design of C++20
535 
536 namespace std {
537 
538 template <typename T>
real(const c10::complex<T> & z)539 constexpr T real(const c10::complex<T>& z) {
540   return z.real();
541 }
542 
543 template <typename T>
imag(const c10::complex<T> & z)544 constexpr T imag(const c10::complex<T>& z) {
545   return z.imag();
546 }
547 
548 template <typename T>
abs(const c10::complex<T> & z)549 C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
550 #if defined(__CUDACC__) || defined(__HIPCC__)
551   return thrust::abs(static_cast<thrust::complex<T>>(z));
552 #else
553   return std::abs(static_cast<std::complex<T>>(z));
554 #endif
555 }
556 
557 #if defined(USE_ROCM)
558 #define ROCm_Bug(x)
559 #else
560 #define ROCm_Bug(x) x
561 #endif
562 
563 template <typename T>
arg(const c10::complex<T> & z)564 C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
565   return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
566 }
567 
568 #undef ROCm_Bug
569 
570 template <typename T>
norm(const c10::complex<T> & z)571 constexpr T norm(const c10::complex<T>& z) {
572   return z.real() * z.real() + z.imag() * z.imag();
573 }
574 
575 // For std::conj, there are other versions of it:
576 //   constexpr std::complex<float> conj( float z );
577 //   template< class DoubleOrInteger >
578 //   constexpr std::complex<double> conj( DoubleOrInteger z );
579 //   constexpr std::complex<long double> conj( long double z );
580 // These are not implemented
581 // TODO(@zasdfgbnm): implement them as c10::conj
582 template <typename T>
conj(const c10::complex<T> & z)583 constexpr c10::complex<T> conj(const c10::complex<T>& z) {
584   return c10::complex<T>(z.real(), -z.imag());
585 }
586 
587 // Thrust does not have complex --> complex version of thrust::proj,
588 // so this function is not implemented at c10 right now.
589 // TODO(@zasdfgbnm): implement it by ourselves
590 
591 // There is no c10 version of std::polar, because std::polar always
592 // returns std::complex. Use c10::polar instead;
593 
594 } // namespace std
595 
596 namespace c10 {
597 
598 template <typename T>
599 C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
600 #if defined(__CUDACC__) || defined(__HIPCC__)
601   return static_cast<complex<T>>(thrust::polar(r, theta));
602 #else
603   // std::polar() requires r >= 0, so spell out the explicit implementation to
604   // avoid a branch.
605   return complex<T>(r * std::cos(theta), r * std::sin(theta));
606 #endif
607 }
608 
609 } // namespace c10
610 
611 C10_CLANG_DIAGNOSTIC_POP()
612 
613 #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
614 // math functions are included in a separate file
615 #include <c10/util/complex_math.h> // IWYU pragma: keep
616 // utilities for complex types
617 #include <c10/util/complex_utils.h> // IWYU pragma: keep
618 #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
619