xref: /aosp_15_r20/external/eigen/unsupported/test/polynomialsolver.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1*bf2c3715SXin Li // This file is part of Eigen, a lightweight C++ template library
2*bf2c3715SXin Li // for linear algebra.
3*bf2c3715SXin Li //
4*bf2c3715SXin Li // Copyright (C) 2010 Manuel Yguel <[email protected]>
5*bf2c3715SXin Li //
6*bf2c3715SXin Li // This Source Code Form is subject to the terms of the Mozilla
7*bf2c3715SXin Li // Public License v. 2.0. If a copy of the MPL was not distributed
8*bf2c3715SXin Li // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9*bf2c3715SXin Li 
10*bf2c3715SXin Li #include "main.h"
11*bf2c3715SXin Li #include <unsupported/Eigen/Polynomials>
12*bf2c3715SXin Li #include <iostream>
13*bf2c3715SXin Li #include <algorithm>
14*bf2c3715SXin Li 
15*bf2c3715SXin Li using namespace std;
16*bf2c3715SXin Li 
17*bf2c3715SXin Li namespace Eigen {
18*bf2c3715SXin Li namespace internal {
19*bf2c3715SXin Li template<int Size>
20*bf2c3715SXin Li struct increment_if_fixed_size
21*bf2c3715SXin Li {
22*bf2c3715SXin Li   enum {
23*bf2c3715SXin Li     ret = (Size == Dynamic) ? Dynamic : Size+1
24*bf2c3715SXin Li   };
25*bf2c3715SXin Li };
26*bf2c3715SXin Li }
27*bf2c3715SXin Li }
28*bf2c3715SXin Li 
29*bf2c3715SXin Li template<typename PolynomialType>
polyder(const PolynomialType & p)30*bf2c3715SXin Li PolynomialType polyder(const PolynomialType& p)
31*bf2c3715SXin Li {
32*bf2c3715SXin Li   typedef typename PolynomialType::Scalar Scalar;
33*bf2c3715SXin Li   PolynomialType res(p.size());
34*bf2c3715SXin Li   for(Index i=1; i<p.size(); ++i)
35*bf2c3715SXin Li     res[i-1] = p[i]*Scalar(i);
36*bf2c3715SXin Li   res[p.size()-1] = 0.;
37*bf2c3715SXin Li   return res;
38*bf2c3715SXin Li }
39*bf2c3715SXin Li 
40*bf2c3715SXin Li template<int Deg, typename POLYNOMIAL, typename SOLVER>
aux_evalSolver(const POLYNOMIAL & pols,SOLVER & psolve)41*bf2c3715SXin Li bool aux_evalSolver( const POLYNOMIAL& pols, SOLVER& psolve )
42*bf2c3715SXin Li {
43*bf2c3715SXin Li   typedef typename POLYNOMIAL::Scalar Scalar;
44*bf2c3715SXin Li   typedef typename POLYNOMIAL::RealScalar RealScalar;
45*bf2c3715SXin Li 
46*bf2c3715SXin Li   typedef typename SOLVER::RootsType    RootsType;
47*bf2c3715SXin Li   typedef Matrix<RealScalar,Deg,1>      EvalRootsType;
48*bf2c3715SXin Li 
49*bf2c3715SXin Li   const Index deg = pols.size()-1;
50*bf2c3715SXin Li 
51*bf2c3715SXin Li   // Test template constructor from coefficient vector
52*bf2c3715SXin Li   SOLVER solve_constr (pols);
53*bf2c3715SXin Li 
54*bf2c3715SXin Li   psolve.compute( pols );
55*bf2c3715SXin Li   const RootsType& roots( psolve.roots() );
56*bf2c3715SXin Li   EvalRootsType evr( deg );
57*bf2c3715SXin Li   POLYNOMIAL pols_der = polyder(pols);
58*bf2c3715SXin Li   EvalRootsType der( deg );
59*bf2c3715SXin Li   for( int i=0; i<roots.size(); ++i ){
60*bf2c3715SXin Li     evr[i] = std::abs( poly_eval( pols, roots[i] ) );
61*bf2c3715SXin Li     der[i] = numext::maxi(RealScalar(1.), std::abs( poly_eval( pols_der, roots[i] ) ));
62*bf2c3715SXin Li   }
63*bf2c3715SXin Li 
64*bf2c3715SXin Li   // we need to divide by the magnitude of the derivative because
65*bf2c3715SXin Li   // with a high derivative is very small error in the value of the root
66*bf2c3715SXin Li   // yiels a very large error in the polynomial evaluation.
67*bf2c3715SXin Li   bool evalToZero = (evr.cwiseQuotient(der)).isZero( test_precision<Scalar>() );
68*bf2c3715SXin Li   if( !evalToZero )
69*bf2c3715SXin Li   {
70*bf2c3715SXin Li     cerr << "WRONG root: " << endl;
71*bf2c3715SXin Li     cerr << "Polynomial: " << pols.transpose() << endl;
72*bf2c3715SXin Li     cerr << "Roots found: " << roots.transpose() << endl;
73*bf2c3715SXin Li     cerr << "Abs value of the polynomial at the roots: " << evr.transpose() << endl;
74*bf2c3715SXin Li     cerr << endl;
75*bf2c3715SXin Li   }
76*bf2c3715SXin Li 
77*bf2c3715SXin Li   std::vector<RealScalar> rootModuli( roots.size() );
78*bf2c3715SXin Li   Map< EvalRootsType > aux( &rootModuli[0], roots.size() );
79*bf2c3715SXin Li   aux = roots.array().abs();
80*bf2c3715SXin Li   std::sort( rootModuli.begin(), rootModuli.end() );
81*bf2c3715SXin Li   bool distinctModuli=true;
82*bf2c3715SXin Li   for( size_t i=1; i<rootModuli.size() && distinctModuli; ++i )
83*bf2c3715SXin Li   {
84*bf2c3715SXin Li     if( internal::isApprox( rootModuli[i], rootModuli[i-1] ) ){
85*bf2c3715SXin Li       distinctModuli = false; }
86*bf2c3715SXin Li   }
87*bf2c3715SXin Li   VERIFY( evalToZero || !distinctModuli );
88*bf2c3715SXin Li 
89*bf2c3715SXin Li   return distinctModuli;
90*bf2c3715SXin Li }
91*bf2c3715SXin Li 
92*bf2c3715SXin Li 
93*bf2c3715SXin Li 
94*bf2c3715SXin Li 
95*bf2c3715SXin Li 
96*bf2c3715SXin Li 
97*bf2c3715SXin Li 
98*bf2c3715SXin Li template<int Deg, typename POLYNOMIAL>
evalSolver(const POLYNOMIAL & pols)99*bf2c3715SXin Li void evalSolver( const POLYNOMIAL& pols )
100*bf2c3715SXin Li {
101*bf2c3715SXin Li   typedef typename POLYNOMIAL::Scalar Scalar;
102*bf2c3715SXin Li 
103*bf2c3715SXin Li   typedef PolynomialSolver<Scalar, Deg > PolynomialSolverType;
104*bf2c3715SXin Li 
105*bf2c3715SXin Li   PolynomialSolverType psolve;
106*bf2c3715SXin Li   aux_evalSolver<Deg, POLYNOMIAL, PolynomialSolverType>( pols, psolve );
107*bf2c3715SXin Li }
108*bf2c3715SXin Li 
109*bf2c3715SXin Li 
110*bf2c3715SXin Li 
111*bf2c3715SXin Li 
112*bf2c3715SXin Li template< int Deg, typename POLYNOMIAL, typename ROOTS, typename REAL_ROOTS >
evalSolverSugarFunction(const POLYNOMIAL & pols,const ROOTS & roots,const REAL_ROOTS & real_roots)113*bf2c3715SXin Li void evalSolverSugarFunction( const POLYNOMIAL& pols, const ROOTS& roots, const REAL_ROOTS& real_roots )
114*bf2c3715SXin Li {
115*bf2c3715SXin Li   using std::sqrt;
116*bf2c3715SXin Li   typedef typename POLYNOMIAL::Scalar Scalar;
117*bf2c3715SXin Li   typedef typename POLYNOMIAL::RealScalar RealScalar;
118*bf2c3715SXin Li 
119*bf2c3715SXin Li   typedef PolynomialSolver<Scalar, Deg >              PolynomialSolverType;
120*bf2c3715SXin Li 
121*bf2c3715SXin Li   PolynomialSolverType psolve;
122*bf2c3715SXin Li   if( aux_evalSolver<Deg, POLYNOMIAL, PolynomialSolverType>( pols, psolve ) )
123*bf2c3715SXin Li   {
124*bf2c3715SXin Li     //It is supposed that
125*bf2c3715SXin Li     // 1) the roots found are correct
126*bf2c3715SXin Li     // 2) the roots have distinct moduli
127*bf2c3715SXin Li 
128*bf2c3715SXin Li     //Test realRoots
129*bf2c3715SXin Li     std::vector< RealScalar > calc_realRoots;
130*bf2c3715SXin Li     psolve.realRoots( calc_realRoots,  test_precision<RealScalar>());
131*bf2c3715SXin Li     VERIFY_IS_EQUAL( calc_realRoots.size() , (size_t)real_roots.size() );
132*bf2c3715SXin Li 
133*bf2c3715SXin Li     const RealScalar psPrec = sqrt( test_precision<RealScalar>() );
134*bf2c3715SXin Li 
135*bf2c3715SXin Li     for( size_t i=0; i<calc_realRoots.size(); ++i )
136*bf2c3715SXin Li     {
137*bf2c3715SXin Li       bool found = false;
138*bf2c3715SXin Li       for( size_t j=0; j<calc_realRoots.size()&& !found; ++j )
139*bf2c3715SXin Li       {
140*bf2c3715SXin Li         if( internal::isApprox( calc_realRoots[i], real_roots[j], psPrec ) ){
141*bf2c3715SXin Li           found = true; }
142*bf2c3715SXin Li       }
143*bf2c3715SXin Li       VERIFY( found );
144*bf2c3715SXin Li     }
145*bf2c3715SXin Li 
146*bf2c3715SXin Li     //Test greatestRoot
147*bf2c3715SXin Li     VERIFY( internal::isApprox( roots.array().abs().maxCoeff(),
148*bf2c3715SXin Li           abs( psolve.greatestRoot() ), psPrec ) );
149*bf2c3715SXin Li 
150*bf2c3715SXin Li     //Test smallestRoot
151*bf2c3715SXin Li     VERIFY( internal::isApprox( roots.array().abs().minCoeff(),
152*bf2c3715SXin Li           abs( psolve.smallestRoot() ), psPrec ) );
153*bf2c3715SXin Li 
154*bf2c3715SXin Li     bool hasRealRoot;
155*bf2c3715SXin Li     //Test absGreatestRealRoot
156*bf2c3715SXin Li     RealScalar r = psolve.absGreatestRealRoot( hasRealRoot );
157*bf2c3715SXin Li     VERIFY( hasRealRoot == (real_roots.size() > 0 ) );
158*bf2c3715SXin Li     if( hasRealRoot ){
159*bf2c3715SXin Li       VERIFY( internal::isApprox( real_roots.array().abs().maxCoeff(), abs(r), psPrec ) );  }
160*bf2c3715SXin Li 
161*bf2c3715SXin Li     //Test absSmallestRealRoot
162*bf2c3715SXin Li     r = psolve.absSmallestRealRoot( hasRealRoot );
163*bf2c3715SXin Li     VERIFY( hasRealRoot == (real_roots.size() > 0 ) );
164*bf2c3715SXin Li     if( hasRealRoot ){
165*bf2c3715SXin Li       VERIFY( internal::isApprox( real_roots.array().abs().minCoeff(), abs( r ), psPrec ) ); }
166*bf2c3715SXin Li 
167*bf2c3715SXin Li     //Test greatestRealRoot
168*bf2c3715SXin Li     r = psolve.greatestRealRoot( hasRealRoot );
169*bf2c3715SXin Li     VERIFY( hasRealRoot == (real_roots.size() > 0 ) );
170*bf2c3715SXin Li     if( hasRealRoot ){
171*bf2c3715SXin Li       VERIFY( internal::isApprox( real_roots.array().maxCoeff(), r, psPrec ) ); }
172*bf2c3715SXin Li 
173*bf2c3715SXin Li     //Test smallestRealRoot
174*bf2c3715SXin Li     r = psolve.smallestRealRoot( hasRealRoot );
175*bf2c3715SXin Li     VERIFY( hasRealRoot == (real_roots.size() > 0 ) );
176*bf2c3715SXin Li     if( hasRealRoot ){
177*bf2c3715SXin Li     VERIFY( internal::isApprox( real_roots.array().minCoeff(), r, psPrec ) ); }
178*bf2c3715SXin Li   }
179*bf2c3715SXin Li }
180*bf2c3715SXin Li 
181*bf2c3715SXin Li 
182*bf2c3715SXin Li template<typename _Scalar, int _Deg>
polynomialsolver(int deg)183*bf2c3715SXin Li void polynomialsolver(int deg)
184*bf2c3715SXin Li {
185*bf2c3715SXin Li   typedef typename NumTraits<_Scalar>::Real RealScalar;
186*bf2c3715SXin Li   typedef internal::increment_if_fixed_size<_Deg>     Dim;
187*bf2c3715SXin Li   typedef Matrix<_Scalar,Dim::ret,1>                  PolynomialType;
188*bf2c3715SXin Li   typedef Matrix<_Scalar,_Deg,1>                      EvalRootsType;
189*bf2c3715SXin Li   typedef Matrix<RealScalar,_Deg,1>                   RealRootsType;
190*bf2c3715SXin Li 
191*bf2c3715SXin Li   cout << "Standard cases" << endl;
192*bf2c3715SXin Li   PolynomialType pols = PolynomialType::Random(deg+1);
193*bf2c3715SXin Li   evalSolver<_Deg,PolynomialType>( pols );
194*bf2c3715SXin Li 
195*bf2c3715SXin Li   cout << "Hard cases" << endl;
196*bf2c3715SXin Li   _Scalar multipleRoot = internal::random<_Scalar>();
197*bf2c3715SXin Li   EvalRootsType allRoots = EvalRootsType::Constant(deg,multipleRoot);
198*bf2c3715SXin Li   roots_to_monicPolynomial( allRoots, pols );
199*bf2c3715SXin Li   evalSolver<_Deg,PolynomialType>( pols );
200*bf2c3715SXin Li 
201*bf2c3715SXin Li   cout << "Test sugar" << endl;
202*bf2c3715SXin Li   RealRootsType realRoots = RealRootsType::Random(deg);
203*bf2c3715SXin Li   roots_to_monicPolynomial( realRoots, pols );
204*bf2c3715SXin Li   evalSolverSugarFunction<_Deg>(
205*bf2c3715SXin Li       pols,
206*bf2c3715SXin Li       realRoots.template cast <std::complex<RealScalar> >().eval(),
207*bf2c3715SXin Li       realRoots );
208*bf2c3715SXin Li }
209*bf2c3715SXin Li 
EIGEN_DECLARE_TEST(polynomialsolver)210*bf2c3715SXin Li EIGEN_DECLARE_TEST(polynomialsolver)
211*bf2c3715SXin Li {
212*bf2c3715SXin Li   for(int i = 0; i < g_repeat; i++)
213*bf2c3715SXin Li   {
214*bf2c3715SXin Li     CALL_SUBTEST_1( (polynomialsolver<float,1>(1)) );
215*bf2c3715SXin Li     CALL_SUBTEST_2( (polynomialsolver<double,2>(2)) );
216*bf2c3715SXin Li     CALL_SUBTEST_3( (polynomialsolver<double,3>(3)) );
217*bf2c3715SXin Li     CALL_SUBTEST_4( (polynomialsolver<float,4>(4)) );
218*bf2c3715SXin Li     CALL_SUBTEST_5( (polynomialsolver<double,5>(5)) );
219*bf2c3715SXin Li     CALL_SUBTEST_6( (polynomialsolver<float,6>(6)) );
220*bf2c3715SXin Li     CALL_SUBTEST_7( (polynomialsolver<float,7>(7)) );
221*bf2c3715SXin Li     CALL_SUBTEST_8( (polynomialsolver<double,8>(8)) );
222*bf2c3715SXin Li 
223*bf2c3715SXin Li     CALL_SUBTEST_9( (polynomialsolver<float,Dynamic>(
224*bf2c3715SXin Li             internal::random<int>(9,13)
225*bf2c3715SXin Li             )) );
226*bf2c3715SXin Li     CALL_SUBTEST_10((polynomialsolver<double,Dynamic>(
227*bf2c3715SXin Li             internal::random<int>(9,13)
228*bf2c3715SXin Li             )) );
229*bf2c3715SXin Li     CALL_SUBTEST_11((polynomialsolver<float,Dynamic>(1)) );
230*bf2c3715SXin Li     CALL_SUBTEST_12((polynomialsolver<std::complex<double>,Dynamic>(internal::random<int>(2,13))) );
231*bf2c3715SXin Li   }
232*bf2c3715SXin Li }
233