xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
17 
18 #include <math.h>
19 #include <stdint.h>
20 #include <stdlib.h>
21 
22 #include <algorithm>
23 #include <iterator>
24 #include <limits>
25 #include <random>
26 #include <sstream>
27 #include <string>
28 #include <tuple>
29 #include <type_traits>
30 #include <vector>
31 
32 #include <gtest/gtest.h>
33 #include "ruy/matrix.h"  // from @ruy
34 #include "ruy/reference_mul.h"  // from @ruy
35 #include "tensorflow/lite/kernels/cpu_backend_context.h"
36 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
37 #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
38 
39 namespace tflite {
40 
41 namespace {
42 
43 using cpu_backend_gemm::Gemm;
44 using cpu_backend_gemm::GemmParams;
45 using cpu_backend_gemm::MatrixParams;
46 using cpu_backend_gemm::QuantizationFlavor;
47 
48 template <typename Scalar>
ToString(const std::vector<Scalar> & vector)49 std::string ToString(const std::vector<Scalar>& vector) {
50   std::stringstream s;
51   if (vector.empty()) {
52     s << "{}";
53   } else {
54     s << "{ " << static_cast<double>(vector[0]);
55     for (int i = 1; i < vector.size(); i++) {
56       s << ", " << static_cast<double>(vector[i]);
57     }
58     s << "}";
59   }
60   return s.str();
61 }
62 
63 template <typename Scalar>
MakeDeterministicPseudoRandomVector(int size,std::vector<Scalar> * vector)64 void MakeDeterministicPseudoRandomVector(int size,
65                                          std::vector<Scalar>* vector) {
66   // Intentionally create a new local random_engine in each invocation,
67   // so pseudorandom values don't depend on invocation order.
68   // Otherwise, test results would be affecting by e.g. filtering.
69   std::default_random_engine random_engine;
70   (void)random_engine();
71   // Do not use std::uniform*_distribution: the values that it
72   // generates are implementation-defined.
73   const double random_min = static_cast<double>(random_engine.min());
74   const double random_max = static_cast<double>(random_engine.max());
75   const double result_min =
76       std::is_floating_point<Scalar>::value
77           ? -1.0
78           : std::max(-256., static_cast<double>(
79                                 std::numeric_limits<Scalar>::lowest()));
80   const double result_max =
81       std::is_floating_point<Scalar>::value
82           ? 1.0
83           : std::min(256.,
84                      static_cast<double>(std::numeric_limits<Scalar>::max()));
85   const double random_scale =
86       (result_max - result_min) / (random_max - random_min);
87 
88   vector->resize(size);
89   for (int i = 0; i < size; i++) {
90     double val = random_scale * (random_engine() - random_min);
91     val = std::max(val,
92                    static_cast<double>(std::numeric_limits<Scalar>::lowest()));
93     val =
94         std::min(val, static_cast<double>(std::numeric_limits<Scalar>::max()));
95     (*vector)[i] = static_cast<Scalar>(val);
96   }
97 }
98 
99 template <typename Scalar>
MakeVectorFilledWithConsecutiveInts(int size,std::vector<Scalar> * vector)100 void MakeVectorFilledWithConsecutiveInts(int size,
101                                          std::vector<Scalar>* vector) {
102   vector->resize(size);
103   EXPECT_LE(size, std::numeric_limits<Scalar>::max());
104   for (int i = 0; i < size; i++) {
105     (*vector)[i] = static_cast<Scalar>(i + 1);
106   }
107 }
108 
109 template <typename Scalar>
Median(const std::vector<Scalar> & vector)110 Scalar Median(const std::vector<Scalar>& vector) {
111   EXPECT_GT(vector.size(), 0);
112   std::vector<Scalar> vector_copy = vector;
113   std::sort(std::begin(vector_copy), std::end(vector_copy));
114   return vector_copy[vector_copy.size() / 2];
115 }
116 
117 template <typename Scalar>
MedianAbs(const std::vector<Scalar> & vector)118 double MedianAbs(const std::vector<Scalar>& vector) {
119   EXPECT_GT(vector.size(), 0);
120   std::vector<double> vector_abs;
121   vector_abs.resize(vector.size());
122   for (int i = 0; i < vector.size(); i++) {
123     vector_abs[i] = std::abs(static_cast<double>(vector[i]));
124   }
125   std::sort(std::begin(vector_abs), std::end(vector_abs));
126   return vector_abs[vector_abs.size() / 2];
127 }
128 
129 template <typename Scalar>
Clamp(const std::vector<Scalar> & src,Scalar clamp_min,Scalar clamp_max,std::vector<Scalar> * dst)130 void Clamp(const std::vector<Scalar>& src, Scalar clamp_min, Scalar clamp_max,
131            std::vector<Scalar>* dst) {
132   dst->resize(src.size());
133   for (int i = 0; i < src.size(); i++) {
134     (*dst)[i] = std::max(std::min(src[i], clamp_max), clamp_min);
135   }
136 }
137 
138 template <typename AccumScalar, typename DstScalar,
139           QuantizationFlavor quantization_flavor>
Clamp(const GemmParams<AccumScalar,DstScalar,quantization_flavor> & src,DstScalar clamp_min,DstScalar clamp_max,GemmParams<AccumScalar,DstScalar,quantization_flavor> * dst)140 void Clamp(const GemmParams<AccumScalar, DstScalar, quantization_flavor>& src,
141            DstScalar clamp_min, DstScalar clamp_max,
142            GemmParams<AccumScalar, DstScalar, quantization_flavor>* dst) {
143   *dst = src;
144   dst->clamp_min = clamp_min;
145   dst->clamp_max = clamp_max;
146 }
147 
148 struct ErrorStats {
149   int size;
150   double scale_factor;
151   double max_abs_diff;
152   double mean_abs_diff;
153   double abs_mean_diff;
154 };
155 
156 template <typename Scalar>
ComputeErrorStats(const std::vector<Scalar> & actual,const std::vector<Scalar> & expected,ErrorStats * error_stats)157 void ComputeErrorStats(const std::vector<Scalar>& actual,
158                        const std::vector<Scalar>& expected,
159                        ErrorStats* error_stats) {
160   double max_abs_diff = 0;
161   double sum_abs_diff = 0;
162   double sum_diff = 0;
163   double max_abs_expected = 0;
164   EXPECT_EQ(actual.size(), expected.size());
165   for (int i = 0; i < actual.size(); i++) {
166     double actual_val = static_cast<double>(actual[i]);
167     double expected_val = static_cast<double>(expected[i]);
168     double diff = actual_val - expected_val;
169     max_abs_expected = std::max(max_abs_expected, std::abs(expected_val));
170     sum_diff += diff;
171     sum_abs_diff += std::abs(diff);
172     max_abs_diff = std::max(max_abs_diff, std::abs(diff));
173   }
174   error_stats->scale_factor = max_abs_expected;
175   error_stats->max_abs_diff = max_abs_diff;
176   error_stats->mean_abs_diff = sum_abs_diff / actual.size();
177   error_stats->abs_mean_diff = std::abs(sum_diff / actual.size());
178   error_stats->size = actual.size();
179 }
180 
181 template <typename AccumScalar, typename DstScalar>
CheckErrorStats(const ErrorStats & error_stats,int accumulation_depth)182 bool CheckErrorStats(const ErrorStats& error_stats, int accumulation_depth) {
183   double tolerated_relative_max_abs_diff = 0;
184   double tolerated_relative_mean_abs_diff = 0;
185   double tolerated_relative_abs_mean_diff = 0;
186 
187   double inverse_size = 1. / error_stats.size;
188 
189   if (std::is_floating_point<AccumScalar>::value) {
190     // Somewhat naive requirement: the worst case should be epsilons
191     // adding up towards the same direction, on values of same magnitude.
192     tolerated_relative_max_abs_diff =
193         accumulation_depth * std::numeric_limits<DstScalar>::epsilon();
194     // Naive interpretation of the Central Limit Theorem is the rationale
195     // for the sqrt here. We haven't even worked out the correct scale factor,
196     // or how applicable that theorem is here (the random variables being added
197     // might not be mutually independent).
198     tolerated_relative_mean_abs_diff =
199         std::sqrt(static_cast<double>(accumulation_depth)) *
200         std::numeric_limits<DstScalar>::epsilon();
201     // Unbiasing requirement: we require the bias, abs_mean_diff, to be much
202     // smaller than the mean_abs_diff, except when there are very few values.
203     tolerated_relative_abs_mean_diff =
204         tolerated_relative_mean_abs_diff * std::sqrt(inverse_size);
205   } else {
206     // In quantized arithmetic, tolerate minor rounding differences, resulting
207     // in off-by-one errors (tolerated_relative_max_abs_diff = 1), as long
208     // as they are rare (tolerated_relative_mean_abs_diff) and unbiased
209     // (tolerated_relative_abs_mean_diff).
210     tolerated_relative_max_abs_diff = 1;
211     // Naively require mean_abs_diff and abs_mean_diff to converge to zero
212     // as size gets large. We don't know at all how quick that convergence
213     // should be: this is just based on trial-and-error and striking a
214     // compromise between something that works and something that's simple
215     // enough code that doesn't feel too ad-hoc. As above in the float path,
216     // abs_mean_diff is subject to a stricter requirement as it is a bias.
217     tolerated_relative_mean_abs_diff = std::sqrt(inverse_size) * 0.5;
218     tolerated_relative_abs_mean_diff = inverse_size * 2.;
219   }
220 
221   double tolerated_max_abs_diff =
222       tolerated_relative_max_abs_diff * error_stats.scale_factor;
223   double tolerated_mean_abs_diff =
224       tolerated_relative_mean_abs_diff * error_stats.scale_factor;
225   double tolerated_abs_mean_diff =
226       tolerated_relative_abs_mean_diff * error_stats.scale_factor;
227 
228   EXPECT_LE(error_stats.max_abs_diff, tolerated_max_abs_diff);
229   EXPECT_LE(error_stats.mean_abs_diff, tolerated_mean_abs_diff);
230   EXPECT_LE(error_stats.abs_mean_diff, tolerated_abs_mean_diff);
231 
232   return error_stats.max_abs_diff <= tolerated_max_abs_diff &&
233          error_stats.mean_abs_diff <= tolerated_mean_abs_diff &&
234          error_stats.abs_mean_diff <= tolerated_abs_mean_diff;
235 }
236 
237 template <typename AccumScalar, typename DstScalar>
CheckErrorForAccumulation(int accumulation_depth,const std::vector<DstScalar> & actual,const std::vector<DstScalar> & expected)238 void CheckErrorForAccumulation(int accumulation_depth,
239                                const std::vector<DstScalar>& actual,
240                                const std::vector<DstScalar>& expected) {
241   ErrorStats error_stats;
242   ComputeErrorStats(actual, expected, &error_stats);
243   bool success =
244       CheckErrorStats<AccumScalar, DstScalar>(error_stats, accumulation_depth);
245   EXPECT_TRUE(success) << "Actual vector\n"
246                        << ToString(actual) << "\ndiffers from expected vector\n"
247                        << ToString(expected) << "\n";
248 }
249 
250 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
251           typename DstScalar, QuantizationFlavor quantization_flavor>
PerformGemmThenCompareResultsThenAgainWithClamping(const MatrixParams<LhsScalar> & lhs_params,const std::vector<LhsScalar> & lhs_data,const MatrixParams<RhsScalar> & rhs_params,const std::vector<RhsScalar> & rhs_data,const MatrixParams<DstScalar> & dst_params,std::vector<DstScalar> * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,const std::vector<DstScalar> & expected,CpuBackendContext * cpu_backend_context)252 void PerformGemmThenCompareResultsThenAgainWithClamping(
253     const MatrixParams<LhsScalar>& lhs_params,
254     const std::vector<LhsScalar>& lhs_data,
255     const MatrixParams<RhsScalar>& rhs_params,
256     const std::vector<RhsScalar>& rhs_data,
257     const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
258     const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
259     const std::vector<DstScalar>& expected,
260     CpuBackendContext* cpu_backend_context) {
261   const int accumulation_depth = lhs_params.cols;
262   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
263        dst_data->data(), params, cpu_backend_context);
264   CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
265                                          expected);
266   DstScalar expected_median = Median(expected);
267   std::vector<DstScalar> expected_with_clamp;
268   GemmParams<AccumScalar, DstScalar, quantization_flavor> params_with_clamp;
269   DstScalar clamp_min, clamp_max;
270 
271   clamp_min = std::numeric_limits<DstScalar>::lowest();
272   clamp_max = expected_median;
273   Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
274   Clamp(params, clamp_min, clamp_max, &params_with_clamp);
275   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
276        dst_data->data(), params_with_clamp, cpu_backend_context);
277   CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
278                                          expected_with_clamp);
279 
280   clamp_min = expected_median;
281   clamp_max = std::numeric_limits<DstScalar>::max();
282   Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
283   Clamp(params, clamp_min, clamp_max, &params_with_clamp);
284   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
285        dst_data->data(), params_with_clamp, cpu_backend_context);
286   CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
287                                          expected_with_clamp);
288 }
289 
290 // When generating testcases for a quantized GEMM, it's not trivial to
291 // pick multiplier exponents: a too low value will result in too many zeros,
292 // a too high value will result in too many large clamped values, in both
293 // cases testing coverage is harmed. Therefore to ensure good testing coverage
294 // we must find a multiplier exponent that's just right.  It would be possible
295 // to do so by analysis of the random distribution of values in the result
296 // matrix. That however would require some mathematical work that we haven't
297 // done so far. Until that is done, the best that we can do is to search for
298 // a good exponent value by trial-and-error. This is expensive, as each try
299 // requires computing a whole GEMM. This is thus probably a major contribution
300 // to the overall latency of this test. To partially mitigate that,
301 // we use a bisection to reduce the required number of tries.
302 //
303 // This function is recursive. The bisect_min and bisect_max arguments
304 // are the current bisection bounds. It performs a Gemm with the mid-point,
305 // named bisect_mid, as the multiplier exponent. Based on whether the values
306 // in the resulting matrix are rather too low or too large in absolute
307 // value, it then recurses into the corresponding half of the bisection range.
308 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
309           typename DstScalar>
BisectReasonableMultiplierExponent(int bisect_min,int bisect_max,const MatrixParams<LhsScalar> & lhs_params,const std::vector<LhsScalar> & lhs_data,const MatrixParams<RhsScalar> & rhs_params,const std::vector<RhsScalar> & rhs_data,const MatrixParams<DstScalar> & dst_params,std::vector<DstScalar> * dst_data,const GemmParams<AccumScalar,DstScalar> & params,CpuBackendContext * cpu_backend_context)310 int BisectReasonableMultiplierExponent(
311     int bisect_min, int bisect_max, const MatrixParams<LhsScalar>& lhs_params,
312     const std::vector<LhsScalar>& lhs_data,
313     const MatrixParams<RhsScalar>& rhs_params,
314     const std::vector<RhsScalar>& rhs_data,
315     const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
316     const GemmParams<AccumScalar, DstScalar>& params,
317     CpuBackendContext* cpu_backend_context) {
318   if (bisect_min == bisect_max) {
319     return bisect_min;
320   }
321   // Compute the midpoint as the floor of the average of bisect_min and
322   // bisect_max. As C++ integer division is rounding towards zero and our values
323   // may be of any sign, it is not trivial to implement this using only integer
324   // arithmetic.
325   int bisect_mid =
326       static_cast<int>(std::floor(0.5 * (bisect_min + bisect_max)));
327   GemmParams<AccumScalar, DstScalar> params_copy(params);
328   params_copy.multiplier_exponent = bisect_mid;
329   double clamp_abs = std::max(std::abs(static_cast<double>(params.clamp_min)),
330                               std::abs(static_cast<double>(params.clamp_max)));
331   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
332        dst_data->data(), params_copy, cpu_backend_context);
333   double median_abs = MedianAbs(*dst_data);
334   if (median_abs < 0.25 * clamp_abs) {
335     return BisectReasonableMultiplierExponent(
336         bisect_mid + 1, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
337         dst_params, dst_data, params_copy, cpu_backend_context);
338   } else {
339     return BisectReasonableMultiplierExponent(
340         bisect_min, bisect_mid, lhs_params, lhs_data, rhs_params, rhs_data,
341         dst_params, dst_data, params_copy, cpu_backend_context);
342   }
343 }
344 
345 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
346           typename DstScalar, QuantizationFlavor quantization_flavor>
ReferenceGemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<DstScalar> & dst_params,DstScalar * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,CpuBackendContext * context)347 void ReferenceGemm(
348     const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
349     const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
350     const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
351     const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
352     CpuBackendContext* context) {
353   ruy::Matrix<LhsScalar> ruy_lhs;
354   ruy::Matrix<RhsScalar> ruy_rhs;
355   ruy::Matrix<DstScalar> ruy_dst;
356   cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs);
357   cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs);
358   cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
359 
360   ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params;
361   cpu_backend_gemm::detail::MakeRuyMulParams(params, &ruy_mul_params);
362 
363   ruy::ReferenceMul(ruy_lhs, ruy_rhs, ruy_mul_params, &ruy_dst);
364 }
365 
366 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
367           typename DstScalar>
TestSomeGemm(int rows,int depth,int cols,const std::vector<DstScalar> & golden)368 void TestSomeGemm(int rows, int depth, int cols,
369                   const std::vector<DstScalar>& golden) {
370   CpuBackendContext cpu_backend_context;
371   std::default_random_engine random_engine;
372   cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
373   bool use_caching = static_cast<bool>(random_engine() % 2);
374   cpu_backend_context.SetUseCaching(use_caching);
375   const bool use_golden = !golden.empty();
376 
377   std::vector<LhsScalar> lhs_data;
378   std::vector<RhsScalar> rhs_data;
379   std::vector<AccumScalar> bias_data;
380   std::vector<DstScalar> dst_data;
381   if (use_golden) {
382     MakeVectorFilledWithConsecutiveInts(rows * depth, &lhs_data);
383     MakeVectorFilledWithConsecutiveInts(depth * cols, &rhs_data);
384     MakeVectorFilledWithConsecutiveInts(rows, &bias_data);
385   } else {
386     MakeDeterministicPseudoRandomVector(rows * depth, &lhs_data);
387     MakeDeterministicPseudoRandomVector(depth * cols, &rhs_data);
388     MakeDeterministicPseudoRandomVector(rows, &bias_data);
389   }
390   MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
391 
392   auto random_order = [&]() {
393     return random_engine() % 2 ? cpu_backend_gemm::Order::kRowMajor
394                                : cpu_backend_gemm::Order::kColMajor;
395   };
396   MatrixParams<LhsScalar> lhs_params;
397   lhs_params.order =
398       use_golden ? cpu_backend_gemm::Order::kRowMajor : random_order();
399   lhs_params.rows = rows;
400   lhs_params.cols = depth;
401   // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
402   // zero_point as 0 because int16 might cause overflow in acuum int32.
403   // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
404   if (!std::is_floating_point<LhsScalar>::value &&
405       (!std::is_same<LhsScalar, int8_t>::value &&
406        !std::is_same<RhsScalar, int16_t>::value)) {
407     lhs_params.zero_point = 1;
408     if (!use_golden) {
409       lhs_params.zero_point += random_engine() % 8;
410     }
411   }
412 
413   MatrixParams<RhsScalar> rhs_params;
414   rhs_params.order =
415       use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
416   rhs_params.rows = depth;
417   rhs_params.cols = cols;
418   // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
419   // zero_point as 0 because int16 might cause overflow in acuum int32.
420   // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
421   if (!std::is_floating_point<RhsScalar>::value &&
422       (!std::is_same<LhsScalar, int8_t>::value &&
423        !std::is_same<RhsScalar, int16_t>::value)) {
424     rhs_params.zero_point = 1;
425     if (!use_golden) {
426       rhs_params.zero_point += random_engine() % 8;
427     }
428   }
429 
430   MatrixParams<DstScalar> dst_params;
431   dst_params.order =
432       use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
433   dst_params.rows = rows;
434   dst_params.cols = cols;
435   // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
436   // zero_point as 0 because int16 might cause overflow in acuum int32.
437   // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
438   if (!std::is_floating_point<DstScalar>::value &&
439       (!std::is_same<LhsScalar, int8_t>::value &&
440        !std::is_same<RhsScalar, int16_t>::value)) {
441     dst_params.zero_point = 1;
442     if (!use_golden) {
443       dst_params.zero_point += random_engine() % 8;
444     }
445   }
446 
447   GemmParams<AccumScalar, DstScalar> params;
448   if (use_golden || (random_engine() % 2)) {
449     // cpu_backend_gemm supports bias=null only in the float path. Test that
450     // in 50% of float testcases.
451     params.bias = bias_data.data();
452   }
453   static constexpr std::int32_t kMultiplierFixedpointMin = 1234567890;
454   static constexpr std::int32_t kMultiplierFixedpointMax = 1987654321;
455   if (!std::is_floating_point<AccumScalar>::value) {
456     // some large int32 value. Not being a multiple of a large
457     // power of two helps testing rounding behavior.
458     params.multiplier_fixedpoint = kMultiplierFixedpointMin;
459     // Now find a suitable value for multiplier_exponent.
460     // It needs to be low enough for a substantial amount of dst values
461     // to avoid getting clamped.
462     int bisect_min = -8 * static_cast<int>(sizeof(AccumScalar));
463     // We don't increase test coverage by using positive multipliers,
464     // and using very large positive multipliers may at the moment
465     // result in overflow in some paths.
466     // TODO(benoitjacob): fix that.
467     int bisect_max = 0;
468     params.multiplier_exponent = BisectReasonableMultiplierExponent(
469         bisect_min, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
470         dst_params, &dst_data, params, &cpu_backend_context);
471   }
472 
473   std::vector<DstScalar> expected;
474   if (use_golden) {
475     EXPECT_EQ(golden.size(), dst_data.size());
476     expected = golden;
477   } else {
478     expected.resize(dst_data.size());
479     ReferenceGemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(),
480                   dst_params, expected.data(), params, &cpu_backend_context);
481   }
482 
483   PerformGemmThenCompareResultsThenAgainWithClamping(
484       lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data, params,
485       expected, &cpu_backend_context);
486 
487   if (!use_golden && !std::is_floating_point<AccumScalar>::value) {
488     // Try with per-channel quantized multipliers.
489     std::vector<AccumScalar> multiplier_fixedpoint_perchannel(rows);
490     std::vector<int> multiplier_exponent_perchannel(rows);
491     for (int i = 0; i < rows; i++) {
492       multiplier_fixedpoint_perchannel[i] =
493           kMultiplierFixedpointMin +
494           (random_engine() %
495            (kMultiplierFixedpointMax + 1 - kMultiplierFixedpointMin));
496       const int exponent_min = params.multiplier_exponent - 2;
497       const int exponent_max = params.multiplier_exponent + 2;
498       multiplier_exponent_perchannel[i] =
499           exponent_min + (random_engine() % (exponent_max + 1 - exponent_min));
500     }
501     static constexpr QuantizationFlavor perchannel_flavor =
502         std::is_floating_point<AccumScalar>::value
503             ? QuantizationFlavor::kFloatingPoint
504             : QuantizationFlavor::kIntegerWithPerRowMultiplier;
505     GemmParams<AccumScalar, DstScalar, perchannel_flavor> params_perchannel;
506     params_perchannel.bias = params.bias;
507     params_perchannel.clamp_min = params.clamp_min;
508     params_perchannel.clamp_max = params.clamp_max;
509     params_perchannel.multiplier_fixedpoint_perchannel =
510         multiplier_fixedpoint_perchannel.data();
511     params_perchannel.multiplier_exponent_perchannel =
512         multiplier_exponent_perchannel.data();
513     ReferenceGemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(),
514                   dst_params, expected.data(), params_perchannel,
515                   &cpu_backend_context);
516     PerformGemmThenCompareResultsThenAgainWithClamping(
517         lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data,
518         params_perchannel, expected, &cpu_backend_context);
519   }
520 }
521 
522 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
523           typename DstScalar>
TestMaybeValidGemm(int lhs_rows,int lhs_cols,int rhs_rows,int rhs_cols,int dst_rows,int dst_cols)524 void TestMaybeValidGemm(int lhs_rows, int lhs_cols, int rhs_rows, int rhs_cols,
525                         int dst_rows, int dst_cols) {
526   CpuBackendContext cpu_backend_context;
527   std::default_random_engine random_engine;
528   cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
529   bool use_caching = static_cast<bool>(random_engine() % 2);
530   cpu_backend_context.SetUseCaching(use_caching);
531 
532   std::vector<LhsScalar> lhs_data;
533   std::vector<RhsScalar> rhs_data;
534   std::vector<AccumScalar> bias_data;
535   std::vector<DstScalar> dst_data;
536   MakeDeterministicPseudoRandomVector(lhs_rows * lhs_cols, &lhs_data);
537   MakeDeterministicPseudoRandomVector(rhs_rows * rhs_cols, &rhs_data);
538   MakeDeterministicPseudoRandomVector(dst_rows, &bias_data);
539   MakeDeterministicPseudoRandomVector(dst_rows * dst_cols, &dst_data);
540 
541   MatrixParams<LhsScalar> lhs_params;
542   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
543   lhs_params.rows = lhs_rows;
544   lhs_params.cols = lhs_cols;
545   // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
546   // zero_point as 0 because int16 might cause overflow in acuum int32.
547   // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
548   if (!std::is_floating_point<LhsScalar>::value &&
549       (!std::is_same<LhsScalar, int8_t>::value &&
550        !std::is_same<RhsScalar, int16_t>::value)) {
551     lhs_params.zero_point = 1;
552   }
553 
554   MatrixParams<RhsScalar> rhs_params;
555   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
556   rhs_params.rows = rhs_rows;
557   rhs_params.cols = rhs_cols;
558   // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
559   // zero_point as 0 because int16 might cause overflow in acuum int32.
560   // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
561   if (!std::is_floating_point<RhsScalar>::value &&
562       (!std::is_same<LhsScalar, int8_t>::value &&
563        !std::is_same<RhsScalar, int16_t>::value)) {
564     rhs_params.zero_point = 1;
565   }
566 
567   MatrixParams<DstScalar> dst_params;
568   dst_params.order = cpu_backend_gemm::Order::kColMajor;
569   dst_params.rows = dst_rows;
570   dst_params.cols = dst_cols;
571   // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
572   // zero_point as 0 because int16 might cause overflow in acuum int32.
573   // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
574   if (!std::is_floating_point<DstScalar>::value &&
575       (!std::is_same<LhsScalar, int8_t>::value &&
576        !std::is_same<RhsScalar, int16_t>::value)) {
577     dst_params.zero_point = 1;
578   }
579 
580   GemmParams<AccumScalar, DstScalar> params;
581   params.bias = bias_data.data();
582   static constexpr std::int32_t kMultiplierFixedpointMin = 1234567890;
583   if (!std::is_floating_point<AccumScalar>::value) {
584     // some large int32 value. Not being a multiple of a large
585     // power of two helps testing rounding behavior.
586     params.multiplier_fixedpoint = kMultiplierFixedpointMin;
587     // Now find a suitable value for multiplier_exponent.
588     // It needs to be low enough for a substantial amount of dst values
589     // to avoid getting clamped.
590     int bisect_min = -8 * static_cast<int>(sizeof(AccumScalar));
591     // We don't increase test coverage by using positive multipliers,
592     // and using very large positive multipliers may at the moment
593     // result in overflow in some paths.
594     int bisect_max = 0;
595     params.multiplier_exponent = BisectReasonableMultiplierExponent(
596         bisect_min, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
597         dst_params, &dst_data, params, &cpu_backend_context);
598   }
599   Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
600        dst_data.data(), params, &cpu_backend_context);
601 }
602 
TEST(CpuBackendGemmSimpleTestAgainstGolden,Float)603 TEST(CpuBackendGemmSimpleTestAgainstGolden, Float) {
604   TestSomeGemm<float, float, float, float>(2, 3, 4,
605                                            {15, 34, 33, 79, 51, 124, 69, 169});
606 }
607 
TEST(CpuBackendGemmSimpleTestAgainstGolden,Uint8)608 TEST(CpuBackendGemmSimpleTestAgainstGolden, Uint8) {
609   TestSomeGemm<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>(
610       5, 2, 3, {2, 4, 6, 7, 9, 3, 10, 16, 22, 29, 4, 15, 26, 37, 48});
611 }
612 
TEST(CpuBackendGemmSimpleTestAgainstGolden,Int8)613 TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8) {
614   TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int8_t>(
615       2, 6, 3, {13, 32, 31, 81, 50, 127});
616 }
617 
TEST(CpuBackendGemmInvalidGemmTest,Float)618 TEST(CpuBackendGemmInvalidGemmTest, Float) {
619   // A standard Gemm operation.
620   TestMaybeValidGemm<float, float, float, float>(2, 3, 3, 4, 2, 4);
621   // An invalid Gemm that will abort in debug mode.
622 #if !defined(TARGET_IPHONE_SIMULATOR) && !defined(TARGET_OS_IPHONE)
623   ASSERT_DEBUG_DEATH(
624       (TestMaybeValidGemm<float, float, float, float>(2, 3, 3, 0, 2, 4)), "");
625   ASSERT_DEBUG_DEATH(
626       (TestMaybeValidGemm<float, float, float, float>(2, 3, 9, 4, 2, 4)), "");
627 #endif
628 }
629 
TEST(CpuBackendGemmSimpleTestAgainstGolden,Int8Int16)630 TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8Int16) {
631   TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int16_t>(
632       3, 5, 4, {19, 48, 77, 48, 149, 250, 76, 249, 422, 105, 350, 595});
633 }
634 
635 template <typename tLhsScalar, typename tRhsScalar, typename tAccumScalar,
636           typename tDstScalar>
637 struct TypesTuple {
638   using LhsScalar = tLhsScalar;
639   using RhsScalar = tRhsScalar;
640   using AccumScalar = tAccumScalar;
641   using DstScalar = tDstScalar;
642 };
643 
644 template <typename TypesTupleType>
TestRandomGemms(const std::vector<std::tuple<int,int,int>> & shapes)645 void TestRandomGemms(const std::vector<std::tuple<int, int, int>>& shapes) {
646   using LhsScalar = typename TypesTupleType::LhsScalar;
647   using RhsScalar = typename TypesTupleType::RhsScalar;
648   using AccumScalar = typename TypesTupleType::AccumScalar;
649   using DstScalar = typename TypesTupleType::DstScalar;
650   for (const auto& shape : shapes) {
651     int rows = std::get<0>(shape);
652     int depth = std::get<1>(shape);
653     int cols = std::get<2>(shape);
654     TestSomeGemm<LhsScalar, RhsScalar, AccumScalar, DstScalar>(rows, depth,
655                                                                cols, {});
656   }
657 }
658 
659 template <typename TypesTupleType>
660 class CpuBackendGemmTest : public testing::Test {};
661 
662 TYPED_TEST_SUITE_P(CpuBackendGemmTest);
663 
664 typedef ::testing::Types<
665     TypesTuple<float, float, float, float>,
666     TypesTuple<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>,
667     TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int8_t>,
668     TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int16_t>,
669     TypesTuple<std::int8_t, std::int16_t, std::int32_t, std::int16_t>,
670     TypesTuple<std::uint8_t, std::uint8_t, std::int32_t, std::int8_t>>
671     CpuBackendGemmTestInstantiations;
672 
673 TYPED_TEST_SUITE(CpuBackendGemmTest, CpuBackendGemmTestInstantiations);
674 
TYPED_TEST(CpuBackendGemmTest,Square)675 TYPED_TEST(CpuBackendGemmTest, Square) {
676   std::vector<std::tuple<int, int, int>> shapes;
677   for (int size = 1; size < 50; size++) {
678     shapes.push_back(std::make_tuple(size, size, size));
679   }
680   TestRandomGemms<TypeParam>(shapes);
681 }
682 
TYPED_TEST(CpuBackendGemmTest,SquarePowerOfTwo)683 TYPED_TEST(CpuBackendGemmTest, SquarePowerOfTwo) {
684   std::vector<std::tuple<int, int, int>> shapes;
685   for (int size = 64; size <= 128; size *= 2) {
686     shapes.push_back(std::make_tuple(size, size, size));
687   }
688   TestRandomGemms<TypeParam>(shapes);
689 }
690 
TYPED_TEST(CpuBackendGemmTest,MatrixTimesVector)691 TYPED_TEST(CpuBackendGemmTest, MatrixTimesVector) {
692   std::vector<std::tuple<int, int, int>> shapes;
693   for (int size = 1; size < 200; size++) {
694     shapes.push_back(std::make_tuple(size, size, 1));
695   }
696   TestRandomGemms<TypeParam>(shapes);
697 }
698 
TYPED_TEST(CpuBackendGemmTest,VectorTimesMatrix)699 TYPED_TEST(CpuBackendGemmTest, VectorTimesMatrix) {
700   std::vector<std::tuple<int, int, int>> shapes;
701   for (int size = 1; size < 200; size++) {
702     shapes.push_back(std::make_tuple(1, size, size));
703   }
704   TestRandomGemms<TypeParam>(shapes);
705 }
706 
TYPED_TEST(CpuBackendGemmTest,MatrixTimesNarrow)707 TYPED_TEST(CpuBackendGemmTest, MatrixTimesNarrow) {
708   std::vector<std::tuple<int, int, int>> shapes;
709   for (int size = 1; size < 50; size++) {
710     shapes.push_back(std::make_tuple(size, size, 2));
711     shapes.push_back(std::make_tuple(size, size, 3));
712     shapes.push_back(std::make_tuple(size, size, 4));
713     shapes.push_back(std::make_tuple(size, size, 8));
714   }
715   TestRandomGemms<TypeParam>(shapes);
716 }
717 
TYPED_TEST(CpuBackendGemmTest,Rectangular)718 TYPED_TEST(CpuBackendGemmTest, Rectangular) {
719   std::vector<std::tuple<int, int, int>> shapes;
720   for (int size = 1; size < 50; size++) {
721     shapes.push_back(std::make_tuple(size, size + 5, size + 1));
722     shapes.push_back(std::make_tuple(size + 10, size + 2, size));
723   }
724   TestRandomGemms<TypeParam>(shapes);
725 }
726 
TYPED_TEST(CpuBackendGemmTest,HighlyRectangular)727 TYPED_TEST(CpuBackendGemmTest, HighlyRectangular) {
728   std::vector<std::tuple<int, int, int>> shapes;
729   for (int size = 1; size <= 1000; size *= 10) {
730     shapes.push_back(std::make_tuple(size, 10, 10));
731     shapes.push_back(std::make_tuple(10, size, 10));
732     shapes.push_back(std::make_tuple(10, 10, size));
733   }
734   TestRandomGemms<TypeParam>(shapes);
735 }
736 
TYPED_TEST(CpuBackendGemmTest,InnerProduct)737 TYPED_TEST(CpuBackendGemmTest, InnerProduct) {
738   std::vector<std::tuple<int, int, int>> shapes;
739   for (int size = 1; size < 200; size++) {
740     shapes.push_back(std::make_tuple(1, size, 1));
741   }
742   TestRandomGemms<TypeParam>(shapes);
743 }
744 
TYPED_TEST(CpuBackendGemmTest,OuterProduct)745 TYPED_TEST(CpuBackendGemmTest, OuterProduct) {
746   std::vector<std::tuple<int, int, int>> shapes;
747   for (int size = 1; size < 100; size++) {
748     shapes.push_back(std::make_tuple(size, 1, size));
749   }
750   TestRandomGemms<TypeParam>(shapes);
751 }
752 
753 }  // namespace
754 
755 }  // namespace tflite
756