xref: /aosp_15_r20/external/gemmlowp/test/test.cc (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han #include "test.h"
16*5f39d1b3SJooyung Han 
17*5f39d1b3SJooyung Han #include <array>
18*5f39d1b3SJooyung Han #include <cstdint>
19*5f39d1b3SJooyung Han #include <cstdlib>
20*5f39d1b3SJooyung Han #include <ctime>
21*5f39d1b3SJooyung Han #include <iostream>
22*5f39d1b3SJooyung Han #include <memory>
23*5f39d1b3SJooyung Han #include <string>
24*5f39d1b3SJooyung Han #include <vector>
25*5f39d1b3SJooyung Han #ifdef __APPLE__
26*5f39d1b3SJooyung Han #include <TargetConditionals.h>
27*5f39d1b3SJooyung Han #endif
28*5f39d1b3SJooyung Han 
29*5f39d1b3SJooyung Han #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
30*5f39d1b3SJooyung Han #include "../internal/kernel_reference.h"
31*5f39d1b3SJooyung Han #include "test_data.h"
32*5f39d1b3SJooyung Han 
33*5f39d1b3SJooyung Han namespace gemmlowp {
34*5f39d1b3SJooyung Han 
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc)35*5f39d1b3SJooyung Han void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
36*5f39d1b3SJooyung Han                               bool transpose_c, int m, int n, int k,
37*5f39d1b3SJooyung Han                               const std::uint8_t* a, std::int32_t a_offset,
38*5f39d1b3SJooyung Han                               int lda, const std::uint8_t* b,
39*5f39d1b3SJooyung Han                               std::int32_t b_offset, int ldb, std::uint8_t* c,
40*5f39d1b3SJooyung Han                               std::int32_t c_offset, std::int32_t c_mult_int,
41*5f39d1b3SJooyung Han                               std::int32_t c_shift, int ldc) {
42*5f39d1b3SJooyung Han   ScopedProfilingLabel("ReferenceEightBitIntGemm");
43*5f39d1b3SJooyung Han   assert((c_shift >= 0) && (c_shift <= 32));
44*5f39d1b3SJooyung Han 
45*5f39d1b3SJooyung Han   assert(a != nullptr);
46*5f39d1b3SJooyung Han   assert(b != nullptr);
47*5f39d1b3SJooyung Han   assert(c != nullptr);
48*5f39d1b3SJooyung Han 
49*5f39d1b3SJooyung Han   int a_i_stride;
50*5f39d1b3SJooyung Han   int a_l_stride;
51*5f39d1b3SJooyung Han   if (transpose_a) {
52*5f39d1b3SJooyung Han     a_i_stride = lda;
53*5f39d1b3SJooyung Han     a_l_stride = 1;
54*5f39d1b3SJooyung Han   } else {
55*5f39d1b3SJooyung Han     a_i_stride = 1;
56*5f39d1b3SJooyung Han     a_l_stride = lda;
57*5f39d1b3SJooyung Han   }
58*5f39d1b3SJooyung Han   int b_j_stride;
59*5f39d1b3SJooyung Han   int b_l_stride;
60*5f39d1b3SJooyung Han   if (transpose_b) {
61*5f39d1b3SJooyung Han     b_j_stride = 1;
62*5f39d1b3SJooyung Han     b_l_stride = ldb;
63*5f39d1b3SJooyung Han   } else {
64*5f39d1b3SJooyung Han     b_j_stride = ldb;
65*5f39d1b3SJooyung Han     b_l_stride = 1;
66*5f39d1b3SJooyung Han   }
67*5f39d1b3SJooyung Han   int c_i_stride;
68*5f39d1b3SJooyung Han   int c_j_stride;
69*5f39d1b3SJooyung Han   if (transpose_c) {
70*5f39d1b3SJooyung Han     c_i_stride = ldc;
71*5f39d1b3SJooyung Han     c_j_stride = 1;
72*5f39d1b3SJooyung Han   } else {
73*5f39d1b3SJooyung Han     c_i_stride = 1;
74*5f39d1b3SJooyung Han     c_j_stride = ldc;
75*5f39d1b3SJooyung Han   }
76*5f39d1b3SJooyung Han   int i, j, l;
77*5f39d1b3SJooyung Han 
78*5f39d1b3SJooyung Han   const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
79*5f39d1b3SJooyung Han 
80*5f39d1b3SJooyung Han   for (j = 0; j < n; j++) {
81*5f39d1b3SJooyung Han     for (i = 0; i < m; i++) {
82*5f39d1b3SJooyung Han       std::int32_t total = 0;
83*5f39d1b3SJooyung Han       for (l = 0; l < k; l++) {
84*5f39d1b3SJooyung Han         const int a_index = i * a_i_stride + l * a_l_stride;
85*5f39d1b3SJooyung Han         const std::uint8_t a_as_byte = a[a_index];
86*5f39d1b3SJooyung Han         const std::int32_t a_as_int =
87*5f39d1b3SJooyung Han             static_cast<std::int32_t>(a_as_byte) + a_offset;
88*5f39d1b3SJooyung Han         const int b_index = j * b_j_stride + l * b_l_stride;
89*5f39d1b3SJooyung Han         const std::uint8_t b_as_byte = b[b_index];
90*5f39d1b3SJooyung Han         const std::int32_t b_as_int =
91*5f39d1b3SJooyung Han             static_cast<std::int32_t>(b_as_byte) + b_offset;
92*5f39d1b3SJooyung Han         const std::int32_t mult_as_int = a_as_int * b_as_int;
93*5f39d1b3SJooyung Han         total += mult_as_int;
94*5f39d1b3SJooyung Han       }
95*5f39d1b3SJooyung Han       std::int32_t output =
96*5f39d1b3SJooyung Han           (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
97*5f39d1b3SJooyung Han       if (output > 255) {
98*5f39d1b3SJooyung Han         output = 255;
99*5f39d1b3SJooyung Han       }
100*5f39d1b3SJooyung Han       if (output < 0) {
101*5f39d1b3SJooyung Han         output = 0;
102*5f39d1b3SJooyung Han       }
103*5f39d1b3SJooyung Han       const int c_index = i * c_i_stride + j * c_j_stride;
104*5f39d1b3SJooyung Han       c[c_index] = static_cast<std::uint8_t>(output);
105*5f39d1b3SJooyung Han     }
106*5f39d1b3SJooyung Han   }
107*5f39d1b3SJooyung Han }
108*5f39d1b3SJooyung Han 
109*5f39d1b3SJooyung Han typedef VectorMap<const std::int32_t, VectorShape::Col> OffsetColMap;
110*5f39d1b3SJooyung Han typedef VectorMap<const std::int32_t, VectorShape::Row> OffsetRowMap;
111*5f39d1b3SJooyung Han typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
112*5f39d1b3SJooyung Han typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
113*5f39d1b3SJooyung Han 
114*5f39d1b3SJooyung Han // *GemmWrapper's allow to wrap various Gemm functions in a uniform
115*5f39d1b3SJooyung Han // interface, so we can use the same testing code to test all of them
116*5f39d1b3SJooyung Han 
117*5f39d1b3SJooyung Han template <typename Kernel, typename Scalar, typename tBitDepthParams>
118*5f39d1b3SJooyung Han struct SingleThreadGemmWrapper {
119*5f39d1b3SJooyung Han   typedef tBitDepthParams BitDepthParams;
120*5f39d1b3SJooyung Han 
Namegemmlowp::SingleThreadGemmWrapper121*5f39d1b3SJooyung Han   static const char* Name() {
122*5f39d1b3SJooyung Han     static char buf[256];
123*5f39d1b3SJooyung Han     snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
124*5f39d1b3SJooyung Han     return buf;
125*5f39d1b3SJooyung Han   }
126*5f39d1b3SJooyung Han 
127*5f39d1b3SJooyung Han   typedef SingleThreadGemmContext Context;
128*5f39d1b3SJooyung Han 
129*5f39d1b3SJooyung Han   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper130*5f39d1b3SJooyung Han   static bool Gemm(Context* context,
131*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, LhsOrder>& lhs,
132*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, RhsOrder>& rhs,
133*5f39d1b3SJooyung Han                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
134*5f39d1b3SJooyung Han                    int rhs_offset, int result_offset, int result_mult_int,
135*5f39d1b3SJooyung Han                    int result_shift) {
136*5f39d1b3SJooyung Han     ScopedProfilingLabel("SingleThreadGemmWrapper::Gemm");
137*5f39d1b3SJooyung Han     const int rows = lhs.rows();
138*5f39d1b3SJooyung Han     const int cols = rhs.cols();
139*5f39d1b3SJooyung Han     if (rows < cols) {
140*5f39d1b3SJooyung Han       // SingleThreadGemm is never called with rows < cols.
141*5f39d1b3SJooyung Han       // That case is handled earlier.
142*5f39d1b3SJooyung Han       return false;
143*5f39d1b3SJooyung Han     }
144*5f39d1b3SJooyung Han     const OffsetColDup lhs_offset_vector(lhs_offset, rows);
145*5f39d1b3SJooyung Han     const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
146*5f39d1b3SJooyung Han     SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
147*5f39d1b3SJooyung Han                      LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
148*5f39d1b3SJooyung Han                      OffsetRowDup>(
149*5f39d1b3SJooyung Han         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
150*5f39d1b3SJooyung Han         rhs_offset_vector,
151*5f39d1b3SJooyung Han         MakeStandardOutputPipeline(result_offset, result_mult_int,
152*5f39d1b3SJooyung Han                                    result_shift));
153*5f39d1b3SJooyung Han     return true;
154*5f39d1b3SJooyung Han   }
155*5f39d1b3SJooyung Han };
156*5f39d1b3SJooyung Han 
157*5f39d1b3SJooyung Han template <typename Kernel, typename Scalar, typename tBitDepthParams>
158*5f39d1b3SJooyung Han struct MultiThreadGemmWrapper {
159*5f39d1b3SJooyung Han   typedef tBitDepthParams BitDepthParams;
160*5f39d1b3SJooyung Han 
Namegemmlowp::MultiThreadGemmWrapper161*5f39d1b3SJooyung Han   static const char* Name() {
162*5f39d1b3SJooyung Han     static char buf[256];
163*5f39d1b3SJooyung Han     snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
164*5f39d1b3SJooyung Han     return buf;
165*5f39d1b3SJooyung Han   }
166*5f39d1b3SJooyung Han 
167*5f39d1b3SJooyung Han   typedef MultiThreadGemmContext Context;
168*5f39d1b3SJooyung Han 
169*5f39d1b3SJooyung Han   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper170*5f39d1b3SJooyung Han   static bool Gemm(Context* context,
171*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, LhsOrder>& lhs,
172*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, RhsOrder>& rhs,
173*5f39d1b3SJooyung Han                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
174*5f39d1b3SJooyung Han                    int rhs_offset, int result_offset, int result_mult_int,
175*5f39d1b3SJooyung Han                    int result_shift) {
176*5f39d1b3SJooyung Han     ScopedProfilingLabel("MultiThreadGemmWrapper::Gemm");
177*5f39d1b3SJooyung Han     context->set_max_num_threads(0);
178*5f39d1b3SJooyung Han     const int rows = lhs.rows();
179*5f39d1b3SJooyung Han     const int cols = rhs.cols();
180*5f39d1b3SJooyung Han     if (rows < cols) {
181*5f39d1b3SJooyung Han       // SingleThreadGemm is never called with rows < cols.
182*5f39d1b3SJooyung Han       // That case is handled earlier.
183*5f39d1b3SJooyung Han       return false;
184*5f39d1b3SJooyung Han     }
185*5f39d1b3SJooyung Han     const OffsetColDup lhs_offset_vector(lhs_offset, rows);
186*5f39d1b3SJooyung Han     const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
187*5f39d1b3SJooyung Han     MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
188*5f39d1b3SJooyung Han                     LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
189*5f39d1b3SJooyung Han                     OffsetRowDup>(
190*5f39d1b3SJooyung Han         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
191*5f39d1b3SJooyung Han         rhs_offset_vector,
192*5f39d1b3SJooyung Han         MakeStandardOutputPipeline(result_offset, result_mult_int,
193*5f39d1b3SJooyung Han                                    result_shift));
194*5f39d1b3SJooyung Han     return true;
195*5f39d1b3SJooyung Han   }
196*5f39d1b3SJooyung Han };
197*5f39d1b3SJooyung Han 
198*5f39d1b3SJooyung Han template <typename Scalar, typename tBitDepthParams>
199*5f39d1b3SJooyung Han struct PublicGemmWrapper {
200*5f39d1b3SJooyung Han   typedef tBitDepthParams BitDepthParams;
201*5f39d1b3SJooyung Han 
Namegemmlowp::PublicGemmWrapper202*5f39d1b3SJooyung Han   static const char* Name() { return "public Gemm"; }
203*5f39d1b3SJooyung Han 
204*5f39d1b3SJooyung Han   typedef GemmContext Context;
205*5f39d1b3SJooyung Han 
206*5f39d1b3SJooyung Han   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper207*5f39d1b3SJooyung Han   static bool Gemm(Context* context,
208*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, LhsOrder>& lhs,
209*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, RhsOrder>& rhs,
210*5f39d1b3SJooyung Han                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
211*5f39d1b3SJooyung Han                    int rhs_offset, int result_offset, int result_mult_int,
212*5f39d1b3SJooyung Han                    int result_shift) {
213*5f39d1b3SJooyung Han     ScopedProfilingLabel("PublicGemmWrapper::Gemm");
214*5f39d1b3SJooyung Han     gemmlowp::Gemm<std::uint8_t, BitDepthParams, LhsOrder, RhsOrder,
215*5f39d1b3SJooyung Han                    ResultOrder>(context, lhs, rhs, result, lhs_offset,
216*5f39d1b3SJooyung Han                                 rhs_offset, result_offset, result_mult_int,
217*5f39d1b3SJooyung Han                                 result_shift);
218*5f39d1b3SJooyung Han     return true;
219*5f39d1b3SJooyung Han   }
220*5f39d1b3SJooyung Han };
221*5f39d1b3SJooyung Han 
222*5f39d1b3SJooyung Han template <eight_bit_int_gemm::BitDepthSetting BitDepth>
223*5f39d1b3SJooyung Han struct BitDepthParamsForSettings {};
224*5f39d1b3SJooyung Han 
225*5f39d1b3SJooyung Han template <>
226*5f39d1b3SJooyung Han struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
227*5f39d1b3SJooyung Han     : DefaultL8R8BitDepthParams {};
228*5f39d1b3SJooyung Han 
229*5f39d1b3SJooyung Han template <>
230*5f39d1b3SJooyung Han struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
231*5f39d1b3SJooyung Han     : DefaultL7R5BitDepthParams {};
232*5f39d1b3SJooyung Han 
233*5f39d1b3SJooyung Han template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
234*5f39d1b3SJooyung Han struct EightBitIntGemmWrapper {
235*5f39d1b3SJooyung Han   typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
236*5f39d1b3SJooyung Han 
Namegemmlowp::EightBitIntGemmWrapper237*5f39d1b3SJooyung Han   static const char* Name() { return "EightBitIntGemm"; }
238*5f39d1b3SJooyung Han 
239*5f39d1b3SJooyung Han   typedef void Context;
240*5f39d1b3SJooyung Han 
241*5f39d1b3SJooyung Han   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper242*5f39d1b3SJooyung Han   static bool Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
243*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, RhsOrder>& rhs,
244*5f39d1b3SJooyung Han                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
245*5f39d1b3SJooyung Han                    int rhs_offset, int result_offset, int result_mult_int,
246*5f39d1b3SJooyung Han                    int result_shift) {
247*5f39d1b3SJooyung Han     ScopedProfilingLabel("EightBitIntGemmWrapper::Gemm");
248*5f39d1b3SJooyung Han     const bool transpose_c = ResultOrder == MapOrder::RowMajor;
249*5f39d1b3SJooyung Han     const bool transpose_a = LhsOrder == MapOrder::RowMajor;
250*5f39d1b3SJooyung Han     const bool transpose_b = RhsOrder == MapOrder::RowMajor;
251*5f39d1b3SJooyung Han     eight_bit_int_gemm::EightBitIntGemm(
252*5f39d1b3SJooyung Han         transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
253*5f39d1b3SJooyung Han         lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
254*5f39d1b3SJooyung Han         rhs_offset, rhs.stride(), result->data(), result_offset,
255*5f39d1b3SJooyung Han         result_mult_int, result_shift, result->stride(), BitDepth);
256*5f39d1b3SJooyung Han     return true;
257*5f39d1b3SJooyung Han   }
258*5f39d1b3SJooyung Han };
259*5f39d1b3SJooyung Han 
260*5f39d1b3SJooyung Han template <typename Scalar>
261*5f39d1b3SJooyung Han struct ReferenceEightBitIntGemmWrapper {
262*5f39d1b3SJooyung Han   typedef DefaultL8R8BitDepthParams BitDepthParams;
263*5f39d1b3SJooyung Han 
Namegemmlowp::ReferenceEightBitIntGemmWrapper264*5f39d1b3SJooyung Han   static const char* Name() { return "ReferenceEightBitIntGemm"; }
265*5f39d1b3SJooyung Han 
266*5f39d1b3SJooyung Han   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper267*5f39d1b3SJooyung Han   static bool Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
268*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, LhsOrder>& lhs,
269*5f39d1b3SJooyung Han                    const MatrixMap<const Scalar, RhsOrder>& rhs,
270*5f39d1b3SJooyung Han                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
271*5f39d1b3SJooyung Han                    int rhs_offset, int result_offset, int result_mult_int,
272*5f39d1b3SJooyung Han                    int result_shift) {
273*5f39d1b3SJooyung Han     ScopedProfilingLabel("ReferenceEightBitIntGemmWrapper::Gemm");
274*5f39d1b3SJooyung Han     ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
275*5f39d1b3SJooyung Han                              rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
276*5f39d1b3SJooyung Han                              lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
277*5f39d1b3SJooyung Han                              result->data(), result_offset, result_mult_int,
278*5f39d1b3SJooyung Han                              result_shift, result->stride());
279*5f39d1b3SJooyung Han     return true;
280*5f39d1b3SJooyung Han   }
281*5f39d1b3SJooyung Han };
282*5f39d1b3SJooyung Han 
OrderName(MapOrder order)283*5f39d1b3SJooyung Han const char* OrderName(MapOrder order) {
284*5f39d1b3SJooyung Han   return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
285*5f39d1b3SJooyung Han }
286*5f39d1b3SJooyung Han 
287*5f39d1b3SJooyung Han struct ResultStats {
ResultStatsgemmlowp::ResultStats288*5f39d1b3SJooyung Han   ResultStats()
289*5f39d1b3SJooyung Han       : count(0),
290*5f39d1b3SJooyung Han         med_val(0),
291*5f39d1b3SJooyung Han         mean_signed_diff(0),
292*5f39d1b3SJooyung Han         med_signed_diff(0),
293*5f39d1b3SJooyung Han         med_unsigned_diff(0),
294*5f39d1b3SJooyung Han         max_unsigned_diff(0) {}
295*5f39d1b3SJooyung Han 
296*5f39d1b3SJooyung Han   int count;
297*5f39d1b3SJooyung Han   int med_val;
298*5f39d1b3SJooyung Han   float mean_signed_diff;
299*5f39d1b3SJooyung Han   int med_signed_diff;
300*5f39d1b3SJooyung Han   int med_unsigned_diff;
301*5f39d1b3SJooyung Han   int max_unsigned_diff;
302*5f39d1b3SJooyung Han 
303*5f39d1b3SJooyung Han   std::vector<int> count_diff_by_pot_slice;
304*5f39d1b3SJooyung Han };
305*5f39d1b3SJooyung Han 
GetResultStats(const std::uint8_t * actual,const std::uint8_t * expected,size_t count,ResultStats * stats)306*5f39d1b3SJooyung Han void GetResultStats(const std::uint8_t* actual, const std::uint8_t* expected,
307*5f39d1b3SJooyung Han                     size_t count, ResultStats* stats) {
308*5f39d1b3SJooyung Han   ScopedProfilingLabel("GetResultStats");
309*5f39d1b3SJooyung Han   std::vector<std::uint8_t> results;
310*5f39d1b3SJooyung Han   std::vector<std::int16_t> signed_diffs;
311*5f39d1b3SJooyung Han   std::vector<std::uint8_t> unsigned_diffs;
312*5f39d1b3SJooyung Han   std::int64_t signed_diffs_sum = 0;
313*5f39d1b3SJooyung Han   for (size_t i = 0; i < count; i++) {
314*5f39d1b3SJooyung Han     results.push_back(actual[i]);
315*5f39d1b3SJooyung Han     std::int16_t signed_diff = actual[i] - expected[i];
316*5f39d1b3SJooyung Han     signed_diffs.push_back(signed_diff);
317*5f39d1b3SJooyung Han     unsigned_diffs.push_back(std::abs(signed_diff));
318*5f39d1b3SJooyung Han     signed_diffs_sum += signed_diff;
319*5f39d1b3SJooyung Han   }
320*5f39d1b3SJooyung Han 
321*5f39d1b3SJooyung Han   std::sort(results.begin(), results.end());
322*5f39d1b3SJooyung Han   std::sort(signed_diffs.begin(), signed_diffs.end());
323*5f39d1b3SJooyung Han   std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
324*5f39d1b3SJooyung Han 
325*5f39d1b3SJooyung Han   const size_t middle = count / 2;
326*5f39d1b3SJooyung Han 
327*5f39d1b3SJooyung Han   stats->count = count;
328*5f39d1b3SJooyung Han   stats->med_val = results[middle];
329*5f39d1b3SJooyung Han   stats->mean_signed_diff = float(signed_diffs_sum) / count;
330*5f39d1b3SJooyung Han   stats->med_signed_diff = signed_diffs[middle];
331*5f39d1b3SJooyung Han   stats->med_unsigned_diff = unsigned_diffs[middle];
332*5f39d1b3SJooyung Han   stats->max_unsigned_diff = unsigned_diffs.back();
333*5f39d1b3SJooyung Han 
334*5f39d1b3SJooyung Han   // Size 9 for 9 different POT values: 2^0, ..., 2^8
335*5f39d1b3SJooyung Han   stats->count_diff_by_pot_slice.resize(9);
336*5f39d1b3SJooyung Han   auto cur = unsigned_diffs.begin();
337*5f39d1b3SJooyung Han   size_t checksum = 0;
338*5f39d1b3SJooyung Han   for (int exponent = 0; exponent < 9; exponent++) {
339*5f39d1b3SJooyung Han     int pot = 1 << exponent;
340*5f39d1b3SJooyung Han     auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
341*5f39d1b3SJooyung Han     checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
342*5f39d1b3SJooyung Han     cur = next;
343*5f39d1b3SJooyung Han   }
344*5f39d1b3SJooyung Han   assert(checksum == count);
345*5f39d1b3SJooyung Han }
346*5f39d1b3SJooyung Han 
347*5f39d1b3SJooyung Han struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds348*5f39d1b3SJooyung Han   ResultStatsBounds()
349*5f39d1b3SJooyung Han       : mean_signed_diff(0),
350*5f39d1b3SJooyung Han         med_signed_diff(0),
351*5f39d1b3SJooyung Han         med_unsigned_diff(0),
352*5f39d1b3SJooyung Han         max_unsigned_diff(0) {}
353*5f39d1b3SJooyung Han 
354*5f39d1b3SJooyung Han   float mean_signed_diff;
355*5f39d1b3SJooyung Han   int med_signed_diff;
356*5f39d1b3SJooyung Han   int med_unsigned_diff;
357*5f39d1b3SJooyung Han   int max_unsigned_diff;
358*5f39d1b3SJooyung Han };
359*5f39d1b3SJooyung Han 
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)360*5f39d1b3SJooyung Han bool CheckResultStatsBounds(const ResultStats& stats,
361*5f39d1b3SJooyung Han                             const ResultStatsBounds& bounds) {
362*5f39d1b3SJooyung Han   return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
363*5f39d1b3SJooyung Han          stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
364*5f39d1b3SJooyung Han          std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
365*5f39d1b3SJooyung Han          std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
366*5f39d1b3SJooyung Han }
367*5f39d1b3SJooyung Han 
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)368*5f39d1b3SJooyung Han void ReportResultStats(const ResultStats& stats,
369*5f39d1b3SJooyung Han                        const ResultStatsBounds& bounds) {
370*5f39d1b3SJooyung Han   printf("    number of matrix entries: %d\n", stats.count);
371*5f39d1b3SJooyung Han   printf("    median value: %d\n", stats.med_val);
372*5f39d1b3SJooyung Han   printf("    median unsigned diff: %d (tolerating %d)\n",
373*5f39d1b3SJooyung Han          stats.med_unsigned_diff, bounds.med_unsigned_diff);
374*5f39d1b3SJooyung Han   printf("    max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
375*5f39d1b3SJooyung Han          bounds.max_unsigned_diff);
376*5f39d1b3SJooyung Han   printf("    median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
377*5f39d1b3SJooyung Han          bounds.med_signed_diff);
378*5f39d1b3SJooyung Han   printf("    mean signed diff: %.3g (tolerating %.3g)\n",
379*5f39d1b3SJooyung Han          stats.mean_signed_diff, bounds.mean_signed_diff);
380*5f39d1b3SJooyung Han 
381*5f39d1b3SJooyung Han   printf("No error: %.2f %% of entries\n",
382*5f39d1b3SJooyung Han          100.f * stats.count_diff_by_pot_slice[0] / stats.count);
383*5f39d1b3SJooyung Han   for (int exponent = 1; exponent < 9; exponent++) {
384*5f39d1b3SJooyung Han     printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
385*5f39d1b3SJooyung Han            (1 << exponent) - 1,
386*5f39d1b3SJooyung Han            100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
387*5f39d1b3SJooyung Han   }
388*5f39d1b3SJooyung Han }
389*5f39d1b3SJooyung Han 
390*5f39d1b3SJooyung Han // Our approach to choosing result_shift values for testing, is bisection.
391*5f39d1b3SJooyung Han // This function takes an interval, [result_shift_min .. result_shift_max].
392*5f39d1b3SJooyung Han // If too much saturation occurred in either direction, it bisects accordingly,
393*5f39d1b3SJooyung Han // recursing until the interval contains only one value.
394*5f39d1b3SJooyung Han // The primary reason why we prefer this over computing optimal shift values,
395*5f39d1b3SJooyung Han // is that we actually want to exercise some saturation, as there is nontrivial
396*5f39d1b3SJooyung Han // code handling that in gemmlowp.
397*5f39d1b3SJooyung Han // Secondarily, this is faster than computing optimal shifts, since in 90% of
398*5f39d1b3SJooyung Han // cases the first-tried shift value 16 turns out to be good enough.
399*5f39d1b3SJooyung Han template <typename GemmWrapper, typename LhsType, typename RhsType,
400*5f39d1b3SJooyung Han           typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)401*5f39d1b3SJooyung Han void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
402*5f39d1b3SJooyung Han                     const RhsType& rhs, ResultType* result, int lhs_offset,
403*5f39d1b3SJooyung Han                     int rhs_offset, int result_offset, int result_mult_int,
404*5f39d1b3SJooyung Han                     int result_shift_min, int result_shift_max) {
405*5f39d1b3SJooyung Han   const int rows = lhs.rows();
406*5f39d1b3SJooyung Han   const int cols = rhs.cols();
407*5f39d1b3SJooyung Han   Check(lhs.cols() == rhs.rows());
408*5f39d1b3SJooyung Han   const int depth = lhs.cols();
409*5f39d1b3SJooyung Han 
410*5f39d1b3SJooyung Han   const int result_shift = (result_shift_min + result_shift_max) / 2;
411*5f39d1b3SJooyung Han 
412*5f39d1b3SJooyung Han   if (!GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(),
413*5f39d1b3SJooyung Han                          &result->map(), lhs_offset, rhs_offset, result_offset,
414*5f39d1b3SJooyung Han                          result_mult_int, result_shift)) {
415*5f39d1b3SJooyung Han     // Internal GEMM functions are not required to handle all cases
416*5f39d1b3SJooyung Han     // (e.g. rows < cols) as these are supposed to have been handled
417*5f39d1b3SJooyung Han     // ahead of them. Their test wrappers return false in that case.
418*5f39d1b3SJooyung Han     return;
419*5f39d1b3SJooyung Han   }
420*5f39d1b3SJooyung Han 
421*5f39d1b3SJooyung Han   typedef typename ResultType::Scalar Scalar;
422*5f39d1b3SJooyung Han   static const MapOrder kLhsOrder = LhsType::kOrder;
423*5f39d1b3SJooyung Han   static const MapOrder kRhsOrder = RhsType::kOrder;
424*5f39d1b3SJooyung Han   static const MapOrder kResultOrder = ResultType::kOrder;
425*5f39d1b3SJooyung Han   ResultType ref_result(rows, cols);
426*5f39d1b3SJooyung Han   const bool transpose_c = kResultOrder == MapOrder::RowMajor;
427*5f39d1b3SJooyung Han   const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
428*5f39d1b3SJooyung Han   const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
429*5f39d1b3SJooyung Han   ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
430*5f39d1b3SJooyung Han       transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
431*5f39d1b3SJooyung Han       &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
432*5f39d1b3SJooyung Han       result_shift);
433*5f39d1b3SJooyung Han 
434*5f39d1b3SJooyung Han   typedef typename GemmWrapper::BitDepthParams BitDepthParams;
435*5f39d1b3SJooyung Han 
436*5f39d1b3SJooyung Han   ResultStats stats;
437*5f39d1b3SJooyung Han   GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
438*5f39d1b3SJooyung Han 
439*5f39d1b3SJooyung Han   // Adjust shifts until we get meaningful results
440*5f39d1b3SJooyung Han   int new_result_shift_min = result_shift_min;
441*5f39d1b3SJooyung Han   int new_result_shift_max = result_shift_max;
442*5f39d1b3SJooyung Han   bool retry = false;
443*5f39d1b3SJooyung Han 
444*5f39d1b3SJooyung Han   if (stats.med_val < 32) {
445*5f39d1b3SJooyung Han     new_result_shift_max = (result_shift_min + result_shift_max) / 2;
446*5f39d1b3SJooyung Han     retry = true;
447*5f39d1b3SJooyung Han   }
448*5f39d1b3SJooyung Han 
449*5f39d1b3SJooyung Han   if (stats.med_val > 224) {
450*5f39d1b3SJooyung Han     new_result_shift_min = (result_shift_min + result_shift_max) / 2;
451*5f39d1b3SJooyung Han     retry = true;
452*5f39d1b3SJooyung Han   }
453*5f39d1b3SJooyung Han 
454*5f39d1b3SJooyung Han   if (retry) {
455*5f39d1b3SJooyung Han     if (result_shift_min != result_shift_max) {
456*5f39d1b3SJooyung Han       test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
457*5f39d1b3SJooyung Han                                   rhs_offset, result_offset, result_mult_int,
458*5f39d1b3SJooyung Han                                   new_result_shift_min, new_result_shift_max);
459*5f39d1b3SJooyung Han     }
460*5f39d1b3SJooyung Han     return;
461*5f39d1b3SJooyung Han   }
462*5f39d1b3SJooyung Han 
463*5f39d1b3SJooyung Han   ResultStatsBounds bounds;
464*5f39d1b3SJooyung Han 
465*5f39d1b3SJooyung Han   // Check results
466*5f39d1b3SJooyung Han   const bool good = CheckResultStatsBounds(stats, bounds);
467*5f39d1b3SJooyung Han 
468*5f39d1b3SJooyung Han   printf(
469*5f39d1b3SJooyung Han       "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
470*5f39d1b3SJooyung Han       good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
471*5f39d1b3SJooyung Han       OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
472*5f39d1b3SJooyung Han       lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
473*5f39d1b3SJooyung Han 
474*5f39d1b3SJooyung Han   if (!good) {
475*5f39d1b3SJooyung Han     ReportResultStats(stats, bounds);
476*5f39d1b3SJooyung Han 
477*5f39d1b3SJooyung Han     int bad_coeffs_printed = 0;
478*5f39d1b3SJooyung Han     for (int c = 0; c < result->cols() && bad_coeffs_printed < 200; c++) {
479*5f39d1b3SJooyung Han       for (int r = 0; r < result->rows() && bad_coeffs_printed < 200; r++) {
480*5f39d1b3SJooyung Han         if (ref_result(r, c) != (*result)(r, c)) {
481*5f39d1b3SJooyung Han           printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
482*5f39d1b3SJooyung Han                  ref_result(r, c), (*result)(r, c));
483*5f39d1b3SJooyung Han           bad_coeffs_printed++;
484*5f39d1b3SJooyung Han         }
485*5f39d1b3SJooyung Han       }
486*5f39d1b3SJooyung Han     }
487*5f39d1b3SJooyung Han   }
488*5f39d1b3SJooyung Han 
489*5f39d1b3SJooyung Han   Check(good);
490*5f39d1b3SJooyung Han }
491*5f39d1b3SJooyung Han 
492*5f39d1b3SJooyung Han template <typename GemmWrapper, typename LhsType, typename RhsType,
493*5f39d1b3SJooyung Han           typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)494*5f39d1b3SJooyung Han void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
495*5f39d1b3SJooyung Han                const RhsType& rhs, ResultType* result, int lhs_offset,
496*5f39d1b3SJooyung Han                int rhs_offset, int result_offset, int result_mult_int) {
497*5f39d1b3SJooyung Han   test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
498*5f39d1b3SJooyung Han                               result_offset, result_mult_int, 0, 32);
499*5f39d1b3SJooyung Han }
500*5f39d1b3SJooyung Han 
501*5f39d1b3SJooyung Han enum class WhatParamsToTest {
502*5f39d1b3SJooyung Han   All,
503*5f39d1b3SJooyung Han   OnlyGenericCase,
504*5f39d1b3SJooyung Han };
505*5f39d1b3SJooyung Han 
506*5f39d1b3SJooyung Han template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
507*5f39d1b3SJooyung Han           MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)508*5f39d1b3SJooyung Han void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
509*5f39d1b3SJooyung Han                int cols, WhatParamsToTest params_to_test) {
510*5f39d1b3SJooyung Han   typedef std::uint8_t Scalar;
511*5f39d1b3SJooyung Han   typedef Matrix<Scalar, LhsOrder> LhsType;
512*5f39d1b3SJooyung Han   using BitDepthParams = typename GemmWrapper::BitDepthParams;
513*5f39d1b3SJooyung Han   LhsType lhs(rows, depth);
514*5f39d1b3SJooyung Han   MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
515*5f39d1b3SJooyung Han   typedef Matrix<Scalar, RhsOrder> RhsType;
516*5f39d1b3SJooyung Han   RhsType rhs(depth, cols);
517*5f39d1b3SJooyung Han   MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
518*5f39d1b3SJooyung Han   typedef Matrix<Scalar, ResultOrder> ResultType;
519*5f39d1b3SJooyung Han   ResultType result(rows, cols);
520*5f39d1b3SJooyung Han   MakeZero(&result);
521*5f39d1b3SJooyung Han 
522*5f39d1b3SJooyung Han   if (params_to_test == WhatParamsToTest::All) {
523*5f39d1b3SJooyung Han     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
524*5f39d1b3SJooyung Han     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
525*5f39d1b3SJooyung Han     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
526*5f39d1b3SJooyung Han     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
527*5f39d1b3SJooyung Han     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
528*5f39d1b3SJooyung Han     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
529*5f39d1b3SJooyung Han     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
530*5f39d1b3SJooyung Han   }
531*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
532*5f39d1b3SJooyung Han }
533*5f39d1b3SJooyung Han 
534*5f39d1b3SJooyung Han enum class WhatOrdersToTest { All, OnlyRCC };
535*5f39d1b3SJooyung Han 
536*5f39d1b3SJooyung Han template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)537*5f39d1b3SJooyung Han void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
538*5f39d1b3SJooyung Han                int cols, WhatParamsToTest params_to_test,
539*5f39d1b3SJooyung Han                WhatOrdersToTest orders_to_test) {
540*5f39d1b3SJooyung Han #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder)         \
541*5f39d1b3SJooyung Han   do {                                                             \
542*5f39d1b3SJooyung Han     test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
543*5f39d1b3SJooyung Han               MapOrder::ResultOrder>(context, rows, depth, cols,   \
544*5f39d1b3SJooyung Han                                      params_to_test);              \
545*5f39d1b3SJooyung Han   } while (false)
546*5f39d1b3SJooyung Han 
547*5f39d1b3SJooyung Han   if (orders_to_test == WhatOrdersToTest::All) {
548*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
549*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
550*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
551*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
552*5f39d1b3SJooyung Han 
553*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
554*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
555*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
556*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
557*5f39d1b3SJooyung Han   } else {
558*5f39d1b3SJooyung Han     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
559*5f39d1b3SJooyung Han   }
560*5f39d1b3SJooyung Han 
561*5f39d1b3SJooyung Han #undef GEMMLOWP_ONE_TEST
562*5f39d1b3SJooyung Han }
563*5f39d1b3SJooyung Han 
564*5f39d1b3SJooyung Han template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)565*5f39d1b3SJooyung Han void test_gemm_kernel(MultiThreadGemmContext* context) {
566*5f39d1b3SJooyung Han   typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
567*5f39d1b3SJooyung Han                                  DefaultL8R8BitDepthParams>
568*5f39d1b3SJooyung Han       GemmWrapper;
569*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
570*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
571*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
572*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
573*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
574*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
575*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
576*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
577*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
578*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
579*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
580*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
581*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
582*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
583*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 200, 200, 200,
584*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
585*5f39d1b3SJooyung Han                          WhatOrdersToTest::All);
586*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 50, 5000, 50,
587*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
588*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
589*5f39d1b3SJooyung Han }
590*5f39d1b3SJooyung Han 
591*5f39d1b3SJooyung Han template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)592*5f39d1b3SJooyung Han void test_gemm(typename GemmWrapper::Context* context) {
593*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
594*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
595*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
596*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
597*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
598*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
599*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
600*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
601*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
602*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
603*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
604*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
605*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
606*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
607*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
608*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
609*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
610*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
611*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
612*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
613*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
614*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
615*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
616*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
617*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
618*5f39d1b3SJooyung Han                          WhatOrdersToTest::All);
619*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
620*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
621*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
622*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
623*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
624*5f39d1b3SJooyung Han                          WhatOrdersToTest::All);
625*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
626*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
627*5f39d1b3SJooyung Han 
628*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
629*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
630*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
631*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
632*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
633*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
634*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
635*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
636*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
637*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
638*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
639*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
640*5f39d1b3SJooyung Han 
641*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 512, 512, 512,
642*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
643*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
644*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
645*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
646*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
647*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 567, 2345, 123,
648*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
649*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
650*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 100, 5000, 100,
651*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
652*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
653*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
654*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
655*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
656*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
657*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
658*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
659*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1, 1000, 1000,
660*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
661*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
662*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1000, 1, 1000,
663*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
664*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
665*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 1000, 1000, 1,
666*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
667*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
668*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 777, 3456, 1,
669*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
670*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
671*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 4567, 555, 1,
672*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
673*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
674*5f39d1b3SJooyung Han 
675*5f39d1b3SJooyung Han   // Test all storage orders
676*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
677*5f39d1b3SJooyung Han                          WhatOrdersToTest::All);
678*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 300, 400, 500,
679*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
680*5f39d1b3SJooyung Han                          WhatOrdersToTest::All);
681*5f39d1b3SJooyung Han }
682*5f39d1b3SJooyung Han 
683*5f39d1b3SJooyung Han template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)684*5f39d1b3SJooyung Han void test_gemv(typename GemmWrapper::Context* context) {
685*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
686*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
687*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
688*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
689*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
690*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
691*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
692*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
693*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
694*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
695*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
696*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
697*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
698*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
699*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
700*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
701*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
702*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
703*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
704*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
705*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
706*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
707*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
708*5f39d1b3SJooyung Han                          WhatOrdersToTest::OnlyRCC);
709*5f39d1b3SJooyung Han 
710*5f39d1b3SJooyung Han   // Test all storage orders
711*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
712*5f39d1b3SJooyung Han                          WhatOrdersToTest::All);
713*5f39d1b3SJooyung Han   test_gemm<GemmWrapper>(context, 300, 400, 1,
714*5f39d1b3SJooyung Han                          WhatParamsToTest::OnlyGenericCase,
715*5f39d1b3SJooyung Han                          WhatOrdersToTest::All);
716*5f39d1b3SJooyung Han }
717*5f39d1b3SJooyung Han 
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)718*5f39d1b3SJooyung Han const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
719*5f39d1b3SJooyung Han   switch (b) {
720*5f39d1b3SJooyung Han     case eight_bit_int_gemm::BitDepthSetting::A8B8:
721*5f39d1b3SJooyung Han       return "Lhs: 8 bit, Rhs: 8 bit";
722*5f39d1b3SJooyung Han     case eight_bit_int_gemm::BitDepthSetting::A5B7:
723*5f39d1b3SJooyung Han       return "(legacy, no longer requantizing) Lhs: 7 bit, Rhs: 5 bit";
724*5f39d1b3SJooyung Han     default:
725*5f39d1b3SJooyung Han       abort();
726*5f39d1b3SJooyung Han       return nullptr;
727*5f39d1b3SJooyung Han   }
728*5f39d1b3SJooyung Han }
729*5f39d1b3SJooyung Han 
730*5f39d1b3SJooyung Han // Runs a small set of hand-picked data for per-channel quantized data.
731*5f39d1b3SJooyung Han // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
732*5f39d1b3SJooyung Han // image.
TestWithSmallDataPerChannelQuantization()733*5f39d1b3SJooyung Han void TestWithSmallDataPerChannelQuantization() {
734*5f39d1b3SJooyung Han   const int m = 2;
735*5f39d1b3SJooyung Han   const int n = 9;
736*5f39d1b3SJooyung Han   const int k = 12;
737*5f39d1b3SJooyung Han 
738*5f39d1b3SJooyung Han   // 12 x 2, columnwise.
739*5f39d1b3SJooyung Han   const std::uint8_t a_data[] = {0,  0,   0,   0,   0,  0,   0,   0,
740*5f39d1b3SJooyung Han                                  0,  255, 255, 255, 64, 64,  64,  64,
741*5f39d1b3SJooyung Han                                  64, 64,  0,   0,   0,  255, 255, 255};
742*5f39d1b3SJooyung Han   const int lda = k;
743*5f39d1b3SJooyung Han   int a_offset[] = {0, -64};
744*5f39d1b3SJooyung Han   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
745*5f39d1b3SJooyung Han   const OffsetColMap lhs_offset(a_offset, m);
746*5f39d1b3SJooyung Han 
747*5f39d1b3SJooyung Han   // 12 x 9, columnwise.
748*5f39d1b3SJooyung Han   const std::uint8_t b_data[] = {
749*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,
750*5f39d1b3SJooyung Han       0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   127,
751*5f39d1b3SJooyung Han       127, 127, 0,   0,   0,   127, 127, 127, 0,   0,   0,   255, 255, 255,
752*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,
753*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   127, 127, 127, 0,   0,   0,   127,
754*5f39d1b3SJooyung Han       127, 127, 0,   0,   0,   0,   0,   0,   127, 127, 127, 127, 127, 127,
755*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   127, 127, 127, 127, 127, 127, 0,   0,
756*5f39d1b3SJooyung Han       0,   127, 127, 127, 127, 127, 127, 127, 127, 127};
757*5f39d1b3SJooyung Han   const int ldb = k;
758*5f39d1b3SJooyung Han   int b_offset = -127;
759*5f39d1b3SJooyung Han   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
760*5f39d1b3SJooyung Han   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
761*5f39d1b3SJooyung Han 
762*5f39d1b3SJooyung Han   // 2 x 9, columnwise.
763*5f39d1b3SJooyung Han   const std::uint8_t expected_c_data[] = {255, 255, 0,   0,   127, 159,
764*5f39d1b3SJooyung Han                                           0,   64,  0,   64,  127, 159,
765*5f39d1b3SJooyung Han                                           127, 127, 127, 127, 127, 127};
766*5f39d1b3SJooyung Han   const int ldc = m;
767*5f39d1b3SJooyung Han   int c_offset[] = {97155, 97346};
768*5f39d1b3SJooyung Han   int c_mult_int[] = {2741, 2741};
769*5f39d1b3SJooyung Han   const int c_shift = 21;
770*5f39d1b3SJooyung Han 
771*5f39d1b3SJooyung Han   const int c_count = m * n;
772*5f39d1b3SJooyung Han   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
773*5f39d1b3SJooyung Han   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
774*5f39d1b3SJooyung Han                                                      ldc);
775*5f39d1b3SJooyung Han   const OffsetColMap result_offset(c_offset, m);
776*5f39d1b3SJooyung Han   const OffsetColMap result_mult_int(c_mult_int, m);
777*5f39d1b3SJooyung Han   const int result_shift = c_shift;
778*5f39d1b3SJooyung Han 
779*5f39d1b3SJooyung Han   GemmContext gemm_context;
780*5f39d1b3SJooyung Han   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
781*5f39d1b3SJooyung Han       result_offset, result_mult_int, result_shift);
782*5f39d1b3SJooyung Han   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
783*5f39d1b3SJooyung Han                            DefaultL8R8BitDepthParams>(
784*5f39d1b3SJooyung Han       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
785*5f39d1b3SJooyung Han       output_pipeline);
786*5f39d1b3SJooyung Han 
787*5f39d1b3SJooyung Han   ResultStats stats;
788*5f39d1b3SJooyung Han   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
789*5f39d1b3SJooyung Han 
790*5f39d1b3SJooyung Han   ResultStatsBounds bounds;
791*5f39d1b3SJooyung Han   const bool good = CheckResultStatsBounds(stats, bounds);
792*5f39d1b3SJooyung Han   printf("TestWithSmallDataPerChannelQuantization: %s\n",
793*5f39d1b3SJooyung Han          good ? "PASS" : "FAIL");
794*5f39d1b3SJooyung Han   ReportResultStats(stats, bounds);
795*5f39d1b3SJooyung Han   Check(good);
796*5f39d1b3SJooyung Han }
797*5f39d1b3SJooyung Han 
798*5f39d1b3SJooyung Han // Runs a larger set of hand-picked data for per-channel quantized data.
799*5f39d1b3SJooyung Han // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
800*5f39d1b3SJooyung Han // image.  Right now, I have 7 different filters and 15 copies of the first
801*5f39d1b3SJooyung Han // filter to make sure NEON code path that processes 16 rows at a time is
802*5f39d1b3SJooyung Han // covered.
TestWithLargeDataPerChannelQuantization()803*5f39d1b3SJooyung Han void TestWithLargeDataPerChannelQuantization() {
804*5f39d1b3SJooyung Han   const int m = 22;
805*5f39d1b3SJooyung Han   const int n = 25;
806*5f39d1b3SJooyung Han   const int k = 27;
807*5f39d1b3SJooyung Han 
808*5f39d1b3SJooyung Han   // 27 x 22, column-wise.
809*5f39d1b3SJooyung Han   const std::uint8_t a_data[] = {
810*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
811*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
812*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   127, 127, 127, 255, 255, 255, 127, 127, 127,
813*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   127, 127, 127,
814*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
815*5f39d1b3SJooyung Han       127, 127, 127, 0,   0,   0,   51,  51,  51,  51,  51,  51,  51,  51,  51,
816*5f39d1b3SJooyung Han       0,   0,   0,   255, 255, 255, 0,   0,   0,   51,  51,  51,  51,  51,  51,
817*5f39d1b3SJooyung Han       51,  51,  51,  51,  51,  51,  0,   0,   0,   51,  51,  51,  51,  51,  51,
818*5f39d1b3SJooyung Han       255, 255, 255, 51,  51,  51,  51,  51,  51,  0,   0,   0,   51,  51,  51,
819*5f39d1b3SJooyung Han       0,   0,   0,   64,  64,  64,  0,   0,   0,   64,  64,  64,  255, 255, 255,
820*5f39d1b3SJooyung Han       64,  64,  64,  0,   0,   0,   64,  64,  64,  0,   0,   0,   36,  36,  36,
821*5f39d1b3SJooyung Han       0,   0,   0,   36,  36,  36,  0,   0,   0,   255, 255, 255, 0,   0,   0,
822*5f39d1b3SJooyung Han       36,  36,  36,  0,   0,   0,   36,  36,  36,  0,   0,   0,   0,   0,   0,
823*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
824*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
825*5f39d1b3SJooyung Han       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
826*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
827*5f39d1b3SJooyung Han       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
828*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
829*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
830*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
831*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
832*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
833*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
834*5f39d1b3SJooyung Han       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
835*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
836*5f39d1b3SJooyung Han       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
837*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
838*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
839*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
840*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
841*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
842*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
843*5f39d1b3SJooyung Han       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
844*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
845*5f39d1b3SJooyung Han       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
846*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
847*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
848*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
849*5f39d1b3SJooyung Han       0,   0,   0,   0,   0,   0,   0,   0,   0,
850*5f39d1b3SJooyung Han   };
851*5f39d1b3SJooyung Han   const int lda = k;
852*5f39d1b3SJooyung Han   int a_offset[] = {0, 0, 0, -51, -51, 0, -36, 0, 0, 0, 0,
853*5f39d1b3SJooyung Han                     0, 0, 0, 0,   0,   0, 0,   0, 0, 0, 0};
854*5f39d1b3SJooyung Han   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
855*5f39d1b3SJooyung Han   const OffsetColMap lhs_offset(a_offset, m);
856*5f39d1b3SJooyung Han 
857*5f39d1b3SJooyung Han   // 27 x 25, column-wise.
858*5f39d1b3SJooyung Han   const std::uint8_t b_data[] = {
859*5f39d1b3SJooyung Han       127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119,
860*5f39d1b3SJooyung Han       119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
861*5f39d1b3SJooyung Han       127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
862*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127,
863*5f39d1b3SJooyung Han       127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
864*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
865*5f39d1b3SJooyung Han       127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
866*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
867*5f39d1b3SJooyung Han       127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
868*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119,
869*5f39d1b3SJooyung Han       119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
870*5f39d1b3SJooyung Han       127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
871*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
872*5f39d1b3SJooyung Han       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
873*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
874*5f39d1b3SJooyung Han       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
875*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
876*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
877*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
878*5f39d1b3SJooyung Han       119, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
879*5f39d1b3SJooyung Han       127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
880*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
881*5f39d1b3SJooyung Han       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
882*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
883*5f39d1b3SJooyung Han       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
884*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
885*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
886*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
887*5f39d1b3SJooyung Han       119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
888*5f39d1b3SJooyung Han       127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
889*5f39d1b3SJooyung Han       119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
890*5f39d1b3SJooyung Han       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
891*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
892*5f39d1b3SJooyung Han       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
893*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
894*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
895*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896*5f39d1b3SJooyung Han       119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
897*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119, 119,
898*5f39d1b3SJooyung Han       119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127,
899*5f39d1b3SJooyung Han       127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
900*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
901*5f39d1b3SJooyung Han       127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
902*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
903*5f39d1b3SJooyung Han       127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119,
904*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127,
905*5f39d1b3SJooyung Han       127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
906*5f39d1b3SJooyung Han       119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
907*5f39d1b3SJooyung Han       127, 127, 127};
908*5f39d1b3SJooyung Han   const int ldb = k;
909*5f39d1b3SJooyung Han   int b_offset = -127;
910*5f39d1b3SJooyung Han   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
911*5f39d1b3SJooyung Han   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
912*5f39d1b3SJooyung Han 
913*5f39d1b3SJooyung Han   // 22 x 25, column-wise.
914*5f39d1b3SJooyung Han   const std::uint8_t expected_c_data[] = {
915*5f39d1b3SJooyung Han       7,   37,  37,  67,  67,  39,  79,  7,   7,   7,   7,   7,   7,   7,   7,
916*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,
917*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
918*5f39d1b3SJooyung Han       7,   37,  87,  67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,
919*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,   7,
920*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
921*5f39d1b3SJooyung Han       37,  67,  67,  39,  79,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
922*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   37,  7,   67,  87,  23,  91,  7,   7,   7,
923*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
924*5f39d1b3SJooyung Han       87,  87,  7,   103, 7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
925*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   71,  87,  45,  41,  77,  7,   7,   7,   7,
926*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   87,
927*5f39d1b3SJooyung Han       87,  7,   103, 7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
928*5f39d1b3SJooyung Han       7,   7,   7,   7,   37,  7,   67,  87,  23,  91,  7,   7,   7,   7,   7,
929*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  7,   67,  87,
930*5f39d1b3SJooyung Han       23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
931*5f39d1b3SJooyung Han       7,   7,   7,   71,  7,   45,  87,  41,  77,  7,   7,   7,   7,   7,   7,
932*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   255, 135, 135, 255, 255, 143,
933*5f39d1b3SJooyung Han       255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
934*5f39d1b3SJooyung Han       255, 7,   71,  7,   45,  87,  41,  77,  7,   7,   7,   7,   7,   7,   7,
935*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  7,   67,  87,  23,  91,
936*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
937*5f39d1b3SJooyung Han       7,   37,  7,   67,  87,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,
938*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   87,  87,  7,   103, 7,
939*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
940*5f39d1b3SJooyung Han       7,   71,  87,  45,  41,  77,  7,   7,   7,   7,   7,   7,   7,   7,   7,
941*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   87,  87,  7,   103, 7,   7,
942*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
943*5f39d1b3SJooyung Han       7,   67,  87,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
944*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   37,  37,  67,  67,  39,  79,  7,   7,   7,
945*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
946*5f39d1b3SJooyung Han       87,  67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
947*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,   7,   7,   7,
948*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  87,
949*5f39d1b3SJooyung Han       67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
950*5f39d1b3SJooyung Han       7,   7,   7,   7,   37,  37,  67,  67,  39,  79,  7,   7,   7,   7,   7,
951*5f39d1b3SJooyung Han       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   99,  99,  99,  99,  99,
952*5f39d1b3SJooyung Han       99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,
953*5f39d1b3SJooyung Han       99,  99,  111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
954*5f39d1b3SJooyung Han       111, 111, 111, 111, 111, 111, 111, 111, 111,
955*5f39d1b3SJooyung Han   };
956*5f39d1b3SJooyung Han   const int ldc = m;
957*5f39d1b3SJooyung Han   int c_offset[] = {
958*5f39d1b3SJooyung Han       6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477, 6477,
959*5f39d1b3SJooyung Han       6477, 6477,  6477,  6477, 6477, 6477,  6477, 6477, 6477, 6477, 6477,
960*5f39d1b3SJooyung Han   };
961*5f39d1b3SJooyung Han   int c_mult_int[] = {
962*5f39d1b3SJooyung Han       41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121,
963*5f39d1b3SJooyung Han       41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
964*5f39d1b3SJooyung Han       41121, 41121, 41121, 41121, 41121, 41121,
965*5f39d1b3SJooyung Han   };
966*5f39d1b3SJooyung Han   const int c_shift = 21;
967*5f39d1b3SJooyung Han 
968*5f39d1b3SJooyung Han   const int c_count = m * n;
969*5f39d1b3SJooyung Han   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
970*5f39d1b3SJooyung Han   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
971*5f39d1b3SJooyung Han                                                      ldc);
972*5f39d1b3SJooyung Han   const OffsetColMap result_offset(c_offset, m);
973*5f39d1b3SJooyung Han   const OffsetColMap result_mult_int(c_mult_int, m);
974*5f39d1b3SJooyung Han   const int result_shift = c_shift;
975*5f39d1b3SJooyung Han 
976*5f39d1b3SJooyung Han   GemmContext gemm_context;
977*5f39d1b3SJooyung Han   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
978*5f39d1b3SJooyung Han       result_offset, result_mult_int, result_shift);
979*5f39d1b3SJooyung Han   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
980*5f39d1b3SJooyung Han                            DefaultL8R8BitDepthParams>(
981*5f39d1b3SJooyung Han       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
982*5f39d1b3SJooyung Han       output_pipeline);
983*5f39d1b3SJooyung Han 
984*5f39d1b3SJooyung Han   ResultStats stats;
985*5f39d1b3SJooyung Han   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
986*5f39d1b3SJooyung Han 
987*5f39d1b3SJooyung Han   ResultStatsBounds bounds;
988*5f39d1b3SJooyung Han   const bool good = CheckResultStatsBounds(stats, bounds);
989*5f39d1b3SJooyung Han   printf("TestWithLargeDataPerChannelQuantization: %s\n",
990*5f39d1b3SJooyung Han          good ? "PASS" : "FAIL");
991*5f39d1b3SJooyung Han   ReportResultStats(stats, bounds);
992*5f39d1b3SJooyung Han   Check(good);
993*5f39d1b3SJooyung Han }
994*5f39d1b3SJooyung Han 
995*5f39d1b3SJooyung Han // Multithreading only activates when the result has more than 16 rows, and also
996*5f39d1b3SJooyung Han // (result rows) * (result cols) * depth >= 2 x 65 x 1024.  Size was selected
997*5f39d1b3SJooyung Han // to run in 3 threads.
998*5f39d1b3SJooyung Han //
999*5f39d1b3SJooyung Han // Based on the following floating point data:
1000*5f39d1b3SJooyung Han //   LHS: all zeros except 10.0, 20.0 at the beginning of first 16 rows;
1001*5f39d1b3SJooyung Han //     1.0, 2.0 at the beginning of next 16 rows; 0.1, 0.2 in next 16 rows;
1002*5f39d1b3SJooyung Han //     0.01, 0.02 in last 16 rows.
1003*5f39d1b3SJooyung Han //   RHS: all zeros except 1.0 in (0, 0) and 2.0 in (1, 0).
1004*5f39d1b3SJooyung Han //   Varying boundaries were used for each 16 rows block of LHS, to test for
1005*5f39d1b3SJooyung Han //     correct indexing into offsets.
1006*5f39d1b3SJooyung Han //   Expected result: all zeros, except 50.0 at the beginning of first 16 rows;
1007*5f39d1b3SJooyung Han //     5.0 at the beginning of next 16 rows; 0.5 in next 16 rows; 0.05 in last
1008*5f39d1b3SJooyung Han //     16 rows.
TestMultithreadedPerChannelQuantization()1009*5f39d1b3SJooyung Han void TestMultithreadedPerChannelQuantization() {
1010*5f39d1b3SJooyung Han   const int m = 64;
1011*5f39d1b3SJooyung Han   const int n = 20;
1012*5f39d1b3SJooyung Han   const int k = 160;
1013*5f39d1b3SJooyung Han 
1014*5f39d1b3SJooyung Han   // LHS, m x k.
1015*5f39d1b3SJooyung Han   const std::array<std::int32_t, 4> lhs_offsets_terse{{
1016*5f39d1b3SJooyung Han       0, -51, -85, -109,
1017*5f39d1b3SJooyung Han   }};
1018*5f39d1b3SJooyung Han   assert(lhs_offsets_terse.size() * 16 == m);
1019*5f39d1b3SJooyung Han   const std::array<std::uint8_t, 4> lhs_first_el{{
1020*5f39d1b3SJooyung Han       128, 153, 170, 182,
1021*5f39d1b3SJooyung Han   }};
1022*5f39d1b3SJooyung Han   assert(lhs_first_el.size() * 16 == m);
1023*5f39d1b3SJooyung Han 
1024*5f39d1b3SJooyung Han   // lhs_first_el at (i, 0) and 255 at (i, 1), other values are all -offset.
1025*5f39d1b3SJooyung Han   std::vector<std::uint8_t> a_data(m * k, 0);
1026*5f39d1b3SJooyung Han   for (int i = 0; i < m; ++i) {
1027*5f39d1b3SJooyung Han     a_data[i * k] = lhs_first_el[i / 16];
1028*5f39d1b3SJooyung Han     a_data[i * k + 1] = 255;
1029*5f39d1b3SJooyung Han     for (int j = 2; j < k; ++j) {
1030*5f39d1b3SJooyung Han       a_data[i * k + j] = std::uint8_t(-lhs_offsets_terse[i / 16]);
1031*5f39d1b3SJooyung Han     }
1032*5f39d1b3SJooyung Han   }
1033*5f39d1b3SJooyung Han 
1034*5f39d1b3SJooyung Han   const int lda = k;
1035*5f39d1b3SJooyung Han   // Given values at [i / 16].
1036*5f39d1b3SJooyung Han   std::vector<std::int32_t> a_offset(m, 0);
1037*5f39d1b3SJooyung Han   for (int i = 0; i < m; ++i) {
1038*5f39d1b3SJooyung Han     a_offset[i] = lhs_offsets_terse[i / 16];
1039*5f39d1b3SJooyung Han   }
1040*5f39d1b3SJooyung Han 
1041*5f39d1b3SJooyung Han   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(&a_data[0], m, k, lda);
1042*5f39d1b3SJooyung Han   const OffsetColMap lhs_offset(&a_offset[0], m);
1043*5f39d1b3SJooyung Han 
1044*5f39d1b3SJooyung Han   // RHS, k x n.
1045*5f39d1b3SJooyung Han   // All zeros, except 128 at (0, 0) and 255 at (1, 0).
1046*5f39d1b3SJooyung Han   std::vector<std::uint8_t> b_data(k * n, 0);
1047*5f39d1b3SJooyung Han   b_data[0] = 128;
1048*5f39d1b3SJooyung Han   b_data[1] = 255;
1049*5f39d1b3SJooyung Han 
1050*5f39d1b3SJooyung Han   const int ldb = k;
1051*5f39d1b3SJooyung Han   std::int32_t b_offset = 0;
1052*5f39d1b3SJooyung Han   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(&b_data[0], k, n, ldb);
1053*5f39d1b3SJooyung Han   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
1054*5f39d1b3SJooyung Han 
1055*5f39d1b3SJooyung Han   // Result, m x n.
1056*5f39d1b3SJooyung Han   // All zeros, except given values at (i / 16, 0).
1057*5f39d1b3SJooyung Han   const std::array<std::uint8_t, 4> expected_c_terse{{
1058*5f39d1b3SJooyung Han       142, 159, 182, 213,
1059*5f39d1b3SJooyung Han   }};
1060*5f39d1b3SJooyung Han   assert(expected_c_terse.size() * 16 == m);
1061*5f39d1b3SJooyung Han   std::vector<std::uint8_t> expected_c_data(m * n, 0);
1062*5f39d1b3SJooyung Han   for (int i = 0; i < m; ++i) {
1063*5f39d1b3SJooyung Han     expected_c_data[i] = expected_c_terse[i / 16];
1064*5f39d1b3SJooyung Han   }
1065*5f39d1b3SJooyung Han 
1066*5f39d1b3SJooyung Han   const int ldc = m;
1067*5f39d1b3SJooyung Han   // All zeros.
1068*5f39d1b3SJooyung Han   std::vector<std::int32_t> c_offset(m, 0);
1069*5f39d1b3SJooyung Han   // Given values at [i / 16].
1070*5f39d1b3SJooyung Han   const std::array<std::int32_t, 4> c_mult_int_terse{{
1071*5f39d1b3SJooyung Han       3655, 5140, 7049, 9595,
1072*5f39d1b3SJooyung Han   }};
1073*5f39d1b3SJooyung Han   assert(c_mult_int_terse.size() * 16 == m);
1074*5f39d1b3SJooyung Han   std::vector<std::int32_t> c_mult_int(m);
1075*5f39d1b3SJooyung Han   for (int i = 0; i < m; ++i) {
1076*5f39d1b3SJooyung Han     c_mult_int[i] = c_mult_int_terse[i / 16];
1077*5f39d1b3SJooyung Han   }
1078*5f39d1b3SJooyung Han 
1079*5f39d1b3SJooyung Han   const int c_shift = 21;
1080*5f39d1b3SJooyung Han 
1081*5f39d1b3SJooyung Han   const int c_count = m * n;
1082*5f39d1b3SJooyung Han   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1083*5f39d1b3SJooyung Han   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
1084*5f39d1b3SJooyung Han                                                      ldc);
1085*5f39d1b3SJooyung Han   const OffsetColMap result_offset(&c_offset[0], m);
1086*5f39d1b3SJooyung Han   const OffsetColMap result_mult_int(&c_mult_int[0], m);
1087*5f39d1b3SJooyung Han   const int result_shift = c_shift;
1088*5f39d1b3SJooyung Han 
1089*5f39d1b3SJooyung Han   GemmContext gemm_context;
1090*5f39d1b3SJooyung Han   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
1091*5f39d1b3SJooyung Han       result_offset, result_mult_int, result_shift);
1092*5f39d1b3SJooyung Han   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
1093*5f39d1b3SJooyung Han                            DefaultL8R8BitDepthParams>(
1094*5f39d1b3SJooyung Han       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
1095*5f39d1b3SJooyung Han       output_pipeline);
1096*5f39d1b3SJooyung Han 
1097*5f39d1b3SJooyung Han   ResultStats stats;
1098*5f39d1b3SJooyung Han   GetResultStats(output_data.get(), &expected_c_data[0], c_count, &stats);
1099*5f39d1b3SJooyung Han 
1100*5f39d1b3SJooyung Han   ResultStatsBounds bounds;
1101*5f39d1b3SJooyung Han   const bool good = CheckResultStatsBounds(stats, bounds);
1102*5f39d1b3SJooyung Han   printf("TestMultithreadedPerChannelQuantization: %s\n",
1103*5f39d1b3SJooyung Han          good ? "PASS" : "FAIL");
1104*5f39d1b3SJooyung Han   ReportResultStats(stats, bounds);
1105*5f39d1b3SJooyung Han   Check(good);
1106*5f39d1b3SJooyung Han }
1107*5f39d1b3SJooyung Han 
1108*5f39d1b3SJooyung Han // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1109*5f39d1b3SJooyung Han void TestWithSmallData() {
1110*5f39d1b3SJooyung Han   const int m = 4;
1111*5f39d1b3SJooyung Han   const int n = 2;
1112*5f39d1b3SJooyung Han   const int k = 3;
1113*5f39d1b3SJooyung Han   // Matrix A (LHS) is:
1114*5f39d1b3SJooyung Han   // |  7 | 10 | 13 | 16 |
1115*5f39d1b3SJooyung Han   // |  8 | 11 | 14 | 17 |
1116*5f39d1b3SJooyung Han   // |  9 | 12 | 15 | 18 |
1117*5f39d1b3SJooyung Han   const std::uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1118*5f39d1b3SJooyung Han   // Matrix B (RHS) is:
1119*5f39d1b3SJooyung Han   // |  1 |  3 |  5 |
1120*5f39d1b3SJooyung Han   // |  2 |  4 |  6 |
1121*5f39d1b3SJooyung Han   const std::uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1122*5f39d1b3SJooyung Han   // Here are the results we expect, from hand calculations:
1123*5f39d1b3SJooyung Han   // (1 * 7) + (3 * 8) + (5 * 9) = 76
1124*5f39d1b3SJooyung Han   // (2 * 7) + (4 * 8) + (6 * 9) = 100
1125*5f39d1b3SJooyung Han   // (1 * 10) + (3 * 11) + (5 * 12) = 103
1126*5f39d1b3SJooyung Han   // (2 * 10) + (4 * 11) + (6 * 12) = 136
1127*5f39d1b3SJooyung Han   // (1 * 13) + (3 * 14) + (5 * 15) = 130
1128*5f39d1b3SJooyung Han   // (2 * 13) + (4 * 14) + (6 * 15) = 172
1129*5f39d1b3SJooyung Han   // (1 * 16) + (3 * 17) + (5 * 18) = 157
1130*5f39d1b3SJooyung Han   // (2 * 16) + (4 * 17) + (6 * 18) = 208
1131*5f39d1b3SJooyung Han   // That means matrix C should be:
1132*5f39d1b3SJooyung Han   // |  76 | 103 | 130 | 157 |
1133*5f39d1b3SJooyung Han   // | 100 | 136 | 172 | 208 |
1134*5f39d1b3SJooyung Han   const std::uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1135*5f39d1b3SJooyung Han 
1136*5f39d1b3SJooyung Han   const int c_count = m * n;
1137*5f39d1b3SJooyung Han   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1138*5f39d1b3SJooyung Han 
1139*5f39d1b3SJooyung Han   const bool is_a_transposed = true;
1140*5f39d1b3SJooyung Han   const bool is_b_transposed = true;
1141*5f39d1b3SJooyung Han   const bool is_c_transposed = true;
1142*5f39d1b3SJooyung Han   const int lda = k;
1143*5f39d1b3SJooyung Han   const int ldb = n;
1144*5f39d1b3SJooyung Han   const int ldc = n;
1145*5f39d1b3SJooyung Han 
1146*5f39d1b3SJooyung Han   const int a_offset = 0;
1147*5f39d1b3SJooyung Han   const int b_offset = 0;
1148*5f39d1b3SJooyung Han   const int c_offset = 0;
1149*5f39d1b3SJooyung Han   const int c_mult = 1;
1150*5f39d1b3SJooyung Han   const int c_shift = 0;
1151*5f39d1b3SJooyung Han 
1152*5f39d1b3SJooyung Han   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1153*5f39d1b3SJooyung Han       is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1154*5f39d1b3SJooyung Han       a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1155*5f39d1b3SJooyung Han       c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1156*5f39d1b3SJooyung Han 
1157*5f39d1b3SJooyung Han   ResultStats stats;
1158*5f39d1b3SJooyung Han   GetResultStats(output_data.get(), expected_data, c_count, &stats);
1159*5f39d1b3SJooyung Han 
1160*5f39d1b3SJooyung Han   ResultStatsBounds bounds;
1161*5f39d1b3SJooyung Han   const bool good = CheckResultStatsBounds(stats, bounds);
1162*5f39d1b3SJooyung Han   printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1163*5f39d1b3SJooyung Han   ReportResultStats(stats, bounds);
1164*5f39d1b3SJooyung Han   Check(good);
1165*5f39d1b3SJooyung Han }
1166*5f39d1b3SJooyung Han 
1167*5f39d1b3SJooyung Han // This is the most realistic test of how we'll be using the low-precision GEMM
1168*5f39d1b3SJooyung Han // function in applications. It takes in large input matrices that have been
1169*5f39d1b3SJooyung Han // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1170*5f39d1b3SJooyung Han void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1171*5f39d1b3SJooyung Han                       int tolerance_median, int tolerance_max) {
1172*5f39d1b3SJooyung Han   std::unique_ptr<std::uint8_t[]> output_data(
1173*5f39d1b3SJooyung Han       new std::uint8_t[test_data::c_count]);
1174*5f39d1b3SJooyung Han   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1175*5f39d1b3SJooyung Han       test_data::is_a_transposed, test_data::is_b_transposed,
1176*5f39d1b3SJooyung Han       test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1177*5f39d1b3SJooyung Han       test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1178*5f39d1b3SJooyung Han       test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1179*5f39d1b3SJooyung Han       test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1180*5f39d1b3SJooyung Han 
1181*5f39d1b3SJooyung Han   ResultStats stats;
1182*5f39d1b3SJooyung Han   GetResultStats(output_data.get(), test_data::expected_c_data,
1183*5f39d1b3SJooyung Han                  test_data::c_count, &stats);
1184*5f39d1b3SJooyung Han 
1185*5f39d1b3SJooyung Han   ResultStatsBounds bounds;
1186*5f39d1b3SJooyung Han   if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1187*5f39d1b3SJooyung Han     bounds.med_unsigned_diff = tolerance_median;
1188*5f39d1b3SJooyung Han     bounds.max_unsigned_diff = tolerance_max;
1189*5f39d1b3SJooyung Han     bounds.med_signed_diff = 0;
1190*5f39d1b3SJooyung Han     bounds.mean_signed_diff = 0.2f;
1191*5f39d1b3SJooyung Han   }
1192*5f39d1b3SJooyung Han 
1193*5f39d1b3SJooyung Han   const bool good = CheckResultStatsBounds(stats, bounds);
1194*5f39d1b3SJooyung Han   printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1195*5f39d1b3SJooyung Han          GetBitDepthName(BitDepth));
1196*5f39d1b3SJooyung Han   ReportResultStats(stats, bounds);
1197*5f39d1b3SJooyung Han   Check(good);
1198*5f39d1b3SJooyung Han }
1199*5f39d1b3SJooyung Han 
1200*5f39d1b3SJooyung Han template <typename BitDepthParams, MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1201*5f39d1b3SJooyung Han void TestOutputStages(int rows, int depth, int cols, int result_offset,
1202*5f39d1b3SJooyung Han                       int result_mult_int, int result_shift) {
1203*5f39d1b3SJooyung Han   Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1204*5f39d1b3SJooyung Han   Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1205*5f39d1b3SJooyung Han   Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1206*5f39d1b3SJooyung Han   MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
1207*5f39d1b3SJooyung Han   MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
1208*5f39d1b3SJooyung Han   const int lhs_offset = 12;
1209*5f39d1b3SJooyung Han   const int rhs_offset = -34;
1210*5f39d1b3SJooyung Han 
1211*5f39d1b3SJooyung Han   // Test an empty pipeline, i.e. returning raw int32 accumulators.
1212*5f39d1b3SJooyung Han   auto empty_pipeline = std::make_tuple();
1213*5f39d1b3SJooyung Han   GemmContext context;
1214*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1215*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1216*5f39d1b3SJooyung Han       rhs_offset, empty_pipeline);
1217*5f39d1b3SJooyung Han 
1218*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1219*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1220*5f39d1b3SJooyung Han       std::int32_t expected = 0;
1221*5f39d1b3SJooyung Han       for (int d = 0; d < depth; d++) {
1222*5f39d1b3SJooyung Han         std::int32_t lhs_val =
1223*5f39d1b3SJooyung Han             static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1224*5f39d1b3SJooyung Han         std::int32_t rhs_val =
1225*5f39d1b3SJooyung Han             static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1226*5f39d1b3SJooyung Han         expected += lhs_val * rhs_val;
1227*5f39d1b3SJooyung Han       }
1228*5f39d1b3SJooyung Han       Check(expected == result_raw_int32(r, c));
1229*5f39d1b3SJooyung Han     }
1230*5f39d1b3SJooyung Han   }
1231*5f39d1b3SJooyung Han 
1232*5f39d1b3SJooyung Han   // Test a pipeline with only the quantize-down stage, still returning
1233*5f39d1b3SJooyung Han   // unclamped (but scaled) int32's
1234*5f39d1b3SJooyung Han   OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1235*5f39d1b3SJooyung Han   quantize_down_stage.result_offset = result_offset;
1236*5f39d1b3SJooyung Han   quantize_down_stage.result_mult_int = result_mult_int;
1237*5f39d1b3SJooyung Han   quantize_down_stage.result_shift = result_shift;
1238*5f39d1b3SJooyung Han   auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1239*5f39d1b3SJooyung Han   Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1240*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1241*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1242*5f39d1b3SJooyung Han       lhs_offset, rhs_offset, quantize_down_pipeline);
1243*5f39d1b3SJooyung Han 
1244*5f39d1b3SJooyung Han   std::int64_t sum = 0;
1245*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1246*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1247*5f39d1b3SJooyung Han       std::int32_t raw = result_raw_int32(r, c);
1248*5f39d1b3SJooyung Han       std::int32_t expected = RoundingDivideByPOT(
1249*5f39d1b3SJooyung Han           (raw + result_offset) * result_mult_int, result_shift);
1250*5f39d1b3SJooyung Han       Check(expected == result_quantized_down_int32(r, c));
1251*5f39d1b3SJooyung Han       sum += expected;
1252*5f39d1b3SJooyung Han     }
1253*5f39d1b3SJooyung Han   }
1254*5f39d1b3SJooyung Han   std::int64_t avg = sum / (rows * cols);
1255*5f39d1b3SJooyung Han   // Test that the average quantized-down value falls reasonably in the
1256*5f39d1b3SJooyung Han   // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1257*5f39d1b3SJooyung Han   // adjusted.
1258*5f39d1b3SJooyung Han   Check(avg >= 64 && avg <= 192);
1259*5f39d1b3SJooyung Han 
1260*5f39d1b3SJooyung Han   // Test the familiar default pipeline consisting of quantize-down and
1261*5f39d1b3SJooyung Han   // clamp-and-cast-to-uint8.
1262*5f39d1b3SJooyung Han   OutputStageSaturatingCastToUint8 saturating_cast_stage;
1263*5f39d1b3SJooyung Han   auto quantize_down_and_saturating_cast_pipeline =
1264*5f39d1b3SJooyung Han       std::make_tuple(quantize_down_stage, saturating_cast_stage);
1265*5f39d1b3SJooyung Han   Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1266*5f39d1b3SJooyung Han                                                                           cols);
1267*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1268*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(),
1269*5f39d1b3SJooyung Han       &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1270*5f39d1b3SJooyung Han       quantize_down_and_saturating_cast_pipeline);
1271*5f39d1b3SJooyung Han 
1272*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1273*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1274*5f39d1b3SJooyung Han       std::int32_t quantized = result_quantized_down_int32(r, c);
1275*5f39d1b3SJooyung Han       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1276*5f39d1b3SJooyung Han       Check(expected == result_quantized_down_saturated_uint8(r, c));
1277*5f39d1b3SJooyung Han     }
1278*5f39d1b3SJooyung Han   }
1279*5f39d1b3SJooyung Han 
1280*5f39d1b3SJooyung Han   // Test a variant of the familiar default pipeline consisting of quantize-down
1281*5f39d1b3SJooyung Han   // and clamp-and-cast-to-int16.
1282*5f39d1b3SJooyung Han   OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1283*5f39d1b3SJooyung Han   auto quantize_down_and_saturating_cast_int16_pipeline =
1284*5f39d1b3SJooyung Han       std::make_tuple(quantize_down_stage, saturating_cast_int16_stage);
1285*5f39d1b3SJooyung Han   Matrix<std::int16_t, ResultOrder> result_quantized_down_saturated_int16(rows,
1286*5f39d1b3SJooyung Han                                                                           cols);
1287*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int16_t, DefaultL8R8BitDepthParams>(
1288*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(),
1289*5f39d1b3SJooyung Han       &result_quantized_down_saturated_int16, lhs_offset, rhs_offset,
1290*5f39d1b3SJooyung Han       quantize_down_and_saturating_cast_int16_pipeline);
1291*5f39d1b3SJooyung Han 
1292*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1293*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1294*5f39d1b3SJooyung Han       std::int32_t quantized = result_quantized_down_int32(r, c);
1295*5f39d1b3SJooyung Han       std::int16_t expected = std::min(std::max(quantized, -32768), 32767);
1296*5f39d1b3SJooyung Han       Check(expected == result_quantized_down_saturated_int16(r, c));
1297*5f39d1b3SJooyung Han     }
1298*5f39d1b3SJooyung Han   }
1299*5f39d1b3SJooyung Han 
1300*5f39d1b3SJooyung Han #ifdef GEMMLOWP_MSA
1301*5f39d1b3SJooyung Han   // Test a pipeline consisting of quantize-down and truncating-cast-to-uint8.
1302*5f39d1b3SJooyung Han   OutputStageTruncatingCastToUint8 truncating_cast_stage;
1303*5f39d1b3SJooyung Han   auto quantize_down_and_truncating_cast_pipeline =
1304*5f39d1b3SJooyung Han       std::make_tuple(quantize_down_stage, truncating_cast_stage);
1305*5f39d1b3SJooyung Han   Matrix<std::uint8_t, ResultOrder> result_quantized_down_truncated_uint8(
1306*5f39d1b3SJooyung Han       rows, cols);
1307*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1308*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(),
1309*5f39d1b3SJooyung Han       &result_quantized_down_truncated_uint8, lhs_offset, rhs_offset,
1310*5f39d1b3SJooyung Han       quantize_down_and_truncating_cast_pipeline);
1311*5f39d1b3SJooyung Han 
1312*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1313*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1314*5f39d1b3SJooyung Han       std::int32_t quantized = result_quantized_down_int32(r, c);
1315*5f39d1b3SJooyung Han       std::uint8_t expected = quantized & 255;
1316*5f39d1b3SJooyung Han       Check(expected == result_quantized_down_truncated_uint8(r, c));
1317*5f39d1b3SJooyung Han     }
1318*5f39d1b3SJooyung Han   }
1319*5f39d1b3SJooyung Han #endif
1320*5f39d1b3SJooyung Han 
1321*5f39d1b3SJooyung Han   // Test a bias-addition with row-vector
1322*5f39d1b3SJooyung Han   std::vector<std::int32_t> row_vector_data(cols);
1323*5f39d1b3SJooyung Han   std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
1324*5f39d1b3SJooyung Han                                                                          500);
1325*5f39d1b3SJooyung Han   for (int i = 0; i < cols; i++) {
1326*5f39d1b3SJooyung Han     row_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1327*5f39d1b3SJooyung Han   }
1328*5f39d1b3SJooyung Han   typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1329*5f39d1b3SJooyung Han   RowVectorMap row_vector_map(row_vector_data.data(), cols);
1330*5f39d1b3SJooyung Han   OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1331*5f39d1b3SJooyung Han   row_bias_addition_stage.bias_vector = row_vector_map;
1332*5f39d1b3SJooyung Han   auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1333*5f39d1b3SJooyung Han   Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1334*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1335*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1336*5f39d1b3SJooyung Han       lhs_offset, rhs_offset, row_bias_addition_pipeline);
1337*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1338*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1339*5f39d1b3SJooyung Han       std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1340*5f39d1b3SJooyung Han       Check(expected == result_of_row_bias_addition(r, c));
1341*5f39d1b3SJooyung Han     }
1342*5f39d1b3SJooyung Han   }
1343*5f39d1b3SJooyung Han 
1344*5f39d1b3SJooyung Han   // Test a bias-addition with column-vector
1345*5f39d1b3SJooyung Han   std::vector<std::int32_t> col_vector_data(rows);
1346*5f39d1b3SJooyung Han   for (int i = 0; i < rows; i++) {
1347*5f39d1b3SJooyung Han     col_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1348*5f39d1b3SJooyung Han   }
1349*5f39d1b3SJooyung Han   typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1350*5f39d1b3SJooyung Han   ColVectorMap col_vector_map(col_vector_data.data(), rows);
1351*5f39d1b3SJooyung Han   OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1352*5f39d1b3SJooyung Han   col_bias_addition_stage.bias_vector = col_vector_map;
1353*5f39d1b3SJooyung Han   auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1354*5f39d1b3SJooyung Han   Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1355*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1356*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1357*5f39d1b3SJooyung Han       lhs_offset, rhs_offset, col_bias_addition_pipeline);
1358*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1359*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1360*5f39d1b3SJooyung Han       std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1361*5f39d1b3SJooyung Han       Check(expected == result_of_col_bias_addition(r, c));
1362*5f39d1b3SJooyung Han     }
1363*5f39d1b3SJooyung Han   }
1364*5f39d1b3SJooyung Han 
1365*5f39d1b3SJooyung Han   // Test a clamp
1366*5f39d1b3SJooyung Han   OutputStageClamp clamp_stage;
1367*5f39d1b3SJooyung Han   // Determine min and max of raw int32 accumulators
1368*5f39d1b3SJooyung Han   std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1369*5f39d1b3SJooyung Han   std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1370*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1371*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1372*5f39d1b3SJooyung Han       raw_min = std::min(raw_min, result_raw_int32(r, c));
1373*5f39d1b3SJooyung Han       raw_max = std::max(raw_max, result_raw_int32(r, c));
1374*5f39d1b3SJooyung Han     }
1375*5f39d1b3SJooyung Han   }
1376*5f39d1b3SJooyung Han   // Pick some interesting clamp min/max bounds
1377*5f39d1b3SJooyung Han   clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1378*5f39d1b3SJooyung Han   clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1379*5f39d1b3SJooyung Han   assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1380*5f39d1b3SJooyung Han          clamp_stage.max <= raw_max);
1381*5f39d1b3SJooyung Han   auto clamp_pipeline = std::make_tuple(clamp_stage);
1382*5f39d1b3SJooyung Han   Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1383*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1384*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1385*5f39d1b3SJooyung Han       rhs_offset, clamp_pipeline);
1386*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1387*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1388*5f39d1b3SJooyung Han       std::int32_t raw = result_raw_int32(r, c);
1389*5f39d1b3SJooyung Han       std::int32_t expected =
1390*5f39d1b3SJooyung Han           std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1391*5f39d1b3SJooyung Han       Check(expected == result_clamped(r, c));
1392*5f39d1b3SJooyung Han     }
1393*5f39d1b3SJooyung Han   }
1394*5f39d1b3SJooyung Han 
1395*5f39d1b3SJooyung Han   // Test tanh
1396*5f39d1b3SJooyung Han   OutputStageTanh tanh_stage;
1397*5f39d1b3SJooyung Han   const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1398*5f39d1b3SJooyung Han   const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1399*5f39d1b3SJooyung Han   tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1400*5f39d1b3SJooyung Han   tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1401*5f39d1b3SJooyung Han   auto tanh_pipeline = std::make_tuple(tanh_stage);
1402*5f39d1b3SJooyung Han   Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1403*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1404*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1405*5f39d1b3SJooyung Han       rhs_offset, tanh_pipeline);
1406*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1407*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1408*5f39d1b3SJooyung Han       std::int32_t raw = result_raw_int32(r, c);
1409*5f39d1b3SJooyung Han       double real_input =
1410*5f39d1b3SJooyung Han           double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1411*5f39d1b3SJooyung Han       double expected = std::tanh(real_input);
1412*5f39d1b3SJooyung Han       std::int32_t actual_int32 = result_tanh(r, c);
1413*5f39d1b3SJooyung Han       double actual =
1414*5f39d1b3SJooyung Han           double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1415*5f39d1b3SJooyung Han       Check(std::abs(expected - actual) < 2e-4);
1416*5f39d1b3SJooyung Han     }
1417*5f39d1b3SJooyung Han   }
1418*5f39d1b3SJooyung Han 
1419*5f39d1b3SJooyung Han   // Test a pipeline with bias and clamp
1420*5f39d1b3SJooyung Han   auto bias_clamp_pipeline =
1421*5f39d1b3SJooyung Han       std::make_tuple(col_bias_addition_stage, clamp_stage);
1422*5f39d1b3SJooyung Han   Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1423*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1424*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1425*5f39d1b3SJooyung Han       lhs_offset, rhs_offset, bias_clamp_pipeline);
1426*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1427*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1428*5f39d1b3SJooyung Han       std::int32_t raw = result_raw_int32(r, c);
1429*5f39d1b3SJooyung Han       std::int32_t biased = raw + col_vector_data[r];
1430*5f39d1b3SJooyung Han       std::int32_t expected =
1431*5f39d1b3SJooyung Han           std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1432*5f39d1b3SJooyung Han       Check(expected == result_biased_clamped(r, c));
1433*5f39d1b3SJooyung Han     }
1434*5f39d1b3SJooyung Han   }
1435*5f39d1b3SJooyung Han 
1436*5f39d1b3SJooyung Han   // Test a full pipeline with bias and clamp and quantization down to 8bit
1437*5f39d1b3SJooyung Han   // result
1438*5f39d1b3SJooyung Han   auto bias_clamp_quantize_cast_pipeline =
1439*5f39d1b3SJooyung Han       std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1440*5f39d1b3SJooyung Han                       saturating_cast_stage);
1441*5f39d1b3SJooyung Han   Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1442*5f39d1b3SJooyung Han       rows, cols);
1443*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1444*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(),
1445*5f39d1b3SJooyung Han       &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1446*5f39d1b3SJooyung Han       bias_clamp_quantize_cast_pipeline);
1447*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1448*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1449*5f39d1b3SJooyung Han       std::int32_t quantized = RoundingDivideByPOT(
1450*5f39d1b3SJooyung Han           (result_biased_clamped(r, c) + result_offset) * result_mult_int,
1451*5f39d1b3SJooyung Han           result_shift);
1452*5f39d1b3SJooyung Han       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1453*5f39d1b3SJooyung Han       Check(expected == result_biased_clamped_quantized_casted(r, c));
1454*5f39d1b3SJooyung Han     }
1455*5f39d1b3SJooyung Han   }
1456*5f39d1b3SJooyung Han 
1457*5f39d1b3SJooyung Han   // Test a pipeline with the fixed-point-multiplier variant stage for the
1458*5f39d1b3SJooyung Han   // quantizing down of 32bit accumulators.
1459*5f39d1b3SJooyung Han   //
1460*5f39d1b3SJooyung Han   // First, figure appropriate fixedpoint multiplier and shift values.
1461*5f39d1b3SJooyung Han   std::int32_t result_fixedpoint_multiplier = result_mult_int;
1462*5f39d1b3SJooyung Han   std::int32_t result_fixedpoint_shift = result_shift;
1463*5f39d1b3SJooyung Han   Check(result_mult_int > 0);
1464*5f39d1b3SJooyung Han   Check(result_shift > 0);
1465*5f39d1b3SJooyung Han   result_fixedpoint_multiplier = result_mult_int;
1466*5f39d1b3SJooyung Han   result_fixedpoint_shift = result_shift - 31;
1467*5f39d1b3SJooyung Han   while (result_fixedpoint_multiplier < (1 << 30)) {
1468*5f39d1b3SJooyung Han     result_fixedpoint_multiplier <<= 1;
1469*5f39d1b3SJooyung Han     result_fixedpoint_shift++;
1470*5f39d1b3SJooyung Han   }
1471*5f39d1b3SJooyung Han   Check(result_fixedpoint_shift >= 0);
1472*5f39d1b3SJooyung Han   // Now test OutputStageQuantizeDownInt32ByFixedPoint
1473*5f39d1b3SJooyung Han   OutputStageQuantizeDownInt32ByFixedPoint
1474*5f39d1b3SJooyung Han       quantize_down_by_fixedpoint_stage;
1475*5f39d1b3SJooyung Han   quantize_down_by_fixedpoint_stage.result_offset_after_shift =
1476*5f39d1b3SJooyung Han       static_cast<std::int32_t>(
1477*5f39d1b3SJooyung Han           round(static_cast<double>(result_offset * result_mult_int) /
1478*5f39d1b3SJooyung Han                 (1 << result_shift)));
1479*5f39d1b3SJooyung Han   quantize_down_by_fixedpoint_stage.result_fixedpoint_multiplier =
1480*5f39d1b3SJooyung Han       result_fixedpoint_multiplier;
1481*5f39d1b3SJooyung Han   quantize_down_by_fixedpoint_stage.result_shift = result_fixedpoint_shift;
1482*5f39d1b3SJooyung Han   auto quantize_down_by_fixedpoint_pipeline =
1483*5f39d1b3SJooyung Han       std::make_tuple(quantize_down_by_fixedpoint_stage);
1484*5f39d1b3SJooyung Han   Matrix<std::int32_t, ResultOrder> result_quantized_down_by_fixedpoint_int32(
1485*5f39d1b3SJooyung Han       rows, cols);
1486*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1487*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(),
1488*5f39d1b3SJooyung Han       &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
1489*5f39d1b3SJooyung Han       quantize_down_by_fixedpoint_pipeline);
1490*5f39d1b3SJooyung Han 
1491*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1492*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1493*5f39d1b3SJooyung Han       const std::int32_t actual =
1494*5f39d1b3SJooyung Han           result_quantized_down_by_fixedpoint_int32(r, c);
1495*5f39d1b3SJooyung Han       const std::int32_t raw = result_raw_int32(r, c);
1496*5f39d1b3SJooyung Han       const std::int32_t expected =
1497*5f39d1b3SJooyung Han           quantize_down_by_fixedpoint_stage.result_offset_after_shift +
1498*5f39d1b3SJooyung Han           RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
1499*5f39d1b3SJooyung Han                                   raw, result_fixedpoint_multiplier),
1500*5f39d1b3SJooyung Han                               result_fixedpoint_shift);
1501*5f39d1b3SJooyung Han       Check(actual == expected);
1502*5f39d1b3SJooyung Han     }
1503*5f39d1b3SJooyung Han   }
1504*5f39d1b3SJooyung Han 
1505*5f39d1b3SJooyung Han   // Test OutputStageScaleInt32ByFixedPointAndExponent
1506*5f39d1b3SJooyung Han   for (int exponent = -2; exponent <= 2; exponent++) {
1507*5f39d1b3SJooyung Han     OutputStageScaleInt32ByFixedPointAndExponent
1508*5f39d1b3SJooyung Han         scale_by_fixedpoint_and_exponent_stage;
1509*5f39d1b3SJooyung Han     scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift =
1510*5f39d1b3SJooyung Han         static_cast<std::int32_t>(round(static_cast<double>(
1511*5f39d1b3SJooyung Han             result_offset * result_mult_int * std::pow(2.0, exponent))));
1512*5f39d1b3SJooyung Han     scale_by_fixedpoint_and_exponent_stage.result_fixedpoint_multiplier =
1513*5f39d1b3SJooyung Han         result_fixedpoint_multiplier;
1514*5f39d1b3SJooyung Han     scale_by_fixedpoint_and_exponent_stage.result_exponent = exponent;
1515*5f39d1b3SJooyung Han     auto scale_by_fixedpoint_and_exponent_pipeline =
1516*5f39d1b3SJooyung Han         std::make_tuple(scale_by_fixedpoint_and_exponent_stage);
1517*5f39d1b3SJooyung Han     Matrix<std::int32_t, ResultOrder>
1518*5f39d1b3SJooyung Han         result_scaled_by_fixedpoint_and_exponent_int32(rows, cols);
1519*5f39d1b3SJooyung Han     GemmWithOutputPipeline<std::uint8_t, std::int32_t,
1520*5f39d1b3SJooyung Han                            DefaultL8R8BitDepthParams>(
1521*5f39d1b3SJooyung Han         &context, lhs.const_map(), rhs.const_map(),
1522*5f39d1b3SJooyung Han         &result_scaled_by_fixedpoint_and_exponent_int32, lhs_offset, rhs_offset,
1523*5f39d1b3SJooyung Han         scale_by_fixedpoint_and_exponent_pipeline);
1524*5f39d1b3SJooyung Han 
1525*5f39d1b3SJooyung Han     for (int r = 0; r < rows; r++) {
1526*5f39d1b3SJooyung Han       for (int c = 0; c < cols; c++) {
1527*5f39d1b3SJooyung Han         const std::int32_t actual =
1528*5f39d1b3SJooyung Han             result_scaled_by_fixedpoint_and_exponent_int32(r, c);
1529*5f39d1b3SJooyung Han         const std::int32_t raw = result_raw_int32(r, c);
1530*5f39d1b3SJooyung Han         int left_shift = std::max(0, exponent);
1531*5f39d1b3SJooyung Han         int right_shift = std::max(0, -exponent);
1532*5f39d1b3SJooyung Han         const std::int32_t expected =
1533*5f39d1b3SJooyung Han             scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift +
1534*5f39d1b3SJooyung Han             RoundingDivideByPOT(
1535*5f39d1b3SJooyung Han                 SaturatingRoundingDoublingHighMul((1 << left_shift) * raw,
1536*5f39d1b3SJooyung Han                                                   result_fixedpoint_multiplier),
1537*5f39d1b3SJooyung Han                 right_shift);
1538*5f39d1b3SJooyung Han         Check(actual == expected);
1539*5f39d1b3SJooyung Han       }
1540*5f39d1b3SJooyung Han     }
1541*5f39d1b3SJooyung Han   }
1542*5f39d1b3SJooyung Han 
1543*5f39d1b3SJooyung Han   // Test the variant of the familiar default pipeline consisting of
1544*5f39d1b3SJooyung Han   // quantize-down and
1545*5f39d1b3SJooyung Han   // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
1546*5f39d1b3SJooyung Han   // downscaling.
1547*5f39d1b3SJooyung Han   auto quantize_down_by_fixedpoint_and_saturating_cast_pipeline =
1548*5f39d1b3SJooyung Han       std::make_tuple(quantize_down_by_fixedpoint_stage, saturating_cast_stage);
1549*5f39d1b3SJooyung Han   Matrix<std::uint8_t, ResultOrder>
1550*5f39d1b3SJooyung Han       result_quantized_down_by_fixedpoint_saturated_uint8(rows, cols);
1551*5f39d1b3SJooyung Han   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1552*5f39d1b3SJooyung Han       &context, lhs.const_map(), rhs.const_map(),
1553*5f39d1b3SJooyung Han       &result_quantized_down_by_fixedpoint_saturated_uint8, lhs_offset,
1554*5f39d1b3SJooyung Han       rhs_offset, quantize_down_by_fixedpoint_and_saturating_cast_pipeline);
1555*5f39d1b3SJooyung Han 
1556*5f39d1b3SJooyung Han   for (int r = 0; r < rows; r++) {
1557*5f39d1b3SJooyung Han     for (int c = 0; c < cols; c++) {
1558*5f39d1b3SJooyung Han       std::int32_t quantized = result_quantized_down_by_fixedpoint_int32(r, c);
1559*5f39d1b3SJooyung Han       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1560*5f39d1b3SJooyung Han       Check(expected ==
1561*5f39d1b3SJooyung Han             result_quantized_down_by_fixedpoint_saturated_uint8(r, c));
1562*5f39d1b3SJooyung Han     }
1563*5f39d1b3SJooyung Han   }
1564*5f39d1b3SJooyung Han 
1565*5f39d1b3SJooyung Han   printf("TestOutputStages: PASS with ResultOrder=%s\n",
1566*5f39d1b3SJooyung Han          OrderName(ResultOrder));
1567*5f39d1b3SJooyung Han }
1568*5f39d1b3SJooyung Han 
1569*5f39d1b3SJooyung Han #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1570*5f39d1b3SJooyung Han template <typename BitDepthParams>
TestExhaustively()1571*5f39d1b3SJooyung Han void TestExhaustively() {
1572*5f39d1b3SJooyung Han   GemmContext context;
1573*5f39d1b3SJooyung Han 
1574*5f39d1b3SJooyung Han   // Test the internal GEMM interfaces
1575*5f39d1b3SJooyung Han   test_gemm<
1576*5f39d1b3SJooyung Han       SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1577*5f39d1b3SJooyung Han                               std::uint8_t, BitDepthParams>>(&context);
1578*5f39d1b3SJooyung Han 
1579*5f39d1b3SJooyung Han   test_gemm<
1580*5f39d1b3SJooyung Han       MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1581*5f39d1b3SJooyung Han                              std::uint8_t, BitDepthParams>>(&context);
1582*5f39d1b3SJooyung Han 
1583*5f39d1b3SJooyung Han   // Test the public GEMM interfaces
1584*5f39d1b3SJooyung Han   test_gemm<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1585*5f39d1b3SJooyung Han 
1586*5f39d1b3SJooyung Han   // Test GEMV cases (internal interfaces)
1587*5f39d1b3SJooyung Han   test_gemv<
1588*5f39d1b3SJooyung Han       SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1589*5f39d1b3SJooyung Han                               std::uint8_t, BitDepthParams>>(&context);
1590*5f39d1b3SJooyung Han 
1591*5f39d1b3SJooyung Han   test_gemv<
1592*5f39d1b3SJooyung Han       MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1593*5f39d1b3SJooyung Han                              std::uint8_t, BitDepthParams>>(&context);
1594*5f39d1b3SJooyung Han 
1595*5f39d1b3SJooyung Han   // Test GEMV cases (public interfaces)
1596*5f39d1b3SJooyung Han   test_gemv<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1597*5f39d1b3SJooyung Han }
1598*5f39d1b3SJooyung Han 
1599*5f39d1b3SJooyung Han template <eight_bit_int_gemm::BitDepthSetting BitDepthSetting>
TestExhaustivelyEightBitIntGemm()1600*5f39d1b3SJooyung Han void TestExhaustivelyEightBitIntGemm() {
1601*5f39d1b3SJooyung Han   GemmContext context;
1602*5f39d1b3SJooyung Han   test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1603*5f39d1b3SJooyung Han   test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1604*5f39d1b3SJooyung Han   test_gemm<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1605*5f39d1b3SJooyung Han }
1606*5f39d1b3SJooyung Han 
TestKernels()1607*5f39d1b3SJooyung Han void TestKernels() {
1608*5f39d1b3SJooyung Han   GemmContext context;
1609*5f39d1b3SJooyung Han 
1610*5f39d1b3SJooyung Han   // Test specific kernels with various different formats,
1611*5f39d1b3SJooyung Han   // to exercises corner cases especially in the packing code.
1612*5f39d1b3SJooyung Han   test_gemm_kernel<
1613*5f39d1b3SJooyung Han       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1614*5f39d1b3SJooyung Han                                    KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1615*5f39d1b3SJooyung Han       &context);
1616*5f39d1b3SJooyung Han 
1617*5f39d1b3SJooyung Han   test_gemm_kernel<
1618*5f39d1b3SJooyung Han       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1619*5f39d1b3SJooyung Han                                    KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1620*5f39d1b3SJooyung Han       &context);
1621*5f39d1b3SJooyung Han 
1622*5f39d1b3SJooyung Han   test_gemm_kernel<
1623*5f39d1b3SJooyung Han       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1624*5f39d1b3SJooyung Han                                    KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1625*5f39d1b3SJooyung Han       &context);
1626*5f39d1b3SJooyung Han 
1627*5f39d1b3SJooyung Han   test_gemm_kernel<ReferenceKernel<KernelFormat<
1628*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1629*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1630*5f39d1b3SJooyung Han 
1631*5f39d1b3SJooyung Han   test_gemm_kernel<ReferenceKernel<KernelFormat<
1632*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1633*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1634*5f39d1b3SJooyung Han 
1635*5f39d1b3SJooyung Han   test_gemm_kernel<ReferenceKernel<KernelFormat<
1636*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1637*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1638*5f39d1b3SJooyung Han 
1639*5f39d1b3SJooyung Han   test_gemm_kernel<ReferenceKernel<KernelFormat<
1640*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1641*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1642*5f39d1b3SJooyung Han 
1643*5f39d1b3SJooyung Han   test_gemm_kernel<ReferenceKernel<KernelFormat<
1644*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1645*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1646*5f39d1b3SJooyung Han 
1647*5f39d1b3SJooyung Han   test_gemm_kernel<ReferenceKernel<KernelFormat<
1648*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1649*5f39d1b3SJooyung Han       KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1650*5f39d1b3SJooyung Han }
1651*5f39d1b3SJooyung Han 
1652*5f39d1b3SJooyung Han #endif  // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1653*5f39d1b3SJooyung Han 
1654*5f39d1b3SJooyung Han template <typename BitDepthParams>
TestOutputStages()1655*5f39d1b3SJooyung Han void TestOutputStages() {
1656*5f39d1b3SJooyung Han   // Test non-default output pipelines with various combinations of
1657*5f39d1b3SJooyung Han   // output stages.
1658*5f39d1b3SJooyung Han   TestOutputStages<BitDepthParams, MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1659*5f39d1b3SJooyung Han   TestOutputStages<BitDepthParams, MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1660*5f39d1b3SJooyung Han   TestOutputStages<BitDepthParams, MapOrder::RowMajor>(630, 10, 1270, 5, 17,
1661*5f39d1b3SJooyung Han                                                        14);
1662*5f39d1b3SJooyung Han   TestOutputStages<BitDepthParams, MapOrder::ColMajor>(630, 10, 1270, 5, 17,
1663*5f39d1b3SJooyung Han                                                        14);
1664*5f39d1b3SJooyung Han }
1665*5f39d1b3SJooyung Han 
test()1666*5f39d1b3SJooyung Han void test() {
1667*5f39d1b3SJooyung Han #ifdef GEMMLOWP_TEST_PROFILE
1668*5f39d1b3SJooyung Han   RegisterCurrentThreadForProfiling();
1669*5f39d1b3SJooyung Han   StartProfiling();
1670*5f39d1b3SJooyung Han #endif
1671*5f39d1b3SJooyung Han 
1672*5f39d1b3SJooyung Han   // Run a first quick test against hand-calculated data.
1673*5f39d1b3SJooyung Han   TestWithSmallData();
1674*5f39d1b3SJooyung Han 
1675*5f39d1b3SJooyung Han #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1676*5f39d1b3SJooyung Han   TestExhaustively<DefaultL8R8BitDepthParams>();
1677*5f39d1b3SJooyung Han   TestExhaustively<L8R8WithLhsNonzeroBitDepthParams>();
1678*5f39d1b3SJooyung Han   TestExhaustively<DefaultL7R5BitDepthParams>();  // legacy, same as L8R8
1679*5f39d1b3SJooyung Han   TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A8B8>();
1680*5f39d1b3SJooyung Han   TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A5B7>();
1681*5f39d1b3SJooyung Han   TestKernels();
1682*5f39d1b3SJooyung Han #endif
1683*5f39d1b3SJooyung Han 
1684*5f39d1b3SJooyung Han   // Run against actual data from a network evaluation.
1685*5f39d1b3SJooyung Han   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1686*5f39d1b3SJooyung Han   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1687*5f39d1b3SJooyung Han 
1688*5f39d1b3SJooyung Han   // Test non-default output pipelines with various combinations of
1689*5f39d1b3SJooyung Han   // output stages.
1690*5f39d1b3SJooyung Han   TestOutputStages<DefaultL8R8BitDepthParams>();
1691*5f39d1b3SJooyung Han   TestOutputStages<L8R8WithLhsNonzeroBitDepthParams>();
1692*5f39d1b3SJooyung Han 
1693*5f39d1b3SJooyung Han   // Test per channel quantization.
1694*5f39d1b3SJooyung Han   TestWithSmallDataPerChannelQuantization();
1695*5f39d1b3SJooyung Han   TestWithLargeDataPerChannelQuantization();
1696*5f39d1b3SJooyung Han   TestMultithreadedPerChannelQuantization();
1697*5f39d1b3SJooyung Han #ifdef GEMMLOWP_TEST_PROFILE
1698*5f39d1b3SJooyung Han   FinishProfiling();
1699*5f39d1b3SJooyung Han #endif
1700*5f39d1b3SJooyung Han 
1701*5f39d1b3SJooyung Han   std::cerr << "All tests passed." << std::endl;
1702*5f39d1b3SJooyung Han 
1703*5f39d1b3SJooyung Han   // We have been testing the eight_bit_int_gemm, so we should free its
1704*5f39d1b3SJooyung Han   // persistent
1705*5f39d1b3SJooyung Han   // resources now to avoid having leak-checking tools report leaks.
1706*5f39d1b3SJooyung Han   eight_bit_int_gemm::FreePersistentResources();
1707*5f39d1b3SJooyung Han }
1708*5f39d1b3SJooyung Han 
1709*5f39d1b3SJooyung Han }  // end namespace gemmlowp
1710*5f39d1b3SJooyung Han 
1711*5f39d1b3SJooyung Han // For iOS, we need to define our own main(), so skip it here.
1712*5f39d1b3SJooyung Han #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1713*5f39d1b3SJooyung Han int main() { gemmlowp::test(); }
1714*5f39d1b3SJooyung Han #endif
1715