xref: /aosp_15_r20/external/eigen/test/numext.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Gael Guennebaud <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 template<typename T, typename U>
check_if_equal_or_nans(const T & actual,const U & expected)13 bool check_if_equal_or_nans(const T& actual, const U& expected) {
14   return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
15 }
16 
17 template<typename T, typename U>
check_if_equal_or_nans(const std::complex<T> & actual,const std::complex<U> & expected)18 bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
19   return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
20          && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
21 }
22 
23 template<typename T, typename U>
test_is_equal_or_nans(const T & actual,const U & expected)24 bool test_is_equal_or_nans(const T& actual, const U& expected)
25 {
26     if (check_if_equal_or_nans(actual, expected)) {
27       return true;
28     }
29 
30     // false:
31     std::cerr
32         << "\n    actual   = " << actual
33         << "\n    expected = " << expected << "\n\n";
34     return false;
35 }
36 
37 #define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
38 
39 template<typename T>
check_abs()40 void check_abs() {
41   typedef typename NumTraits<T>::Real Real;
42   Real zero(0);
43 
44   if(NumTraits<T>::IsSigned)
45     VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
46   VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
47   VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
48 
49   for(int k=0; k<100; ++k)
50   {
51     T x = internal::random<T>();
52     if(!internal::is_same<T,bool>::value)
53       x = x/Real(2);
54     if(NumTraits<T>::IsSigned)
55     {
56       VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
57       VERIFY( numext::abs(-x) >= zero );
58     }
59     VERIFY( numext::abs(x) >= zero );
60     VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
61   }
62 }
63 
64 template<typename T>
check_arg()65 void check_arg() {
66   typedef typename NumTraits<T>::Real Real;
67   VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
68   VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
69 
70   for(int k=0; k<100; ++k)
71   {
72     T x = internal::random<T>();
73     Real y = numext::arg(x);
74     VERIFY_IS_APPROX( y, std::arg(x) );
75   }
76 }
77 
78 template<typename T>
79 struct check_sqrt_impl {
runcheck_sqrt_impl80   static void run() {
81     for (int i=0; i<1000; ++i) {
82       const T x = numext::abs(internal::random<T>());
83       const T sqrtx = numext::sqrt(x);
84       VERIFY_IS_APPROX(sqrtx*sqrtx, x);
85     }
86 
87     // Corner cases.
88     const T zero = T(0);
89     const T one = T(1);
90     const T inf = std::numeric_limits<T>::infinity();
91     const T nan = std::numeric_limits<T>::quiet_NaN();
92     VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
93     VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
94     VERIFY((numext::isnan)(numext::sqrt(nan)));
95     VERIFY((numext::isnan)(numext::sqrt(-one)));
96   }
97 };
98 
99 template<typename T>
100 struct check_sqrt_impl<std::complex<T>  > {
runcheck_sqrt_impl101   static void run() {
102     typedef typename std::complex<T> ComplexT;
103 
104     for (int i=0; i<1000; ++i) {
105       const ComplexT x = internal::random<ComplexT>();
106       const ComplexT sqrtx = numext::sqrt(x);
107       VERIFY_IS_APPROX(sqrtx*sqrtx, x);
108     }
109 
110     // Corner cases.
111     const T zero = T(0);
112     const T one = T(1);
113     const T inf = std::numeric_limits<T>::infinity();
114     const T nan = std::numeric_limits<T>::quiet_NaN();
115 
116     // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
117     const int kNumCorners = 20;
118     const ComplexT corners[kNumCorners][2] = {
119       {ComplexT(zero, zero), ComplexT(zero, zero)},
120       {ComplexT(-zero, zero), ComplexT(zero, zero)},
121       {ComplexT(zero, -zero), ComplexT(zero, zero)},
122       {ComplexT(-zero, -zero), ComplexT(zero, zero)},
123       {ComplexT(one, inf), ComplexT(inf, inf)},
124       {ComplexT(nan, inf), ComplexT(inf, inf)},
125       {ComplexT(one, -inf), ComplexT(inf, -inf)},
126       {ComplexT(nan, -inf), ComplexT(inf, -inf)},
127       {ComplexT(-inf, one), ComplexT(zero, inf)},
128       {ComplexT(inf, one), ComplexT(inf, zero)},
129       {ComplexT(-inf, -one), ComplexT(zero, -inf)},
130       {ComplexT(inf, -one), ComplexT(inf, -zero)},
131       {ComplexT(-inf, nan), ComplexT(nan, inf)},
132       {ComplexT(inf, nan), ComplexT(inf, nan)},
133       {ComplexT(zero, nan), ComplexT(nan, nan)},
134       {ComplexT(one, nan), ComplexT(nan, nan)},
135       {ComplexT(nan, zero), ComplexT(nan, nan)},
136       {ComplexT(nan, one), ComplexT(nan, nan)},
137       {ComplexT(nan, -one), ComplexT(nan, nan)},
138       {ComplexT(nan, nan), ComplexT(nan, nan)},
139     };
140 
141     for (int i=0; i<kNumCorners; ++i) {
142       const ComplexT& x = corners[i][0];
143       const ComplexT sqrtx = corners[i][1];
144       VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
145     }
146   }
147 };
148 
149 template<typename T>
check_sqrt()150 void check_sqrt() {
151   check_sqrt_impl<T>::run();
152 }
153 
154 template<typename T>
155 struct check_rsqrt_impl {
runcheck_rsqrt_impl156   static void run() {
157     const T zero = T(0);
158     const T one = T(1);
159     const T inf = std::numeric_limits<T>::infinity();
160     const T nan = std::numeric_limits<T>::quiet_NaN();
161 
162     for (int i=0; i<1000; ++i) {
163       const T x = numext::abs(internal::random<T>());
164       const T rsqrtx = numext::rsqrt(x);
165       const T invx = one / x;
166       VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
167     }
168 
169     // Corner cases.
170     VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
171     VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
172     VERIFY((numext::isnan)(numext::rsqrt(nan)));
173     VERIFY((numext::isnan)(numext::rsqrt(-one)));
174   }
175 };
176 
177 template<typename T>
178 struct check_rsqrt_impl<std::complex<T> > {
runcheck_rsqrt_impl179   static void run() {
180     typedef typename std::complex<T> ComplexT;
181     const T zero = T(0);
182     const T one = T(1);
183     const T inf = std::numeric_limits<T>::infinity();
184     const T nan = std::numeric_limits<T>::quiet_NaN();
185 
186     for (int i=0; i<1000; ++i) {
187       const ComplexT x = internal::random<ComplexT>();
188       const ComplexT invx = ComplexT(one, zero) / x;
189       const ComplexT rsqrtx = numext::rsqrt(x);
190       VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
191     }
192 
193     // GCC and MSVC differ in their treatment of 1/(0 + 0i)
194     //   GCC/clang = (inf, nan)
195     //   MSVC = (nan, nan)
196     // and 1 / (x + inf i)
197     //   GCC/clang = (0, 0)
198     //   MSVC = (nan, nan)
199     #if (EIGEN_COMP_GNUC)
200     {
201       const int kNumCorners = 20;
202       const ComplexT corners[kNumCorners][2] = {
203         // Only consistent across GCC, clang
204         {ComplexT(zero, zero), ComplexT(zero, zero)},
205         {ComplexT(-zero, zero), ComplexT(zero, zero)},
206         {ComplexT(zero, -zero), ComplexT(zero, zero)},
207         {ComplexT(-zero, -zero), ComplexT(zero, zero)},
208         {ComplexT(one, inf), ComplexT(inf, inf)},
209         {ComplexT(nan, inf), ComplexT(inf, inf)},
210         {ComplexT(one, -inf), ComplexT(inf, -inf)},
211         {ComplexT(nan, -inf), ComplexT(inf, -inf)},
212         // Consistent across GCC, clang, MSVC
213         {ComplexT(-inf, one), ComplexT(zero, inf)},
214         {ComplexT(inf, one), ComplexT(inf, zero)},
215         {ComplexT(-inf, -one), ComplexT(zero, -inf)},
216         {ComplexT(inf, -one), ComplexT(inf, -zero)},
217         {ComplexT(-inf, nan), ComplexT(nan, inf)},
218         {ComplexT(inf, nan), ComplexT(inf, nan)},
219         {ComplexT(zero, nan), ComplexT(nan, nan)},
220         {ComplexT(one, nan), ComplexT(nan, nan)},
221         {ComplexT(nan, zero), ComplexT(nan, nan)},
222         {ComplexT(nan, one), ComplexT(nan, nan)},
223         {ComplexT(nan, -one), ComplexT(nan, nan)},
224         {ComplexT(nan, nan), ComplexT(nan, nan)},
225       };
226 
227       for (int i=0; i<kNumCorners; ++i) {
228         const ComplexT& x = corners[i][0];
229         const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
230         VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
231       }
232     }
233     #endif
234   }
235 };
236 
237 template<typename T>
check_rsqrt()238 void check_rsqrt() {
239   check_rsqrt_impl<T>::run();
240 }
241 
EIGEN_DECLARE_TEST(numext)242 EIGEN_DECLARE_TEST(numext) {
243   for(int k=0; k<g_repeat; ++k)
244   {
245     CALL_SUBTEST( check_abs<bool>() );
246     CALL_SUBTEST( check_abs<signed char>() );
247     CALL_SUBTEST( check_abs<unsigned char>() );
248     CALL_SUBTEST( check_abs<short>() );
249     CALL_SUBTEST( check_abs<unsigned short>() );
250     CALL_SUBTEST( check_abs<int>() );
251     CALL_SUBTEST( check_abs<unsigned int>() );
252     CALL_SUBTEST( check_abs<long>() );
253     CALL_SUBTEST( check_abs<unsigned long>() );
254     CALL_SUBTEST( check_abs<half>() );
255     CALL_SUBTEST( check_abs<bfloat16>() );
256     CALL_SUBTEST( check_abs<float>() );
257     CALL_SUBTEST( check_abs<double>() );
258     CALL_SUBTEST( check_abs<long double>() );
259     CALL_SUBTEST( check_abs<std::complex<float> >() );
260     CALL_SUBTEST( check_abs<std::complex<double> >() );
261 
262     CALL_SUBTEST( check_arg<std::complex<float> >() );
263     CALL_SUBTEST( check_arg<std::complex<double> >() );
264 
265     CALL_SUBTEST( check_sqrt<float>() );
266     CALL_SUBTEST( check_sqrt<double>() );
267     CALL_SUBTEST( check_sqrt<std::complex<float> >() );
268     CALL_SUBTEST( check_sqrt<std::complex<double> >() );
269 
270     CALL_SUBTEST( check_rsqrt<float>() );
271     CALL_SUBTEST( check_rsqrt<double>() );
272     CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
273     CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
274   }
275 }
276