xref: /aosp_15_r20/external/eigen/bench/product_threshold.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1*bf2c3715SXin Li 
2*bf2c3715SXin Li #include <iostream>
3*bf2c3715SXin Li #include <Eigen/Core>
4*bf2c3715SXin Li #include <bench/BenchTimer.h>
5*bf2c3715SXin Li 
6*bf2c3715SXin Li using namespace Eigen;
7*bf2c3715SXin Li using namespace std;
8*bf2c3715SXin Li 
9*bf2c3715SXin Li #define END 9
10*bf2c3715SXin Li 
11*bf2c3715SXin Li template<int S> struct map_size { enum { ret = S }; };
12*bf2c3715SXin Li template<>  struct map_size<10> { enum { ret = 20 }; };
13*bf2c3715SXin Li template<>  struct map_size<11> { enum { ret = 50 }; };
14*bf2c3715SXin Li template<>  struct map_size<12> { enum { ret = 100 }; };
15*bf2c3715SXin Li template<>  struct map_size<13> { enum { ret = 300 }; };
16*bf2c3715SXin Li 
17*bf2c3715SXin Li template<int M, int N,int K> struct alt_prod
18*bf2c3715SXin Li {
19*bf2c3715SXin Li   enum {
20*bf2c3715SXin Li     ret = M==1 && N==1 ? InnerProduct
21*bf2c3715SXin Li         : K==1 ? OuterProduct
22*bf2c3715SXin Li         : M==1 ? GemvProduct
23*bf2c3715SXin Li         : N==1 ? GemvProduct
24*bf2c3715SXin Li         : GemmProduct
25*bf2c3715SXin Li   };
26*bf2c3715SXin Li };
27*bf2c3715SXin Li 
print_mode(int mode)28*bf2c3715SXin Li void print_mode(int mode)
29*bf2c3715SXin Li {
30*bf2c3715SXin Li   if(mode==InnerProduct) std::cout << "i";
31*bf2c3715SXin Li   if(mode==OuterProduct) std::cout << "o";
32*bf2c3715SXin Li   if(mode==CoeffBasedProductMode) std::cout << "c";
33*bf2c3715SXin Li   if(mode==LazyCoeffBasedProductMode) std::cout << "l";
34*bf2c3715SXin Li   if(mode==GemvProduct) std::cout << "v";
35*bf2c3715SXin Li   if(mode==GemmProduct) std::cout << "m";
36*bf2c3715SXin Li }
37*bf2c3715SXin Li 
38*bf2c3715SXin Li template<int Mode, typename Lhs, typename Rhs, typename Res>
prod(const Lhs & a,const Rhs & b,Res & c)39*bf2c3715SXin Li EIGEN_DONT_INLINE void prod(const Lhs& a, const Rhs& b, Res& c)
40*bf2c3715SXin Li {
41*bf2c3715SXin Li   c.noalias() += typename ProductReturnType<Lhs,Rhs,Mode>::Type(a,b);
42*bf2c3715SXin Li }
43*bf2c3715SXin Li 
44*bf2c3715SXin Li template<int M, int N, int K, typename Scalar, int Mode>
bench_prod()45*bf2c3715SXin Li EIGEN_DONT_INLINE void bench_prod()
46*bf2c3715SXin Li {
47*bf2c3715SXin Li   typedef Matrix<Scalar,M,K> Lhs; Lhs a; a.setRandom();
48*bf2c3715SXin Li   typedef Matrix<Scalar,K,N> Rhs; Rhs b; b.setRandom();
49*bf2c3715SXin Li   typedef Matrix<Scalar,M,N> Res; Res c; c.setRandom();
50*bf2c3715SXin Li 
51*bf2c3715SXin Li   BenchTimer t;
52*bf2c3715SXin Li   double n = 2.*double(M)*double(N)*double(K);
53*bf2c3715SXin Li   int rep = 100000./n;
54*bf2c3715SXin Li   rep /= 2;
55*bf2c3715SXin Li   if(rep<1) rep = 1;
56*bf2c3715SXin Li   do {
57*bf2c3715SXin Li     rep *= 2;
58*bf2c3715SXin Li     t.reset();
59*bf2c3715SXin Li     BENCH(t,1,rep,prod<CoeffBasedProductMode>(a,b,c));
60*bf2c3715SXin Li   } while(t.best()<0.1);
61*bf2c3715SXin Li 
62*bf2c3715SXin Li   t.reset();
63*bf2c3715SXin Li   BENCH(t,5,rep,prod<Mode>(a,b,c));
64*bf2c3715SXin Li 
65*bf2c3715SXin Li   print_mode(Mode);
66*bf2c3715SXin Li   std::cout << int(1e-6*n*rep/t.best()) << "\t";
67*bf2c3715SXin Li }
68*bf2c3715SXin Li 
69*bf2c3715SXin Li template<int N> struct print_n;
70*bf2c3715SXin Li template<int M, int N, int K> struct loop_on_m;
71*bf2c3715SXin Li template<int M, int N, int K, typename Scalar, int Mode> struct loop_on_n;
72*bf2c3715SXin Li 
73*bf2c3715SXin Li template<int M, int N, int K>
74*bf2c3715SXin Li struct loop_on_k
75*bf2c3715SXin Li {
runloop_on_k76*bf2c3715SXin Li   static void run()
77*bf2c3715SXin Li   {
78*bf2c3715SXin Li     std::cout << "K=" << K << "\t";
79*bf2c3715SXin Li     print_n<N>::run();
80*bf2c3715SXin Li     std::cout << "\n";
81*bf2c3715SXin Li 
82*bf2c3715SXin Li     loop_on_m<M,N,K>::run();
83*bf2c3715SXin Li     std::cout << "\n\n";
84*bf2c3715SXin Li 
85*bf2c3715SXin Li     loop_on_k<M,N,K+1>::run();
86*bf2c3715SXin Li   }
87*bf2c3715SXin Li };
88*bf2c3715SXin Li 
89*bf2c3715SXin Li template<int M, int N>
runloop_on_k90*bf2c3715SXin Li struct loop_on_k<M,N,END> { static void run(){} };
91*bf2c3715SXin Li 
92*bf2c3715SXin Li 
93*bf2c3715SXin Li template<int M, int N, int K>
94*bf2c3715SXin Li struct loop_on_m
95*bf2c3715SXin Li {
runloop_on_m96*bf2c3715SXin Li   static void run()
97*bf2c3715SXin Li   {
98*bf2c3715SXin Li     std::cout << M << "f\t";
99*bf2c3715SXin Li     loop_on_n<M,N,K,float,CoeffBasedProductMode>::run();
100*bf2c3715SXin Li     std::cout << "\n";
101*bf2c3715SXin Li 
102*bf2c3715SXin Li     std::cout << M << "f\t";
103*bf2c3715SXin Li     loop_on_n<M,N,K,float,-1>::run();
104*bf2c3715SXin Li     std::cout << "\n";
105*bf2c3715SXin Li 
106*bf2c3715SXin Li     loop_on_m<M+1,N,K>::run();
107*bf2c3715SXin Li   }
108*bf2c3715SXin Li };
109*bf2c3715SXin Li 
110*bf2c3715SXin Li template<int N, int K>
runloop_on_m111*bf2c3715SXin Li struct loop_on_m<END,N,K> { static void run(){} };
112*bf2c3715SXin Li 
113*bf2c3715SXin Li template<int M, int N, int K, typename Scalar, int Mode>
114*bf2c3715SXin Li struct loop_on_n
115*bf2c3715SXin Li {
runloop_on_n116*bf2c3715SXin Li   static void run()
117*bf2c3715SXin Li   {
118*bf2c3715SXin Li     bench_prod<M,N,K,Scalar,Mode==-1? alt_prod<M,N,K>::ret : Mode>();
119*bf2c3715SXin Li 
120*bf2c3715SXin Li     loop_on_n<M,N+1,K,Scalar,Mode>::run();
121*bf2c3715SXin Li   }
122*bf2c3715SXin Li };
123*bf2c3715SXin Li 
124*bf2c3715SXin Li template<int M, int K, typename Scalar, int Mode>
runloop_on_n125*bf2c3715SXin Li struct loop_on_n<M,END,K,Scalar,Mode> { static void run(){} };
126*bf2c3715SXin Li 
127*bf2c3715SXin Li template<int N> struct print_n
128*bf2c3715SXin Li {
runprint_n129*bf2c3715SXin Li   static void run()
130*bf2c3715SXin Li   {
131*bf2c3715SXin Li     std::cout << map_size<N>::ret << "\t";
132*bf2c3715SXin Li     print_n<N+1>::run();
133*bf2c3715SXin Li   }
134*bf2c3715SXin Li };
135*bf2c3715SXin Li 
runprint_n136*bf2c3715SXin Li template<> struct print_n<END> { static void run(){} };
137*bf2c3715SXin Li 
main()138*bf2c3715SXin Li int main()
139*bf2c3715SXin Li {
140*bf2c3715SXin Li   loop_on_k<1,1,1>::run();
141*bf2c3715SXin Li 
142*bf2c3715SXin Li   return 0;
143*bf2c3715SXin Li }
144