1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han
15*5f39d1b3SJooyung Han #ifdef __APPLE__
16*5f39d1b3SJooyung Han #include <sys/time.h>
17*5f39d1b3SJooyung Han #endif
18*5f39d1b3SJooyung Han
19*5f39d1b3SJooyung Han #include <cstdint>
20*5f39d1b3SJooyung Han #include <cstdlib>
21*5f39d1b3SJooyung Han #include <ctime>
22*5f39d1b3SJooyung Han #include <iostream>
23*5f39d1b3SJooyung Han #include <map>
24*5f39d1b3SJooyung Han #include <vector>
25*5f39d1b3SJooyung Han #ifdef __APPLE__
26*5f39d1b3SJooyung Han #include <TargetConditionals.h>
27*5f39d1b3SJooyung Han #endif
28*5f39d1b3SJooyung Han
29*5f39d1b3SJooyung Han #include "test.h"
30*5f39d1b3SJooyung Han
31*5f39d1b3SJooyung Han #ifndef GEMMLOWP_TEST_BIT_DEPTH_PARAMS
32*5f39d1b3SJooyung Han #define GEMMLOWP_TEST_BIT_DEPTH_PARAMS DefaultL8R8BitDepthParams
33*5f39d1b3SJooyung Han #endif
34*5f39d1b3SJooyung Han
35*5f39d1b3SJooyung Han #if defined(__arm__) && !defined(GEMMLOWP_NEON)
36*5f39d1b3SJooyung Han #warning "Building without NEON support on ARM, check your compiler setup!"
37*5f39d1b3SJooyung Han #endif
38*5f39d1b3SJooyung Han
39*5f39d1b3SJooyung Han #if defined(__mips) && !defined(GEMMLOWP_MSA)
40*5f39d1b3SJooyung Han #warning "Building without MSA support on MIPS, check your compiler setup!"
41*5f39d1b3SJooyung Han #endif
42*5f39d1b3SJooyung Han
43*5f39d1b3SJooyung Han #if defined(__AVX2__) && !defined(GEMMLOWP_AVX2)
44*5f39d1b3SJooyung Han #warning \
45*5f39d1b3SJooyung Han "Building without AVX2 support on AVX2 enabled machine, check your compiler setup!"
46*5f39d1b3SJooyung Han #endif
47*5f39d1b3SJooyung Han
48*5f39d1b3SJooyung Han #if defined(__SSE4_2__) && !defined(GEMMLOWP_AVX2) && !defined(GEMMLOWP_SSE4)
49*5f39d1b3SJooyung Han #warning \
50*5f39d1b3SJooyung Han "Building without SSE4.2 support on SSE4.2 enabled machine, check your compiler setup!"
51*5f39d1b3SJooyung Han #endif
52*5f39d1b3SJooyung Han
53*5f39d1b3SJooyung Han namespace gemmlowp {
54*5f39d1b3SJooyung Han
55*5f39d1b3SJooyung Han const double min_accurate_duration = 1e-1;
56*5f39d1b3SJooyung Han const std::size_t min_working_set_size = 16 * 1024 * 1024;
57*5f39d1b3SJooyung Han
58*5f39d1b3SJooyung Han struct gemm_t {
59*5f39d1b3SJooyung Han int rows, depth, cols;
gemm_tgemmlowp::gemm_t60*5f39d1b3SJooyung Han gemm_t() : rows(0), depth(0), cols(0) {}
gemm_tgemmlowp::gemm_t61*5f39d1b3SJooyung Han gemm_t(int r, int d, int c) : rows(r), depth(d), cols(c) {}
62*5f39d1b3SJooyung Han };
63*5f39d1b3SJooyung Han
operator <(const gemm_t & a,const gemm_t & b)64*5f39d1b3SJooyung Han bool operator<(const gemm_t& a, const gemm_t& b) {
65*5f39d1b3SJooyung Han return a.rows < b.rows ||
66*5f39d1b3SJooyung Han (a.rows <= b.rows &&
67*5f39d1b3SJooyung Han (a.depth < b.depth || (a.depth <= b.depth && (a.cols < b.cols))));
68*5f39d1b3SJooyung Han }
69*5f39d1b3SJooyung Han
70*5f39d1b3SJooyung Han template <typename LhsType, typename RhsType, typename ResultType>
time_for_gemms(GemmContext * context,const std::vector<gemm_t> & gemms)71*5f39d1b3SJooyung Han double time_for_gemms(GemmContext* context, const std::vector<gemm_t>& gemms) {
72*5f39d1b3SJooyung Han typedef std::uint8_t Scalar;
73*5f39d1b3SJooyung Han
74*5f39d1b3SJooyung Han // set up the matrix pool
75*5f39d1b3SJooyung Han
76*5f39d1b3SJooyung Han std::size_t combined_gemm_sizes = 0;
77*5f39d1b3SJooyung Han for (auto gemm : gemms) {
78*5f39d1b3SJooyung Han int rows = gemm.rows;
79*5f39d1b3SJooyung Han int depth = gemm.depth;
80*5f39d1b3SJooyung Han int cols = gemm.cols;
81*5f39d1b3SJooyung Han combined_gemm_sizes +=
82*5f39d1b3SJooyung Han sizeof(Scalar) * (rows * depth + depth * cols + rows * cols);
83*5f39d1b3SJooyung Han }
84*5f39d1b3SJooyung Han
85*5f39d1b3SJooyung Han const std::size_t pool_size = 1 + min_working_set_size / combined_gemm_sizes;
86*5f39d1b3SJooyung Han
87*5f39d1b3SJooyung Han std::vector<LhsType> lhs(pool_size * gemms.size());
88*5f39d1b3SJooyung Han std::vector<RhsType> rhs(pool_size * gemms.size());
89*5f39d1b3SJooyung Han std::vector<ResultType> result(pool_size * gemms.size());
90*5f39d1b3SJooyung Han
91*5f39d1b3SJooyung Han for (std::size_t i = 0; i < pool_size; i++) {
92*5f39d1b3SJooyung Han for (std::size_t j = 0; j < gemms.size(); j++) {
93*5f39d1b3SJooyung Han int k = i * gemms.size() + j;
94*5f39d1b3SJooyung Han lhs[k].Resize(gemms[j].rows, gemms[j].depth);
95*5f39d1b3SJooyung Han MakeConstant(&lhs[k], 0);
96*5f39d1b3SJooyung Han rhs[k].Resize(gemms[j].depth, gemms[j].cols);
97*5f39d1b3SJooyung Han MakeConstant(&rhs[k], 0);
98*5f39d1b3SJooyung Han result[k].Resize(gemms[j].rows, gemms[j].cols);
99*5f39d1b3SJooyung Han MakeConstant(&result[k], 0);
100*5f39d1b3SJooyung Han }
101*5f39d1b3SJooyung Han }
102*5f39d1b3SJooyung Han
103*5f39d1b3SJooyung Han // main benchmark loop
104*5f39d1b3SJooyung Han
105*5f39d1b3SJooyung Han int iters_at_a_time = 1;
106*5f39d1b3SJooyung Han float time_per_iter = 0.0f;
107*5f39d1b3SJooyung Han std::size_t pool_index = 0;
108*5f39d1b3SJooyung Han
109*5f39d1b3SJooyung Han while (true) {
110*5f39d1b3SJooyung Han double starttime = real_time_in_seconds();
111*5f39d1b3SJooyung Han for (int i = 0; i < iters_at_a_time; i++) {
112*5f39d1b3SJooyung Han for (size_t j = 0; j < gemms.size(); j++) {
113*5f39d1b3SJooyung Han size_t k = pool_index * gemms.size() + j;
114*5f39d1b3SJooyung Han Gemm<std::uint8_t, GEMMLOWP_TEST_BIT_DEPTH_PARAMS>(
115*5f39d1b3SJooyung Han context, lhs[k].const_map(), rhs[k].const_map(), &result[k].map(),
116*5f39d1b3SJooyung Han -75, -91, 74980, 123, 20);
117*5f39d1b3SJooyung Han }
118*5f39d1b3SJooyung Han pool_index++;
119*5f39d1b3SJooyung Han if (pool_index == pool_size) {
120*5f39d1b3SJooyung Han pool_index = 0;
121*5f39d1b3SJooyung Han }
122*5f39d1b3SJooyung Han }
123*5f39d1b3SJooyung Han double endtime = real_time_in_seconds();
124*5f39d1b3SJooyung Han
125*5f39d1b3SJooyung Han const float timing = static_cast<float>(endtime - starttime);
126*5f39d1b3SJooyung Han
127*5f39d1b3SJooyung Han if (timing >= min_accurate_duration) {
128*5f39d1b3SJooyung Han time_per_iter = timing / iters_at_a_time;
129*5f39d1b3SJooyung Han break;
130*5f39d1b3SJooyung Han }
131*5f39d1b3SJooyung Han
132*5f39d1b3SJooyung Han iters_at_a_time *= 2;
133*5f39d1b3SJooyung Han }
134*5f39d1b3SJooyung Han
135*5f39d1b3SJooyung Han return time_per_iter;
136*5f39d1b3SJooyung Han }
137*5f39d1b3SJooyung Han
138*5f39d1b3SJooyung Han template <typename LhsType, typename RhsType, typename ResultType>
gflops_for_gemms(GemmContext * context,const std::vector<gemm_t> & gemms)139*5f39d1b3SJooyung Han double gflops_for_gemms(GemmContext* context,
140*5f39d1b3SJooyung Han const std::vector<gemm_t>& gemms) {
141*5f39d1b3SJooyung Han const double time_per_iter =
142*5f39d1b3SJooyung Han time_for_gemms<LhsType, RhsType, ResultType>(context, gemms);
143*5f39d1b3SJooyung Han double ops = 0;
144*5f39d1b3SJooyung Han for (auto gemm : gemms) {
145*5f39d1b3SJooyung Han ops += 2.0 * gemm.rows * gemm.depth * gemm.cols;
146*5f39d1b3SJooyung Han }
147*5f39d1b3SJooyung Han return 1e-9 * ops / time_per_iter;
148*5f39d1b3SJooyung Han }
149*5f39d1b3SJooyung Han
benchmark(GemmContext * context)150*5f39d1b3SJooyung Han void benchmark(GemmContext* context) {
151*5f39d1b3SJooyung Han std::map<gemm_t, std::vector<double>> benchmark_results;
152*5f39d1b3SJooyung Han
153*5f39d1b3SJooyung Han std::vector<gemm_t> benchmark_gemms;
154*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(10, 10, 10);
155*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(20, 20, 20);
156*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(30, 30, 30);
157*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(40, 40, 40);
158*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(50, 50, 50);
159*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(60, 60, 60);
160*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(64, 256, 147);
161*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(100, 100, 1);
162*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(100, 100, 100);
163*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(100, 1000, 100);
164*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(1000, 1000, 1);
165*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(1000, 1000, 10);
166*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(1000, 1000, 100);
167*5f39d1b3SJooyung Han benchmark_gemms.emplace_back(1000, 1000, 1000);
168*5f39d1b3SJooyung Han
169*5f39d1b3SJooyung Han const int repeat = 2;
170*5f39d1b3SJooyung Han
171*5f39d1b3SJooyung Han typedef Matrix<std::uint8_t, MapOrder::RowMajor> LhsType;
172*5f39d1b3SJooyung Han typedef Matrix<std::uint8_t, MapOrder::ColMajor> RhsType;
173*5f39d1b3SJooyung Han typedef Matrix<std::uint8_t, MapOrder::ColMajor> ResultType;
174*5f39d1b3SJooyung Han
175*5f39d1b3SJooyung Han #ifdef GEMMLOWP_TEST_PROFILE
176*5f39d1b3SJooyung Han gemmlowp::RegisterCurrentThreadForProfiling();
177*5f39d1b3SJooyung Han gemmlowp::StartProfiling();
178*5f39d1b3SJooyung Han #endif
179*5f39d1b3SJooyung Han
180*5f39d1b3SJooyung Han // We don't record the first repetition, it's just warm-up.
181*5f39d1b3SJooyung Han for (int r = 0; r < repeat + 1; r++) {
182*5f39d1b3SJooyung Han std::cout << "repetition " << r + 1 << "/" << repeat + 1 << "...\r"
183*5f39d1b3SJooyung Han << std::flush;
184*5f39d1b3SJooyung Han for (auto gemm : benchmark_gemms) {
185*5f39d1b3SJooyung Han double gflops = 0;
186*5f39d1b3SJooyung Han std::vector<gemm_t> unique_gemm;
187*5f39d1b3SJooyung Han unique_gemm.push_back(gemm);
188*5f39d1b3SJooyung Han gflops =
189*5f39d1b3SJooyung Han gflops_for_gemms<LhsType, RhsType, ResultType>(context, unique_gemm);
190*5f39d1b3SJooyung Han if (r > 0) {
191*5f39d1b3SJooyung Han benchmark_results[gemm].emplace_back(gflops);
192*5f39d1b3SJooyung Han }
193*5f39d1b3SJooyung Han }
194*5f39d1b3SJooyung Han }
195*5f39d1b3SJooyung Han
196*5f39d1b3SJooyung Han #ifdef GEMMLOWP_TEST_PROFILE
197*5f39d1b3SJooyung Han gemmlowp::FinishProfiling();
198*5f39d1b3SJooyung Han #endif
199*5f39d1b3SJooyung Han
200*5f39d1b3SJooyung Han std::cout << " \r"
201*5f39d1b3SJooyung Han << std::flush;
202*5f39d1b3SJooyung Han
203*5f39d1b3SJooyung Han std::cout.precision(4);
204*5f39d1b3SJooyung Han
205*5f39d1b3SJooyung Han for (auto b : benchmark_results) {
206*5f39d1b3SJooyung Han sort(b.second.begin(), b.second.end());
207*5f39d1b3SJooyung Han std::cout << b.first.rows << "x" << b.first.depth << "x" << b.first.cols
208*5f39d1b3SJooyung Han << " : " << b.second.back() << " GFlops/s" << std::endl;
209*5f39d1b3SJooyung Han }
210*5f39d1b3SJooyung Han std::cout << std::endl;
211*5f39d1b3SJooyung Han }
212*5f39d1b3SJooyung Han
benchmark_gemm_sizes(GemmContext * context,const std::vector<gemm_t> & gemms,double mintime)213*5f39d1b3SJooyung Han void benchmark_gemm_sizes(GemmContext* context,
214*5f39d1b3SJooyung Han const std::vector<gemm_t>& gemms, double mintime) {
215*5f39d1b3SJooyung Han typedef Matrix<std::uint8_t, MapOrder::RowMajor> LhsType;
216*5f39d1b3SJooyung Han typedef Matrix<std::uint8_t, MapOrder::ColMajor> RhsType;
217*5f39d1b3SJooyung Han typedef Matrix<std::uint8_t, MapOrder::ColMajor> ResultType;
218*5f39d1b3SJooyung Han
219*5f39d1b3SJooyung Han std::vector<float> gemm_times;
220*5f39d1b3SJooyung Han std::cout << "running for " << mintime << " seconds..." << std::endl;
221*5f39d1b3SJooyung Han
222*5f39d1b3SJooyung Han #ifdef GEMMLOWP_TEST_PROFILE
223*5f39d1b3SJooyung Han gemmlowp::RegisterCurrentThreadForProfiling();
224*5f39d1b3SJooyung Han gemmlowp::StartProfiling();
225*5f39d1b3SJooyung Han #endif
226*5f39d1b3SJooyung Han
227*5f39d1b3SJooyung Han double starttime = real_time_in_seconds();
228*5f39d1b3SJooyung Han while (real_time_in_seconds() < starttime + mintime) {
229*5f39d1b3SJooyung Han gemm_times.push_back(
230*5f39d1b3SJooyung Han time_for_gemms<LhsType, RhsType, ResultType>(context, gemms));
231*5f39d1b3SJooyung Han }
232*5f39d1b3SJooyung Han
233*5f39d1b3SJooyung Han #ifdef GEMMLOWP_TEST_PROFILE
234*5f39d1b3SJooyung Han gemmlowp::FinishProfiling();
235*5f39d1b3SJooyung Han #endif
236*5f39d1b3SJooyung Han
237*5f39d1b3SJooyung Han std::sort(gemm_times.begin(), gemm_times.end());
238*5f39d1b3SJooyung Han
239*5f39d1b3SJooyung Han double sum_gemm_times = 0;
240*5f39d1b3SJooyung Han double sum_gemm_times_trimmed = 0;
241*5f39d1b3SJooyung Han int count_gemm_times_trimmed = 0;
242*5f39d1b3SJooyung Han const float trim_ratio = 0.25;
243*5f39d1b3SJooyung Han const size_t count_trimmed = gemm_times.size() * trim_ratio;
244*5f39d1b3SJooyung Han double sum_gemm_times_best = 0;
245*5f39d1b3SJooyung Han int count_gemm_times_best = 0;
246*5f39d1b3SJooyung Han const float best_ratio = 0.1;
247*5f39d1b3SJooyung Han const size_t count_best = gemm_times.size() * best_ratio;
248*5f39d1b3SJooyung Han
249*5f39d1b3SJooyung Han for (size_t i = 0; i < gemm_times.size(); i++) {
250*5f39d1b3SJooyung Han sum_gemm_times += gemm_times[i];
251*5f39d1b3SJooyung Han if (i >= count_trimmed && i < gemm_times.size() - count_trimmed) {
252*5f39d1b3SJooyung Han sum_gemm_times_trimmed += gemm_times[i];
253*5f39d1b3SJooyung Han count_gemm_times_trimmed++;
254*5f39d1b3SJooyung Han }
255*5f39d1b3SJooyung Han if (i < count_best) {
256*5f39d1b3SJooyung Han sum_gemm_times_best += gemm_times[i];
257*5f39d1b3SJooyung Han count_gemm_times_best++;
258*5f39d1b3SJooyung Han }
259*5f39d1b3SJooyung Han }
260*5f39d1b3SJooyung Han
261*5f39d1b3SJooyung Han const double min_latency = gemm_times.front();
262*5f39d1b3SJooyung Han const double max_latency = gemm_times.back();
263*5f39d1b3SJooyung Han const double mean_latency = sum_gemm_times / gemm_times.size();
264*5f39d1b3SJooyung Han const double trimmed_mean_latency =
265*5f39d1b3SJooyung Han sum_gemm_times_trimmed / count_gemm_times_trimmed;
266*5f39d1b3SJooyung Han const double best_mean_latency = sum_gemm_times_best / count_gemm_times_best;
267*5f39d1b3SJooyung Han
268*5f39d1b3SJooyung Han std::cout << "Graph latency (over " << gemm_times.size()
269*5f39d1b3SJooyung Han << " iterations):" << std::endl;
270*5f39d1b3SJooyung Han std::cout << " Best: " << min_latency << "s" << std::endl;
271*5f39d1b3SJooyung Han std::cout << " Worst: " << max_latency << "s" << std::endl;
272*5f39d1b3SJooyung Han std::cout << " Mean: " << mean_latency << "s" << std::endl;
273*5f39d1b3SJooyung Han std::cout << " " << 100 * trim_ratio
274*5f39d1b3SJooyung Han << "% trimmed mean: " << trimmed_mean_latency << "s" << std::endl;
275*5f39d1b3SJooyung Han std::cout << " Mean of " << 100 * best_ratio
276*5f39d1b3SJooyung Han << "% best: " << best_mean_latency << "s" << std::endl;
277*5f39d1b3SJooyung Han }
278*5f39d1b3SJooyung Han
benchmark_googlenet(GemmContext * context)279*5f39d1b3SJooyung Han void benchmark_googlenet(GemmContext* context) {
280*5f39d1b3SJooyung Han // These are the m, n, k sizes for a typical GoogLeNet.
281*5f39d1b3SJooyung Han const int googlenet_gemm_sizes[] = {
282*5f39d1b3SJooyung Han 12544, 64, 147, 3136, 64, 64, 3136, 192, 576, 784, 64, 192,
283*5f39d1b3SJooyung Han 784, 96, 192, 784, 128, 864, 784, 16, 192, 784, 32, 400,
284*5f39d1b3SJooyung Han 784, 32, 192, 784, 128, 256, 784, 128, 256, 784, 192, 1152,
285*5f39d1b3SJooyung Han 784, 32, 256, 784, 96, 800, 784, 64, 256, 196, 192, 480,
286*5f39d1b3SJooyung Han 196, 96, 480, 196, 204, 864, 196, 16, 480, 196, 48, 400,
287*5f39d1b3SJooyung Han 196, 64, 480, 196, 160, 508, 196, 112, 508, 196, 224, 1008,
288*5f39d1b3SJooyung Han 196, 24, 508, 196, 64, 600, 196, 64, 508, 196, 128, 512,
289*5f39d1b3SJooyung Han 196, 128, 512, 196, 256, 1152, 196, 24, 512, 196, 64, 600,
290*5f39d1b3SJooyung Han 196, 64, 512, 196, 112, 512, 196, 144, 512, 196, 288, 1296,
291*5f39d1b3SJooyung Han 196, 32, 512, 196, 64, 800, 196, 64, 512, 196, 256, 528,
292*5f39d1b3SJooyung Han 196, 160, 528, 196, 320, 1440, 196, 32, 528, 196, 128, 800,
293*5f39d1b3SJooyung Han 196, 128, 528, 49, 256, 832, 49, 160, 832, 49, 320, 1440,
294*5f39d1b3SJooyung Han 49, 48, 832, 49, 128, 1200, 49, 128, 832, 49, 384, 832,
295*5f39d1b3SJooyung Han 49, 192, 832, 49, 384, 1728, 49, 48, 832, 49, 128, 1200,
296*5f39d1b3SJooyung Han 49, 128, 832, 16, 128, 508, 1, 1024, 2048, 1, 1008, 1024,
297*5f39d1b3SJooyung Han 16, 128, 528, 1, 1024, 2048, 1, 1008, 1024, 1, 1008, 1024,
298*5f39d1b3SJooyung Han };
299*5f39d1b3SJooyung Han assert(sizeof(googlenet_gemm_sizes) % (3 * sizeof(googlenet_gemm_sizes[0])) ==
300*5f39d1b3SJooyung Han 0);
301*5f39d1b3SJooyung Han const std::size_t num_googlenet_gemms =
302*5f39d1b3SJooyung Han sizeof(googlenet_gemm_sizes) / (3 * sizeof(googlenet_gemm_sizes[0]));
303*5f39d1b3SJooyung Han
304*5f39d1b3SJooyung Han std::vector<gemm_t> googlenet_gemms(num_googlenet_gemms);
305*5f39d1b3SJooyung Han for (std::size_t i = 0; i < num_googlenet_gemms; i++) {
306*5f39d1b3SJooyung Han googlenet_gemms[i].rows = googlenet_gemm_sizes[3 * i + 1];
307*5f39d1b3SJooyung Han googlenet_gemms[i].depth = googlenet_gemm_sizes[3 * i + 2];
308*5f39d1b3SJooyung Han googlenet_gemms[i].cols = googlenet_gemm_sizes[3 * i + 0];
309*5f39d1b3SJooyung Han }
310*5f39d1b3SJooyung Han
311*5f39d1b3SJooyung Han const double mintime = 20.0;
312*5f39d1b3SJooyung Han benchmark_gemm_sizes(context, googlenet_gemms, mintime);
313*5f39d1b3SJooyung Han }
314*5f39d1b3SJooyung Han
benchmark_small_model(GemmContext * context)315*5f39d1b3SJooyung Han void benchmark_small_model(GemmContext* context) {
316*5f39d1b3SJooyung Han // These are the m, n, k sizes for a small model with large batches.
317*5f39d1b3SJooyung Han const int small_model_gemm_sizes[] = {
318*5f39d1b3SJooyung Han 29232, 16, 25, 7308, 6, 400, 203, 3002, 216,
319*5f39d1b3SJooyung Han };
320*5f39d1b3SJooyung Han assert(sizeof(small_model_gemm_sizes) %
321*5f39d1b3SJooyung Han (3 * sizeof(small_model_gemm_sizes[0])) ==
322*5f39d1b3SJooyung Han 0);
323*5f39d1b3SJooyung Han const std::size_t num_small_model_gemms =
324*5f39d1b3SJooyung Han sizeof(small_model_gemm_sizes) / (3 * sizeof(small_model_gemm_sizes[0]));
325*5f39d1b3SJooyung Han
326*5f39d1b3SJooyung Han std::vector<gemm_t> small_model_gemms(num_small_model_gemms);
327*5f39d1b3SJooyung Han for (std::size_t i = 0; i < num_small_model_gemms; i++) {
328*5f39d1b3SJooyung Han small_model_gemms[i].rows = small_model_gemm_sizes[3 * i + 1];
329*5f39d1b3SJooyung Han small_model_gemms[i].depth = small_model_gemm_sizes[3 * i + 2];
330*5f39d1b3SJooyung Han small_model_gemms[i].cols = small_model_gemm_sizes[3 * i + 0];
331*5f39d1b3SJooyung Han }
332*5f39d1b3SJooyung Han
333*5f39d1b3SJooyung Han const double mintime = 10.0;
334*5f39d1b3SJooyung Han benchmark_gemm_sizes(context, small_model_gemms, mintime);
335*5f39d1b3SJooyung Han }
336*5f39d1b3SJooyung Han
benchmark_all()337*5f39d1b3SJooyung Han void benchmark_all() {
338*5f39d1b3SJooyung Han {
339*5f39d1b3SJooyung Han gemmlowp::GemmContext context;
340*5f39d1b3SJooyung Han std::cout << "Benchmarking small model GEMMs..." << std::endl;
341*5f39d1b3SJooyung Han gemmlowp::benchmark_small_model(&context);
342*5f39d1b3SJooyung Han }
343*5f39d1b3SJooyung Han
344*5f39d1b3SJooyung Han {
345*5f39d1b3SJooyung Han gemmlowp::GemmContext context;
346*5f39d1b3SJooyung Han std::cout << "Benchmarking typical GoogLeNet GEMMs..." << std::endl;
347*5f39d1b3SJooyung Han gemmlowp::benchmark_googlenet(&context);
348*5f39d1b3SJooyung Han }
349*5f39d1b3SJooyung Han
350*5f39d1b3SJooyung Han {
351*5f39d1b3SJooyung Han gemmlowp::GemmContext context;
352*5f39d1b3SJooyung Han context.set_max_num_threads(0);
353*5f39d1b3SJooyung Han std::cout << "Benchmarking multi-threaded mode..." << std::endl;
354*5f39d1b3SJooyung Han gemmlowp::benchmark(&context);
355*5f39d1b3SJooyung Han }
356*5f39d1b3SJooyung Han
357*5f39d1b3SJooyung Han {
358*5f39d1b3SJooyung Han gemmlowp::GemmContext context;
359*5f39d1b3SJooyung Han context.set_max_num_threads(1);
360*5f39d1b3SJooyung Han std::cout << "Benchmarking single-threaded mode..." << std::endl;
361*5f39d1b3SJooyung Han gemmlowp::benchmark(&context);
362*5f39d1b3SJooyung Han }
363*5f39d1b3SJooyung Han }
364*5f39d1b3SJooyung Han
365*5f39d1b3SJooyung Han } // end namespace gemmlowp
366*5f39d1b3SJooyung Han
367*5f39d1b3SJooyung Han // For iOS, we need to define our own main(), so skip it here.
368*5f39d1b3SJooyung Han #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()369*5f39d1b3SJooyung Han int main() { gemmlowp::benchmark_all(); }
370*5f39d1b3SJooyung Han #endif
371