xref: /aosp_15_r20/external/eigen/bench/dense_solvers.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1*bf2c3715SXin Li #include <iostream>
2*bf2c3715SXin Li #include "BenchTimer.h"
3*bf2c3715SXin Li #include <Eigen/Dense>
4*bf2c3715SXin Li #include <map>
5*bf2c3715SXin Li #include <vector>
6*bf2c3715SXin Li #include <string>
7*bf2c3715SXin Li #include <sstream>
8*bf2c3715SXin Li using namespace Eigen;
9*bf2c3715SXin Li 
10*bf2c3715SXin Li std::map<std::string,Array<float,1,8,DontAlign|RowMajor> > results;
11*bf2c3715SXin Li std::vector<std::string> labels;
12*bf2c3715SXin Li std::vector<Array2i> sizes;
13*bf2c3715SXin Li 
14*bf2c3715SXin Li template<typename Solver,typename MatrixType>
15*bf2c3715SXin Li EIGEN_DONT_INLINE
compute_norm_equation(Solver & solver,const MatrixType & A)16*bf2c3715SXin Li void compute_norm_equation(Solver &solver, const MatrixType &A) {
17*bf2c3715SXin Li   if(A.rows()!=A.cols())
18*bf2c3715SXin Li     solver.compute(A.transpose()*A);
19*bf2c3715SXin Li   else
20*bf2c3715SXin Li     solver.compute(A);
21*bf2c3715SXin Li }
22*bf2c3715SXin Li 
23*bf2c3715SXin Li template<typename Solver,typename MatrixType>
24*bf2c3715SXin Li EIGEN_DONT_INLINE
compute(Solver & solver,const MatrixType & A)25*bf2c3715SXin Li void compute(Solver &solver, const MatrixType &A) {
26*bf2c3715SXin Li   solver.compute(A);
27*bf2c3715SXin Li }
28*bf2c3715SXin Li 
29*bf2c3715SXin Li template<typename Scalar,int Size>
bench(int id,int rows,int size=Size)30*bf2c3715SXin Li void bench(int id, int rows, int size = Size)
31*bf2c3715SXin Li {
32*bf2c3715SXin Li   typedef Matrix<Scalar,Dynamic,Size> Mat;
33*bf2c3715SXin Li   typedef Matrix<Scalar,Dynamic,Dynamic> MatDyn;
34*bf2c3715SXin Li   typedef Matrix<Scalar,Size,Size> MatSquare;
35*bf2c3715SXin Li   Mat A(rows,size);
36*bf2c3715SXin Li   A.setRandom();
37*bf2c3715SXin Li   if(rows==size)
38*bf2c3715SXin Li     A = A*A.adjoint();
39*bf2c3715SXin Li   BenchTimer t_llt, t_ldlt, t_lu, t_fplu, t_qr, t_cpqr, t_cod, t_fpqr, t_jsvd, t_bdcsvd;
40*bf2c3715SXin Li 
41*bf2c3715SXin Li   int svd_opt = ComputeThinU|ComputeThinV;
42*bf2c3715SXin Li 
43*bf2c3715SXin Li   int tries = 5;
44*bf2c3715SXin Li   int rep = 1000/size;
45*bf2c3715SXin Li   if(rep==0) rep = 1;
46*bf2c3715SXin Li //   rep = rep*rep;
47*bf2c3715SXin Li 
48*bf2c3715SXin Li   LLT<MatSquare> llt(size);
49*bf2c3715SXin Li   LDLT<MatSquare> ldlt(size);
50*bf2c3715SXin Li   PartialPivLU<MatSquare> lu(size);
51*bf2c3715SXin Li   FullPivLU<MatSquare> fplu(size,size);
52*bf2c3715SXin Li   HouseholderQR<Mat> qr(A.rows(),A.cols());
53*bf2c3715SXin Li   ColPivHouseholderQR<Mat> cpqr(A.rows(),A.cols());
54*bf2c3715SXin Li   CompleteOrthogonalDecomposition<Mat> cod(A.rows(),A.cols());
55*bf2c3715SXin Li   FullPivHouseholderQR<Mat> fpqr(A.rows(),A.cols());
56*bf2c3715SXin Li   JacobiSVD<MatDyn> jsvd(A.rows(),A.cols());
57*bf2c3715SXin Li   BDCSVD<MatDyn> bdcsvd(A.rows(),A.cols());
58*bf2c3715SXin Li 
59*bf2c3715SXin Li   BENCH(t_llt, tries, rep, compute_norm_equation(llt,A));
60*bf2c3715SXin Li   BENCH(t_ldlt, tries, rep, compute_norm_equation(ldlt,A));
61*bf2c3715SXin Li   BENCH(t_lu, tries, rep, compute_norm_equation(lu,A));
62*bf2c3715SXin Li   if(size<=1000)
63*bf2c3715SXin Li     BENCH(t_fplu, tries, rep, compute_norm_equation(fplu,A));
64*bf2c3715SXin Li   BENCH(t_qr, tries, rep, compute(qr,A));
65*bf2c3715SXin Li   BENCH(t_cpqr, tries, rep, compute(cpqr,A));
66*bf2c3715SXin Li   BENCH(t_cod, tries, rep, compute(cod,A));
67*bf2c3715SXin Li   if(size*rows<=10000000)
68*bf2c3715SXin Li     BENCH(t_fpqr, tries, rep, compute(fpqr,A));
69*bf2c3715SXin Li   if(size<500) // JacobiSVD is really too slow for too large matrices
70*bf2c3715SXin Li     BENCH(t_jsvd, tries, rep, jsvd.compute(A,svd_opt));
71*bf2c3715SXin Li //   if(size*rows<=20000000)
72*bf2c3715SXin Li     BENCH(t_bdcsvd, tries, rep, bdcsvd.compute(A,svd_opt));
73*bf2c3715SXin Li 
74*bf2c3715SXin Li   results["LLT"][id] = t_llt.best();
75*bf2c3715SXin Li   results["LDLT"][id] = t_ldlt.best();
76*bf2c3715SXin Li   results["PartialPivLU"][id] = t_lu.best();
77*bf2c3715SXin Li   results["FullPivLU"][id] = t_fplu.best();
78*bf2c3715SXin Li   results["HouseholderQR"][id] = t_qr.best();
79*bf2c3715SXin Li   results["ColPivHouseholderQR"][id] = t_cpqr.best();
80*bf2c3715SXin Li   results["CompleteOrthogonalDecomposition"][id] = t_cod.best();
81*bf2c3715SXin Li   results["FullPivHouseholderQR"][id] = t_fpqr.best();
82*bf2c3715SXin Li   results["JacobiSVD"][id] = t_jsvd.best();
83*bf2c3715SXin Li   results["BDCSVD"][id] = t_bdcsvd.best();
84*bf2c3715SXin Li }
85*bf2c3715SXin Li 
86*bf2c3715SXin Li 
main()87*bf2c3715SXin Li int main()
88*bf2c3715SXin Li {
89*bf2c3715SXin Li   labels.push_back("LLT");
90*bf2c3715SXin Li   labels.push_back("LDLT");
91*bf2c3715SXin Li   labels.push_back("PartialPivLU");
92*bf2c3715SXin Li   labels.push_back("FullPivLU");
93*bf2c3715SXin Li   labels.push_back("HouseholderQR");
94*bf2c3715SXin Li   labels.push_back("ColPivHouseholderQR");
95*bf2c3715SXin Li   labels.push_back("CompleteOrthogonalDecomposition");
96*bf2c3715SXin Li   labels.push_back("FullPivHouseholderQR");
97*bf2c3715SXin Li   labels.push_back("JacobiSVD");
98*bf2c3715SXin Li   labels.push_back("BDCSVD");
99*bf2c3715SXin Li 
100*bf2c3715SXin Li   for(int i=0; i<labels.size(); ++i)
101*bf2c3715SXin Li     results[labels[i]].fill(-1);
102*bf2c3715SXin Li 
103*bf2c3715SXin Li   const int small = 8;
104*bf2c3715SXin Li   sizes.push_back(Array2i(small,small));
105*bf2c3715SXin Li   sizes.push_back(Array2i(100,100));
106*bf2c3715SXin Li   sizes.push_back(Array2i(1000,1000));
107*bf2c3715SXin Li   sizes.push_back(Array2i(4000,4000));
108*bf2c3715SXin Li   sizes.push_back(Array2i(10000,small));
109*bf2c3715SXin Li   sizes.push_back(Array2i(10000,100));
110*bf2c3715SXin Li   sizes.push_back(Array2i(10000,1000));
111*bf2c3715SXin Li   sizes.push_back(Array2i(10000,4000));
112*bf2c3715SXin Li 
113*bf2c3715SXin Li   using namespace std;
114*bf2c3715SXin Li 
115*bf2c3715SXin Li   for(int k=0; k<sizes.size(); ++k)
116*bf2c3715SXin Li   {
117*bf2c3715SXin Li     cout << sizes[k](0) << "x" << sizes[k](1) << "...\n";
118*bf2c3715SXin Li     bench<float,Dynamic>(k,sizes[k](0),sizes[k](1));
119*bf2c3715SXin Li   }
120*bf2c3715SXin Li 
121*bf2c3715SXin Li   cout.width(32);
122*bf2c3715SXin Li   cout << "solver/size";
123*bf2c3715SXin Li   cout << "  ";
124*bf2c3715SXin Li   for(int k=0; k<sizes.size(); ++k)
125*bf2c3715SXin Li   {
126*bf2c3715SXin Li     std::stringstream ss;
127*bf2c3715SXin Li     ss << sizes[k](0) << "x" << sizes[k](1);
128*bf2c3715SXin Li     cout.width(10); cout << ss.str(); cout << " ";
129*bf2c3715SXin Li   }
130*bf2c3715SXin Li   cout << endl;
131*bf2c3715SXin Li 
132*bf2c3715SXin Li 
133*bf2c3715SXin Li   for(int i=0; i<labels.size(); ++i)
134*bf2c3715SXin Li   {
135*bf2c3715SXin Li     cout.width(32); cout << labels[i]; cout << "  ";
136*bf2c3715SXin Li     ArrayXf r = (results[labels[i]]*100000.f).floor()/100.f;
137*bf2c3715SXin Li     for(int k=0; k<sizes.size(); ++k)
138*bf2c3715SXin Li     {
139*bf2c3715SXin Li       cout.width(10);
140*bf2c3715SXin Li       if(r(k)>=1e6)  cout << "-";
141*bf2c3715SXin Li       else           cout << r(k);
142*bf2c3715SXin Li       cout << " ";
143*bf2c3715SXin Li     }
144*bf2c3715SXin Li     cout << endl;
145*bf2c3715SXin Li   }
146*bf2c3715SXin Li 
147*bf2c3715SXin Li   // HTML output
148*bf2c3715SXin Li   cout << "<table class=\"manual\">" << endl;
149*bf2c3715SXin Li   cout << "<tr><th>solver/size</th>" << endl;
150*bf2c3715SXin Li   for(int k=0; k<sizes.size(); ++k)
151*bf2c3715SXin Li     cout << "  <th>" << sizes[k](0) << "x" << sizes[k](1) << "</th>";
152*bf2c3715SXin Li   cout << "</tr>" << endl;
153*bf2c3715SXin Li   for(int i=0; i<labels.size(); ++i)
154*bf2c3715SXin Li   {
155*bf2c3715SXin Li     cout << "<tr";
156*bf2c3715SXin Li     if(i%2==1) cout << " class=\"alt\"";
157*bf2c3715SXin Li     cout << "><td>" << labels[i] << "</td>";
158*bf2c3715SXin Li     ArrayXf r = (results[labels[i]]*100000.f).floor()/100.f;
159*bf2c3715SXin Li     for(int k=0; k<sizes.size(); ++k)
160*bf2c3715SXin Li     {
161*bf2c3715SXin Li       if(r(k)>=1e6) cout << "<td>-</td>";
162*bf2c3715SXin Li       else
163*bf2c3715SXin Li       {
164*bf2c3715SXin Li         cout << "<td>" << r(k);
165*bf2c3715SXin Li         if(i>0)
166*bf2c3715SXin Li           cout << " (x" << numext::round(10.f*results[labels[i]](k)/results["LLT"](k))/10.f << ")";
167*bf2c3715SXin Li         if(i<4 && sizes[k](0)!=sizes[k](1))
168*bf2c3715SXin Li           cout << " <sup><a href=\"#note_ls\">*</a></sup>";
169*bf2c3715SXin Li         cout << "</td>";
170*bf2c3715SXin Li       }
171*bf2c3715SXin Li     }
172*bf2c3715SXin Li     cout << "</tr>" << endl;
173*bf2c3715SXin Li   }
174*bf2c3715SXin Li   cout << "</table>" << endl;
175*bf2c3715SXin Li 
176*bf2c3715SXin Li //   cout << "LLT                             (ms)  " << (results["LLT"]*1000.).format(fmt) << "\n";
177*bf2c3715SXin Li //   cout << "LDLT                             (%)  " << (results["LDLT"]/results["LLT"]).format(fmt) << "\n";
178*bf2c3715SXin Li //   cout << "PartialPivLU                     (%)  " << (results["PartialPivLU"]/results["LLT"]).format(fmt) << "\n";
179*bf2c3715SXin Li //   cout << "FullPivLU                        (%)  " << (results["FullPivLU"]/results["LLT"]).format(fmt) << "\n";
180*bf2c3715SXin Li //   cout << "HouseholderQR                    (%)  " << (results["HouseholderQR"]/results["LLT"]).format(fmt) << "\n";
181*bf2c3715SXin Li //   cout << "ColPivHouseholderQR              (%)  " << (results["ColPivHouseholderQR"]/results["LLT"]).format(fmt) << "\n";
182*bf2c3715SXin Li //   cout << "CompleteOrthogonalDecomposition  (%)  " << (results["CompleteOrthogonalDecomposition"]/results["LLT"]).format(fmt) << "\n";
183*bf2c3715SXin Li //   cout << "FullPivHouseholderQR             (%)  " << (results["FullPivHouseholderQR"]/results["LLT"]).format(fmt) << "\n";
184*bf2c3715SXin Li //   cout << "JacobiSVD                        (%)  " << (results["JacobiSVD"]/results["LLT"]).format(fmt) << "\n";
185*bf2c3715SXin Li //   cout << "BDCSVD                           (%)  " << (results["BDCSVD"]/results["LLT"]).format(fmt) << "\n";
186*bf2c3715SXin Li }
187