xref: /aosp_15_r20/external/pytorch/c10/test/util/complex_test_common.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/macros/Macros.h>
2 #include <c10/util/complex.h>
3 #include <c10/util/hash.h>
4 #include <gtest/gtest.h>
5 #include <sstream>
6 #include <tuple>
7 #include <type_traits>
8 #include <unordered_map>
9 
10 #if (defined(__CUDACC__) || defined(__HIPCC__))
11 #define MAYBE_GLOBAL __global__
12 #else
13 #define MAYBE_GLOBAL
14 #endif
15 
16 #define PI 3.141592653589793238463
17 
18 namespace memory {
19 
test_size()20 MAYBE_GLOBAL void test_size() {
21   static_assert(sizeof(c10::complex<float>) == 2 * sizeof(float), "");
22   static_assert(sizeof(c10::complex<double>) == 2 * sizeof(double), "");
23 }
24 
test_align()25 MAYBE_GLOBAL void test_align() {
26   static_assert(alignof(c10::complex<float>) == 2 * sizeof(float), "");
27   static_assert(alignof(c10::complex<double>) == 2 * sizeof(double), "");
28 }
29 
test_pod()30 MAYBE_GLOBAL void test_pod() {
31   static_assert(std::is_standard_layout<c10::complex<float>>::value, "");
32   static_assert(std::is_standard_layout<c10::complex<double>>::value, "");
33 }
34 
TEST(TestMemory,ReinterpretCast)35 TEST(TestMemory, ReinterpretCast) {
36   {
37     std::complex<float> z(1, 2);
38     c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
39     ASSERT_EQ(zz.real(), float(1));
40     ASSERT_EQ(zz.imag(), float(2));
41   }
42 
43   {
44     c10::complex<float> z(3, 4);
45     std::complex<float> zz = *reinterpret_cast<std::complex<float>*>(&z);
46     ASSERT_EQ(zz.real(), float(3));
47     ASSERT_EQ(zz.imag(), float(4));
48   }
49 
50   {
51     std::complex<double> z(1, 2);
52     c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
53     ASSERT_EQ(zz.real(), double(1));
54     ASSERT_EQ(zz.imag(), double(2));
55   }
56 
57   {
58     c10::complex<double> z(3, 4);
59     std::complex<double> zz = *reinterpret_cast<std::complex<double>*>(&z);
60     ASSERT_EQ(zz.real(), double(3));
61     ASSERT_EQ(zz.imag(), double(4));
62   }
63 }
64 
65 #if defined(__CUDACC__) || defined(__HIPCC__)
TEST(TestMemory,ThrustReinterpretCast)66 TEST(TestMemory, ThrustReinterpretCast) {
67   {
68     thrust::complex<float> z(1, 2);
69     c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
70     ASSERT_EQ(zz.real(), float(1));
71     ASSERT_EQ(zz.imag(), float(2));
72   }
73 
74   {
75     c10::complex<float> z(3, 4);
76     thrust::complex<float> zz = *reinterpret_cast<thrust::complex<float>*>(&z);
77     ASSERT_EQ(zz.real(), float(3));
78     ASSERT_EQ(zz.imag(), float(4));
79   }
80 
81   {
82     thrust::complex<double> z(1, 2);
83     c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
84     ASSERT_EQ(zz.real(), double(1));
85     ASSERT_EQ(zz.imag(), double(2));
86   }
87 
88   {
89     c10::complex<double> z(3, 4);
90     thrust::complex<double> zz =
91         *reinterpret_cast<thrust::complex<double>*>(&z);
92     ASSERT_EQ(zz.real(), double(3));
93     ASSERT_EQ(zz.imag(), double(4));
94   }
95 }
96 #endif
97 
98 } // namespace memory
99 
100 namespace constructors {
101 
102 template <typename scalar_t>
test_construct_from_scalar()103 C10_HOST_DEVICE void test_construct_from_scalar() {
104   constexpr scalar_t num1 = scalar_t(1.23);
105   constexpr scalar_t num2 = scalar_t(4.56);
106   constexpr scalar_t zero = scalar_t();
107   static_assert(c10::complex<scalar_t>(num1, num2).real() == num1, "");
108   static_assert(c10::complex<scalar_t>(num1, num2).imag() == num2, "");
109   static_assert(c10::complex<scalar_t>(num1).real() == num1, "");
110   static_assert(c10::complex<scalar_t>(num1).imag() == zero, "");
111   static_assert(c10::complex<scalar_t>().real() == zero, "");
112   static_assert(c10::complex<scalar_t>().imag() == zero, "");
113 }
114 
115 template <typename scalar_t, typename other_t>
test_construct_from_other()116 C10_HOST_DEVICE void test_construct_from_other() {
117   constexpr other_t num1 = other_t(1.23);
118   constexpr other_t num2 = other_t(4.56);
119   constexpr scalar_t num3 = scalar_t(num1);
120   constexpr scalar_t num4 = scalar_t(num2);
121   static_assert(
122       c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).real() == num3,
123       "");
124   static_assert(
125       c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).imag() == num4,
126       "");
127 }
128 
test_convert_constructors()129 MAYBE_GLOBAL void test_convert_constructors() {
130   test_construct_from_scalar<float>();
131   test_construct_from_scalar<double>();
132 
133   static_assert(
134       std::is_convertible<c10::complex<float>, c10::complex<float>>::value, "");
135   static_assert(
136       !std::is_convertible<c10::complex<double>, c10::complex<float>>::value,
137       "");
138   static_assert(
139       std::is_convertible<c10::complex<float>, c10::complex<double>>::value,
140       "");
141   static_assert(
142       std::is_convertible<c10::complex<double>, c10::complex<double>>::value,
143       "");
144 
145   static_assert(
146       std::is_constructible<c10::complex<float>, c10::complex<float>>::value,
147       "");
148   static_assert(
149       std::is_constructible<c10::complex<double>, c10::complex<float>>::value,
150       "");
151   static_assert(
152       std::is_constructible<c10::complex<float>, c10::complex<double>>::value,
153       "");
154   static_assert(
155       std::is_constructible<c10::complex<double>, c10::complex<double>>::value,
156       "");
157 
158   test_construct_from_other<float, float>();
159   test_construct_from_other<float, double>();
160   test_construct_from_other<double, float>();
161   test_construct_from_other<double, double>();
162 }
163 
164 template <typename scalar_t>
test_construct_from_std()165 C10_HOST_DEVICE void test_construct_from_std() {
166   constexpr scalar_t num1 = scalar_t(1.23);
167   constexpr scalar_t num2 = scalar_t(4.56);
168   static_assert(
169       c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).real() == num1,
170       "");
171   static_assert(
172       c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).imag() == num2,
173       "");
174 }
175 
test_std_conversion()176 MAYBE_GLOBAL void test_std_conversion() {
177   test_construct_from_std<float>();
178   test_construct_from_std<double>();
179 }
180 
181 #if defined(__CUDACC__) || defined(__HIPCC__)
182 template <typename scalar_t>
test_construct_from_thrust()183 void test_construct_from_thrust() {
184   constexpr scalar_t num1 = scalar_t(1.23);
185   constexpr scalar_t num2 = scalar_t(4.56);
186   ASSERT_EQ(
187       c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).real(),
188       num1);
189   ASSERT_EQ(
190       c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).imag(),
191       num2);
192 }
193 
TEST(TestConstructors,FromThrust)194 TEST(TestConstructors, FromThrust) {
195   test_construct_from_thrust<float>();
196   test_construct_from_thrust<double>();
197 }
198 #endif
199 
TEST(TestConstructors,UnorderedMap)200 TEST(TestConstructors, UnorderedMap) {
201   std::unordered_map<
202       c10::complex<double>,
203       c10::complex<double>,
204       c10::hash<c10::complex<double>>>
205       m;
206   auto key1 = c10::complex<double>(2.5, 3);
207   auto key2 = c10::complex<double>(2, 0);
208   auto val1 = c10::complex<double>(2, -3.2);
209   auto val2 = c10::complex<double>(0, -3);
210   m[key1] = val1;
211   m[key2] = val2;
212   ASSERT_EQ(m[key1], val1);
213   ASSERT_EQ(m[key2], val2);
214 }
215 
216 } // namespace constructors
217 
218 namespace assignment {
219 
220 template <typename scalar_t>
one()221 constexpr c10::complex<scalar_t> one() {
222   c10::complex<scalar_t> result(3, 4);
223   result = scalar_t(1);
224   return result;
225 }
226 
test_assign_real()227 MAYBE_GLOBAL void test_assign_real() {
228   static_assert(one<float>().real() == float(1), "");
229   static_assert(one<float>().imag() == float(), "");
230   static_assert(one<double>().real() == double(1), "");
231   static_assert(one<double>().imag() == double(), "");
232 }
233 
one_two()234 constexpr std::tuple<c10::complex<double>, c10::complex<float>> one_two() {
235   constexpr c10::complex<float> src(1, 2);
236   c10::complex<double> ret0;
237   c10::complex<float> ret1;
238   ret0 = ret1 = src;
239   return std::make_tuple(ret0, ret1);
240 }
241 
test_assign_other()242 MAYBE_GLOBAL void test_assign_other() {
243   constexpr auto tup = one_two();
244   static_assert(std::get<c10::complex<double>>(tup).real() == double(1), "");
245   static_assert(std::get<c10::complex<double>>(tup).imag() == double(2), "");
246   static_assert(std::get<c10::complex<float>>(tup).real() == float(1), "");
247   static_assert(std::get<c10::complex<float>>(tup).imag() == float(2), "");
248 }
249 
one_two_std()250 constexpr std::tuple<c10::complex<double>, c10::complex<float>> one_two_std() {
251   constexpr std::complex<float> src(1, 1);
252   c10::complex<double> ret0;
253   c10::complex<float> ret1;
254   ret0 = ret1 = src;
255   return std::make_tuple(ret0, ret1);
256 }
257 
test_assign_std()258 MAYBE_GLOBAL void test_assign_std() {
259   constexpr auto tup = one_two();
260   static_assert(std::get<c10::complex<double>>(tup).real() == double(1), "");
261   static_assert(std::get<c10::complex<double>>(tup).imag() == double(2), "");
262   static_assert(std::get<c10::complex<float>>(tup).real() == float(1), "");
263   static_assert(std::get<c10::complex<float>>(tup).imag() == float(2), "");
264 }
265 
266 #if defined(__CUDACC__) || defined(__HIPCC__)
267 C10_HOST_DEVICE std::tuple<c10::complex<double>, c10::complex<float>>
one_two_thrust()268 one_two_thrust() {
269   thrust::complex<float> src(1, 2);
270   c10::complex<double> ret0;
271   c10::complex<float> ret1;
272   ret0 = ret1 = src;
273   return std::make_tuple(ret0, ret1);
274 }
275 
TEST(TestAssignment,FromThrust)276 TEST(TestAssignment, FromThrust) {
277   auto tup = one_two_thrust();
278   ASSERT_EQ(std::get<c10::complex<double>>(tup).real(), double(1));
279   ASSERT_EQ(std::get<c10::complex<double>>(tup).imag(), double(2));
280   ASSERT_EQ(std::get<c10::complex<float>>(tup).real(), float(1));
281   ASSERT_EQ(std::get<c10::complex<float>>(tup).imag(), float(2));
282 }
283 #endif
284 
285 } // namespace assignment
286 
287 namespace literals {
288 
test_complex_literals()289 MAYBE_GLOBAL void test_complex_literals() {
290   using namespace c10::complex_literals;
291   static_assert(std::is_same<decltype(0.5_if), c10::complex<float>>::value, "");
292   static_assert((0.5_if).real() == float(), "");
293   static_assert((0.5_if).imag() == float(0.5), "");
294   static_assert(
295       std::is_same<decltype(0.5_id), c10::complex<double>>::value, "");
296   static_assert((0.5_id).real() == float(), "");
297   static_assert((0.5_id).imag() == float(0.5), "");
298 
299   static_assert(std::is_same<decltype(1_if), c10::complex<float>>::value, "");
300   static_assert((1_if).real() == float(), "");
301   static_assert((1_if).imag() == float(1), "");
302   static_assert(std::is_same<decltype(1_id), c10::complex<double>>::value, "");
303   static_assert((1_id).real() == double(), "");
304   static_assert((1_id).imag() == double(1), "");
305 }
306 
307 } // namespace literals
308 
309 namespace real_imag {
310 
311 template <typename scalar_t>
zero_one()312 constexpr c10::complex<scalar_t> zero_one() {
313   c10::complex<scalar_t> result;
314   result.imag(scalar_t(1));
315   return result;
316 }
317 
318 template <typename scalar_t>
one_zero()319 constexpr c10::complex<scalar_t> one_zero() {
320   c10::complex<scalar_t> result;
321   result.real(scalar_t(1));
322   return result;
323 }
324 
test_real_imag_modify()325 MAYBE_GLOBAL void test_real_imag_modify() {
326   static_assert(zero_one<float>().real() == float(0), "");
327   static_assert(zero_one<float>().imag() == float(1), "");
328   static_assert(zero_one<double>().real() == double(0), "");
329   static_assert(zero_one<double>().imag() == double(1), "");
330 
331   static_assert(one_zero<float>().real() == float(1), "");
332   static_assert(one_zero<float>().imag() == float(0), "");
333   static_assert(one_zero<double>().real() == double(1), "");
334   static_assert(one_zero<double>().imag() == double(0), "");
335 }
336 
337 } // namespace real_imag
338 
339 namespace arithmetic_assign {
340 
341 template <typename scalar_t>
p(scalar_t value)342 constexpr c10::complex<scalar_t> p(scalar_t value) {
343   c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
344   result += value;
345   return result;
346 }
347 
348 template <typename scalar_t>
m(scalar_t value)349 constexpr c10::complex<scalar_t> m(scalar_t value) {
350   c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
351   result -= value;
352   return result;
353 }
354 
355 template <typename scalar_t>
t(scalar_t value)356 constexpr c10::complex<scalar_t> t(scalar_t value) {
357   c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
358   result *= value;
359   return result;
360 }
361 
362 template <typename scalar_t>
d(scalar_t value)363 constexpr c10::complex<scalar_t> d(scalar_t value) {
364   c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
365   result /= value;
366   return result;
367 }
368 
369 template <typename scalar_t>
test_arithmetic_assign_scalar()370 C10_HOST_DEVICE void test_arithmetic_assign_scalar() {
371   constexpr c10::complex<scalar_t> x = p(scalar_t(1));
372   static_assert(x.real() == scalar_t(3), "");
373   static_assert(x.imag() == scalar_t(2), "");
374   constexpr c10::complex<scalar_t> y = m(scalar_t(1));
375   static_assert(y.real() == scalar_t(1), "");
376   static_assert(y.imag() == scalar_t(2), "");
377   constexpr c10::complex<scalar_t> z = t(scalar_t(2));
378   static_assert(z.real() == scalar_t(4), "");
379   static_assert(z.imag() == scalar_t(4), "");
380   constexpr c10::complex<scalar_t> t = d(scalar_t(2));
381   static_assert(t.real() == scalar_t(1), "");
382   static_assert(t.imag() == scalar_t(1), "");
383 }
384 
385 template <typename scalar_t, typename rhs_t>
p(scalar_t real,scalar_t imag,c10::complex<rhs_t> rhs)386 constexpr c10::complex<scalar_t> p(
387     scalar_t real,
388     scalar_t imag,
389     c10::complex<rhs_t> rhs) {
390   c10::complex<scalar_t> result(real, imag);
391   result += rhs;
392   return result;
393 }
394 
395 template <typename scalar_t, typename rhs_t>
m(scalar_t real,scalar_t imag,c10::complex<rhs_t> rhs)396 constexpr c10::complex<scalar_t> m(
397     scalar_t real,
398     scalar_t imag,
399     c10::complex<rhs_t> rhs) {
400   c10::complex<scalar_t> result(real, imag);
401   result -= rhs;
402   return result;
403 }
404 
405 template <typename scalar_t, typename rhs_t>
t(scalar_t real,scalar_t imag,c10::complex<rhs_t> rhs)406 constexpr c10::complex<scalar_t> t(
407     scalar_t real,
408     scalar_t imag,
409     c10::complex<rhs_t> rhs) {
410   c10::complex<scalar_t> result(real, imag);
411   result *= rhs;
412   return result;
413 }
414 
415 template <typename scalar_t, typename rhs_t>
d(scalar_t real,scalar_t imag,c10::complex<rhs_t> rhs)416 constexpr c10::complex<scalar_t> d(
417     scalar_t real,
418     scalar_t imag,
419     c10::complex<rhs_t> rhs) {
420   c10::complex<scalar_t> result(real, imag);
421   result /= rhs;
422   return result;
423 }
424 
425 template <typename scalar_t>
test_arithmetic_assign_complex()426 C10_HOST_DEVICE void test_arithmetic_assign_complex() {
427   using namespace c10::complex_literals;
428   constexpr c10::complex<scalar_t> x2 = p(scalar_t(2), scalar_t(2), 1.0_if);
429   static_assert(x2.real() == scalar_t(2), "");
430   static_assert(x2.imag() == scalar_t(3), "");
431   constexpr c10::complex<scalar_t> x3 = p(scalar_t(2), scalar_t(2), 1.0_id);
432   static_assert(x3.real() == scalar_t(2), "");
433 
434   // this test is skipped due to a bug in constexpr evaluation
435   // in nvcc. This bug has already been fixed since CUDA 11.2
436 #if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020)
437   static_assert(x3.imag() == scalar_t(3), "");
438 #endif
439 
440   constexpr c10::complex<scalar_t> y2 = m(scalar_t(2), scalar_t(2), 1.0_if);
441   static_assert(y2.real() == scalar_t(2), "");
442   static_assert(y2.imag() == scalar_t(1), "");
443   constexpr c10::complex<scalar_t> y3 = m(scalar_t(2), scalar_t(2), 1.0_id);
444   static_assert(y3.real() == scalar_t(2), "");
445 
446   // this test is skipped due to a bug in constexpr evaluation
447   // in nvcc. This bug has already been fixed since CUDA 11.2
448 #if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020)
449   static_assert(y3.imag() == scalar_t(1), "");
450 #endif
451 
452   constexpr c10::complex<scalar_t> z2 = t(scalar_t(1), scalar_t(-2), 1.0_if);
453   static_assert(z2.real() == scalar_t(2), "");
454   static_assert(z2.imag() == scalar_t(1), "");
455   constexpr c10::complex<scalar_t> z3 = t(scalar_t(1), scalar_t(-2), 1.0_id);
456   static_assert(z3.real() == scalar_t(2), "");
457   static_assert(z3.imag() == scalar_t(1), "");
458 
459   constexpr c10::complex<scalar_t> t2 = d(scalar_t(-1), scalar_t(2), 1.0_if);
460   static_assert(t2.real() == scalar_t(2), "");
461   static_assert(t2.imag() == scalar_t(1), "");
462   constexpr c10::complex<scalar_t> t3 = d(scalar_t(-1), scalar_t(2), 1.0_id);
463   static_assert(t3.real() == scalar_t(2), "");
464   static_assert(t3.imag() == scalar_t(1), "");
465 }
466 
test_arithmetic_assign()467 MAYBE_GLOBAL void test_arithmetic_assign() {
468   test_arithmetic_assign_scalar<float>();
469   test_arithmetic_assign_scalar<double>();
470   test_arithmetic_assign_complex<float>();
471   test_arithmetic_assign_complex<double>();
472 }
473 
474 } // namespace arithmetic_assign
475 
476 namespace arithmetic {
477 
478 template <typename scalar_t>
test_arithmetic_()479 C10_HOST_DEVICE void test_arithmetic_() {
480   static_assert(
481       c10::complex<scalar_t>(1, 2) == +c10::complex<scalar_t>(1, 2), "");
482   static_assert(
483       c10::complex<scalar_t>(-1, -2) == -c10::complex<scalar_t>(1, 2), "");
484 
485   static_assert(
486       c10::complex<scalar_t>(1, 2) + c10::complex<scalar_t>(3, 4) ==
487           c10::complex<scalar_t>(4, 6),
488       "");
489   static_assert(
490       c10::complex<scalar_t>(1, 2) + scalar_t(3) ==
491           c10::complex<scalar_t>(4, 2),
492       "");
493   static_assert(
494       scalar_t(3) + c10::complex<scalar_t>(1, 2) ==
495           c10::complex<scalar_t>(4, 2),
496       "");
497 
498   static_assert(
499       c10::complex<scalar_t>(1, 2) - c10::complex<scalar_t>(3, 4) ==
500           c10::complex<scalar_t>(-2, -2),
501       "");
502   static_assert(
503       c10::complex<scalar_t>(1, 2) - scalar_t(3) ==
504           c10::complex<scalar_t>(-2, 2),
505       "");
506   static_assert(
507       scalar_t(3) - c10::complex<scalar_t>(1, 2) ==
508           c10::complex<scalar_t>(2, -2),
509       "");
510 
511   static_assert(
512       c10::complex<scalar_t>(1, 2) * c10::complex<scalar_t>(3, 4) ==
513           c10::complex<scalar_t>(-5, 10),
514       "");
515   static_assert(
516       c10::complex<scalar_t>(1, 2) * scalar_t(3) ==
517           c10::complex<scalar_t>(3, 6),
518       "");
519   static_assert(
520       scalar_t(3) * c10::complex<scalar_t>(1, 2) ==
521           c10::complex<scalar_t>(3, 6),
522       "");
523 
524   static_assert(
525       c10::complex<scalar_t>(-5, 10) / c10::complex<scalar_t>(3, 4) ==
526           c10::complex<scalar_t>(1, 2),
527       "");
528   static_assert(
529       c10::complex<scalar_t>(5, 10) / scalar_t(5) ==
530           c10::complex<scalar_t>(1, 2),
531       "");
532   static_assert(
533       scalar_t(25) / c10::complex<scalar_t>(3, 4) ==
534           c10::complex<scalar_t>(3, -4),
535       "");
536 }
537 
test_arithmetic()538 MAYBE_GLOBAL void test_arithmetic() {
539   test_arithmetic_<float>();
540   test_arithmetic_<double>();
541 }
542 
543 template <typename T, typename int_t>
test_binary_ops_for_int_type_(T real,T img,int_t num)544 void test_binary_ops_for_int_type_(T real, T img, int_t num) {
545   c10::complex<T> c(real, img);
546   ASSERT_EQ(c + num, c10::complex<T>(real + num, img));
547   ASSERT_EQ(num + c, c10::complex<T>(num + real, img));
548   ASSERT_EQ(c - num, c10::complex<T>(real - num, img));
549   ASSERT_EQ(num - c, c10::complex<T>(num - real, -img));
550   ASSERT_EQ(c * num, c10::complex<T>(real * num, img * num));
551   ASSERT_EQ(num * c, c10::complex<T>(num * real, num * img));
552   ASSERT_EQ(c / num, c10::complex<T>(real / num, img / num));
553   ASSERT_EQ(
554       num / c,
555       c10::complex<T>(num * real / std::norm(c), -num * img / std::norm(c)));
556 }
557 
558 template <typename T>
test_binary_ops_for_all_int_types_(T real,T img,int8_t i)559 void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) {
560   test_binary_ops_for_int_type_<T, int8_t>(real, img, i);
561   test_binary_ops_for_int_type_<T, int16_t>(real, img, i);
562   test_binary_ops_for_int_type_<T, int32_t>(real, img, i);
563   test_binary_ops_for_int_type_<T, int64_t>(real, img, i);
564 }
565 
TEST(TestArithmeticIntScalar,All)566 TEST(TestArithmeticIntScalar, All) {
567   test_binary_ops_for_all_int_types_<float>(1.0, 0.1, 1);
568   test_binary_ops_for_all_int_types_<double>(-1.3, -0.2, -2);
569 }
570 
571 } // namespace arithmetic
572 
573 namespace equality {
574 
575 template <typename scalar_t>
test_equality_()576 C10_HOST_DEVICE void test_equality_() {
577   static_assert(
578       c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(1, 2), "");
579   static_assert(c10::complex<scalar_t>(1, 0) == scalar_t(1), "");
580   static_assert(scalar_t(1) == c10::complex<scalar_t>(1, 0), "");
581   static_assert(
582       c10::complex<scalar_t>(1, 2) != c10::complex<scalar_t>(3, 4), "");
583   static_assert(c10::complex<scalar_t>(1, 2) != scalar_t(1), "");
584   static_assert(scalar_t(1) != c10::complex<scalar_t>(1, 2), "");
585 }
586 
test_equality()587 MAYBE_GLOBAL void test_equality() {
588   test_equality_<float>();
589   test_equality_<double>();
590 }
591 
592 } // namespace equality
593 
594 namespace io {
595 
596 template <typename scalar_t>
test_io_()597 void test_io_() {
598   std::stringstream ss;
599   c10::complex<scalar_t> a(1, 2);
600   ss << a;
601   ASSERT_EQ(ss.str(), "(1,2)");
602   ss.str("(3,4)");
603   ss >> a;
604   ASSERT_TRUE(a == c10::complex<scalar_t>(3, 4));
605 }
606 
TEST(TestIO,All)607 TEST(TestIO, All) {
608   test_io_<float>();
609   test_io_<double>();
610 }
611 
612 } // namespace io
613 
614 namespace test_std {
615 
616 template <typename scalar_t>
test_callable_()617 C10_HOST_DEVICE void test_callable_() {
618   static_assert(std::real(c10::complex<scalar_t>(1, 2)) == scalar_t(1), "");
619   static_assert(std::imag(c10::complex<scalar_t>(1, 2)) == scalar_t(2), "");
620   std::abs(c10::complex<scalar_t>(1, 2));
621   std::arg(c10::complex<scalar_t>(1, 2));
622   static_assert(std::norm(c10::complex<scalar_t>(3, 4)) == scalar_t(25), "");
623   static_assert(
624       std::conj(c10::complex<scalar_t>(3, 4)) == c10::complex<scalar_t>(3, -4),
625       "");
626   c10::polar(float(1), float(PI / 2));
627   c10::polar(double(1), double(PI / 2));
628 }
629 
test_callable()630 MAYBE_GLOBAL void test_callable() {
631   test_callable_<float>();
632   test_callable_<double>();
633 }
634 
635 template <typename scalar_t>
test_values_()636 void test_values_() {
637   ASSERT_EQ(std::abs(c10::complex<scalar_t>(3, 4)), scalar_t(5));
638   ASSERT_LT(std::abs(std::arg(c10::complex<scalar_t>(0, 1)) - PI / 2), 1e-6);
639   ASSERT_LT(
640       std::abs(
641           c10::polar(scalar_t(1), scalar_t(PI / 2)) -
642           c10::complex<scalar_t>(0, 1)),
643       1e-6);
644 }
645 
TEST(TestStd,BasicFunctions)646 TEST(TestStd, BasicFunctions) {
647   test_values_<float>();
648   test_values_<double>();
649   // CSQRT edge cases: checks for overflows which are likely to occur
650   // if square root is computed using polar form
651   ASSERT_LT(
652       std::abs(std::sqrt(c10::complex<float>(-1e20, -4988429.2)).real()), 3e-4);
653   ASSERT_LT(
654       std::abs(std::sqrt(c10::complex<double>(-1e60, -4988429.2)).real()),
655       3e-4);
656 }
657 
658 } // namespace test_std
659