1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
17
18 #include <math.h>
19 #include <stdint.h>
20 #include <stdlib.h>
21
22 #include <algorithm>
23 #include <iterator>
24 #include <limits>
25 #include <random>
26 #include <sstream>
27 #include <string>
28 #include <tuple>
29 #include <type_traits>
30 #include <vector>
31
32 #include <gtest/gtest.h>
33 #include "ruy/matrix.h" // from @ruy
34 #include "ruy/reference_mul.h" // from @ruy
35 #include "tensorflow/lite/kernels/cpu_backend_context.h"
36 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
37 #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
38
39 namespace tflite {
40
41 namespace {
42
43 using cpu_backend_gemm::Gemm;
44 using cpu_backend_gemm::GemmParams;
45 using cpu_backend_gemm::MatrixParams;
46 using cpu_backend_gemm::QuantizationFlavor;
47
48 template <typename Scalar>
ToString(const std::vector<Scalar> & vector)49 std::string ToString(const std::vector<Scalar>& vector) {
50 std::stringstream s;
51 if (vector.empty()) {
52 s << "{}";
53 } else {
54 s << "{ " << static_cast<double>(vector[0]);
55 for (int i = 1; i < vector.size(); i++) {
56 s << ", " << static_cast<double>(vector[i]);
57 }
58 s << "}";
59 }
60 return s.str();
61 }
62
63 template <typename Scalar>
MakeDeterministicPseudoRandomVector(int size,std::vector<Scalar> * vector)64 void MakeDeterministicPseudoRandomVector(int size,
65 std::vector<Scalar>* vector) {
66 // Intentionally create a new local random_engine in each invocation,
67 // so pseudorandom values don't depend on invocation order.
68 // Otherwise, test results would be affecting by e.g. filtering.
69 std::default_random_engine random_engine;
70 (void)random_engine();
71 // Do not use std::uniform*_distribution: the values that it
72 // generates are implementation-defined.
73 const double random_min = static_cast<double>(random_engine.min());
74 const double random_max = static_cast<double>(random_engine.max());
75 const double result_min =
76 std::is_floating_point<Scalar>::value
77 ? -1.0
78 : std::max(-256., static_cast<double>(
79 std::numeric_limits<Scalar>::lowest()));
80 const double result_max =
81 std::is_floating_point<Scalar>::value
82 ? 1.0
83 : std::min(256.,
84 static_cast<double>(std::numeric_limits<Scalar>::max()));
85 const double random_scale =
86 (result_max - result_min) / (random_max - random_min);
87
88 vector->resize(size);
89 for (int i = 0; i < size; i++) {
90 double val = random_scale * (random_engine() - random_min);
91 val = std::max(val,
92 static_cast<double>(std::numeric_limits<Scalar>::lowest()));
93 val =
94 std::min(val, static_cast<double>(std::numeric_limits<Scalar>::max()));
95 (*vector)[i] = static_cast<Scalar>(val);
96 }
97 }
98
99 template <typename Scalar>
MakeVectorFilledWithConsecutiveInts(int size,std::vector<Scalar> * vector)100 void MakeVectorFilledWithConsecutiveInts(int size,
101 std::vector<Scalar>* vector) {
102 vector->resize(size);
103 EXPECT_LE(size, std::numeric_limits<Scalar>::max());
104 for (int i = 0; i < size; i++) {
105 (*vector)[i] = static_cast<Scalar>(i + 1);
106 }
107 }
108
109 template <typename Scalar>
Median(const std::vector<Scalar> & vector)110 Scalar Median(const std::vector<Scalar>& vector) {
111 EXPECT_GT(vector.size(), 0);
112 std::vector<Scalar> vector_copy = vector;
113 std::sort(std::begin(vector_copy), std::end(vector_copy));
114 return vector_copy[vector_copy.size() / 2];
115 }
116
117 template <typename Scalar>
MedianAbs(const std::vector<Scalar> & vector)118 double MedianAbs(const std::vector<Scalar>& vector) {
119 EXPECT_GT(vector.size(), 0);
120 std::vector<double> vector_abs;
121 vector_abs.resize(vector.size());
122 for (int i = 0; i < vector.size(); i++) {
123 vector_abs[i] = std::abs(static_cast<double>(vector[i]));
124 }
125 std::sort(std::begin(vector_abs), std::end(vector_abs));
126 return vector_abs[vector_abs.size() / 2];
127 }
128
129 template <typename Scalar>
Clamp(const std::vector<Scalar> & src,Scalar clamp_min,Scalar clamp_max,std::vector<Scalar> * dst)130 void Clamp(const std::vector<Scalar>& src, Scalar clamp_min, Scalar clamp_max,
131 std::vector<Scalar>* dst) {
132 dst->resize(src.size());
133 for (int i = 0; i < src.size(); i++) {
134 (*dst)[i] = std::max(std::min(src[i], clamp_max), clamp_min);
135 }
136 }
137
138 template <typename AccumScalar, typename DstScalar,
139 QuantizationFlavor quantization_flavor>
Clamp(const GemmParams<AccumScalar,DstScalar,quantization_flavor> & src,DstScalar clamp_min,DstScalar clamp_max,GemmParams<AccumScalar,DstScalar,quantization_flavor> * dst)140 void Clamp(const GemmParams<AccumScalar, DstScalar, quantization_flavor>& src,
141 DstScalar clamp_min, DstScalar clamp_max,
142 GemmParams<AccumScalar, DstScalar, quantization_flavor>* dst) {
143 *dst = src;
144 dst->clamp_min = clamp_min;
145 dst->clamp_max = clamp_max;
146 }
147
148 struct ErrorStats {
149 int size;
150 double scale_factor;
151 double max_abs_diff;
152 double mean_abs_diff;
153 double abs_mean_diff;
154 };
155
156 template <typename Scalar>
ComputeErrorStats(const std::vector<Scalar> & actual,const std::vector<Scalar> & expected,ErrorStats * error_stats)157 void ComputeErrorStats(const std::vector<Scalar>& actual,
158 const std::vector<Scalar>& expected,
159 ErrorStats* error_stats) {
160 double max_abs_diff = 0;
161 double sum_abs_diff = 0;
162 double sum_diff = 0;
163 double max_abs_expected = 0;
164 EXPECT_EQ(actual.size(), expected.size());
165 for (int i = 0; i < actual.size(); i++) {
166 double actual_val = static_cast<double>(actual[i]);
167 double expected_val = static_cast<double>(expected[i]);
168 double diff = actual_val - expected_val;
169 max_abs_expected = std::max(max_abs_expected, std::abs(expected_val));
170 sum_diff += diff;
171 sum_abs_diff += std::abs(diff);
172 max_abs_diff = std::max(max_abs_diff, std::abs(diff));
173 }
174 error_stats->scale_factor = max_abs_expected;
175 error_stats->max_abs_diff = max_abs_diff;
176 error_stats->mean_abs_diff = sum_abs_diff / actual.size();
177 error_stats->abs_mean_diff = std::abs(sum_diff / actual.size());
178 error_stats->size = actual.size();
179 }
180
181 template <typename AccumScalar, typename DstScalar>
CheckErrorStats(const ErrorStats & error_stats,int accumulation_depth)182 bool CheckErrorStats(const ErrorStats& error_stats, int accumulation_depth) {
183 double tolerated_relative_max_abs_diff = 0;
184 double tolerated_relative_mean_abs_diff = 0;
185 double tolerated_relative_abs_mean_diff = 0;
186
187 double inverse_size = 1. / error_stats.size;
188
189 if (std::is_floating_point<AccumScalar>::value) {
190 // Somewhat naive requirement: the worst case should be epsilons
191 // adding up towards the same direction, on values of same magnitude.
192 tolerated_relative_max_abs_diff =
193 accumulation_depth * std::numeric_limits<DstScalar>::epsilon();
194 // Naive interpretation of the Central Limit Theorem is the rationale
195 // for the sqrt here. We haven't even worked out the correct scale factor,
196 // or how applicable that theorem is here (the random variables being added
197 // might not be mutually independent).
198 tolerated_relative_mean_abs_diff =
199 std::sqrt(static_cast<double>(accumulation_depth)) *
200 std::numeric_limits<DstScalar>::epsilon();
201 // Unbiasing requirement: we require the bias, abs_mean_diff, to be much
202 // smaller than the mean_abs_diff, except when there are very few values.
203 tolerated_relative_abs_mean_diff =
204 tolerated_relative_mean_abs_diff * std::sqrt(inverse_size);
205 } else {
206 // In quantized arithmetic, tolerate minor rounding differences, resulting
207 // in off-by-one errors (tolerated_relative_max_abs_diff = 1), as long
208 // as they are rare (tolerated_relative_mean_abs_diff) and unbiased
209 // (tolerated_relative_abs_mean_diff).
210 tolerated_relative_max_abs_diff = 1;
211 // Naively require mean_abs_diff and abs_mean_diff to converge to zero
212 // as size gets large. We don't know at all how quick that convergence
213 // should be: this is just based on trial-and-error and striking a
214 // compromise between something that works and something that's simple
215 // enough code that doesn't feel too ad-hoc. As above in the float path,
216 // abs_mean_diff is subject to a stricter requirement as it is a bias.
217 tolerated_relative_mean_abs_diff = std::sqrt(inverse_size) * 0.5;
218 tolerated_relative_abs_mean_diff = inverse_size * 2.;
219 }
220
221 double tolerated_max_abs_diff =
222 tolerated_relative_max_abs_diff * error_stats.scale_factor;
223 double tolerated_mean_abs_diff =
224 tolerated_relative_mean_abs_diff * error_stats.scale_factor;
225 double tolerated_abs_mean_diff =
226 tolerated_relative_abs_mean_diff * error_stats.scale_factor;
227
228 EXPECT_LE(error_stats.max_abs_diff, tolerated_max_abs_diff);
229 EXPECT_LE(error_stats.mean_abs_diff, tolerated_mean_abs_diff);
230 EXPECT_LE(error_stats.abs_mean_diff, tolerated_abs_mean_diff);
231
232 return error_stats.max_abs_diff <= tolerated_max_abs_diff &&
233 error_stats.mean_abs_diff <= tolerated_mean_abs_diff &&
234 error_stats.abs_mean_diff <= tolerated_abs_mean_diff;
235 }
236
237 template <typename AccumScalar, typename DstScalar>
CheckErrorForAccumulation(int accumulation_depth,const std::vector<DstScalar> & actual,const std::vector<DstScalar> & expected)238 void CheckErrorForAccumulation(int accumulation_depth,
239 const std::vector<DstScalar>& actual,
240 const std::vector<DstScalar>& expected) {
241 ErrorStats error_stats;
242 ComputeErrorStats(actual, expected, &error_stats);
243 bool success =
244 CheckErrorStats<AccumScalar, DstScalar>(error_stats, accumulation_depth);
245 EXPECT_TRUE(success) << "Actual vector\n"
246 << ToString(actual) << "\ndiffers from expected vector\n"
247 << ToString(expected) << "\n";
248 }
249
250 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
251 typename DstScalar, QuantizationFlavor quantization_flavor>
PerformGemmThenCompareResultsThenAgainWithClamping(const MatrixParams<LhsScalar> & lhs_params,const std::vector<LhsScalar> & lhs_data,const MatrixParams<RhsScalar> & rhs_params,const std::vector<RhsScalar> & rhs_data,const MatrixParams<DstScalar> & dst_params,std::vector<DstScalar> * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,const std::vector<DstScalar> & expected,CpuBackendContext * cpu_backend_context)252 void PerformGemmThenCompareResultsThenAgainWithClamping(
253 const MatrixParams<LhsScalar>& lhs_params,
254 const std::vector<LhsScalar>& lhs_data,
255 const MatrixParams<RhsScalar>& rhs_params,
256 const std::vector<RhsScalar>& rhs_data,
257 const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
258 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
259 const std::vector<DstScalar>& expected,
260 CpuBackendContext* cpu_backend_context) {
261 const int accumulation_depth = lhs_params.cols;
262 Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
263 dst_data->data(), params, cpu_backend_context);
264 CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
265 expected);
266 DstScalar expected_median = Median(expected);
267 std::vector<DstScalar> expected_with_clamp;
268 GemmParams<AccumScalar, DstScalar, quantization_flavor> params_with_clamp;
269 DstScalar clamp_min, clamp_max;
270
271 clamp_min = std::numeric_limits<DstScalar>::lowest();
272 clamp_max = expected_median;
273 Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
274 Clamp(params, clamp_min, clamp_max, ¶ms_with_clamp);
275 Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
276 dst_data->data(), params_with_clamp, cpu_backend_context);
277 CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
278 expected_with_clamp);
279
280 clamp_min = expected_median;
281 clamp_max = std::numeric_limits<DstScalar>::max();
282 Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
283 Clamp(params, clamp_min, clamp_max, ¶ms_with_clamp);
284 Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
285 dst_data->data(), params_with_clamp, cpu_backend_context);
286 CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
287 expected_with_clamp);
288 }
289
290 // When generating testcases for a quantized GEMM, it's not trivial to
291 // pick multiplier exponents: a too low value will result in too many zeros,
292 // a too high value will result in too many large clamped values, in both
293 // cases testing coverage is harmed. Therefore to ensure good testing coverage
294 // we must find a multiplier exponent that's just right. It would be possible
295 // to do so by analysis of the random distribution of values in the result
296 // matrix. That however would require some mathematical work that we haven't
297 // done so far. Until that is done, the best that we can do is to search for
298 // a good exponent value by trial-and-error. This is expensive, as each try
299 // requires computing a whole GEMM. This is thus probably a major contribution
300 // to the overall latency of this test. To partially mitigate that,
301 // we use a bisection to reduce the required number of tries.
302 //
303 // This function is recursive. The bisect_min and bisect_max arguments
304 // are the current bisection bounds. It performs a Gemm with the mid-point,
305 // named bisect_mid, as the multiplier exponent. Based on whether the values
306 // in the resulting matrix are rather too low or too large in absolute
307 // value, it then recurses into the corresponding half of the bisection range.
308 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
309 typename DstScalar>
BisectReasonableMultiplierExponent(int bisect_min,int bisect_max,const MatrixParams<LhsScalar> & lhs_params,const std::vector<LhsScalar> & lhs_data,const MatrixParams<RhsScalar> & rhs_params,const std::vector<RhsScalar> & rhs_data,const MatrixParams<DstScalar> & dst_params,std::vector<DstScalar> * dst_data,const GemmParams<AccumScalar,DstScalar> & params,CpuBackendContext * cpu_backend_context)310 int BisectReasonableMultiplierExponent(
311 int bisect_min, int bisect_max, const MatrixParams<LhsScalar>& lhs_params,
312 const std::vector<LhsScalar>& lhs_data,
313 const MatrixParams<RhsScalar>& rhs_params,
314 const std::vector<RhsScalar>& rhs_data,
315 const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
316 const GemmParams<AccumScalar, DstScalar>& params,
317 CpuBackendContext* cpu_backend_context) {
318 if (bisect_min == bisect_max) {
319 return bisect_min;
320 }
321 // Compute the midpoint as the floor of the average of bisect_min and
322 // bisect_max. As C++ integer division is rounding towards zero and our values
323 // may be of any sign, it is not trivial to implement this using only integer
324 // arithmetic.
325 int bisect_mid =
326 static_cast<int>(std::floor(0.5 * (bisect_min + bisect_max)));
327 GemmParams<AccumScalar, DstScalar> params_copy(params);
328 params_copy.multiplier_exponent = bisect_mid;
329 double clamp_abs = std::max(std::abs(static_cast<double>(params.clamp_min)),
330 std::abs(static_cast<double>(params.clamp_max)));
331 Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
332 dst_data->data(), params_copy, cpu_backend_context);
333 double median_abs = MedianAbs(*dst_data);
334 if (median_abs < 0.25 * clamp_abs) {
335 return BisectReasonableMultiplierExponent(
336 bisect_mid + 1, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
337 dst_params, dst_data, params_copy, cpu_backend_context);
338 } else {
339 return BisectReasonableMultiplierExponent(
340 bisect_min, bisect_mid, lhs_params, lhs_data, rhs_params, rhs_data,
341 dst_params, dst_data, params_copy, cpu_backend_context);
342 }
343 }
344
345 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
346 typename DstScalar, QuantizationFlavor quantization_flavor>
ReferenceGemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<DstScalar> & dst_params,DstScalar * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,CpuBackendContext * context)347 void ReferenceGemm(
348 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
349 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
350 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
351 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
352 CpuBackendContext* context) {
353 ruy::Matrix<LhsScalar> ruy_lhs;
354 ruy::Matrix<RhsScalar> ruy_rhs;
355 ruy::Matrix<DstScalar> ruy_dst;
356 cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs);
357 cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs);
358 cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
359
360 ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params;
361 cpu_backend_gemm::detail::MakeRuyMulParams(params, &ruy_mul_params);
362
363 ruy::ReferenceMul(ruy_lhs, ruy_rhs, ruy_mul_params, &ruy_dst);
364 }
365
366 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
367 typename DstScalar>
TestSomeGemm(int rows,int depth,int cols,const std::vector<DstScalar> & golden)368 void TestSomeGemm(int rows, int depth, int cols,
369 const std::vector<DstScalar>& golden) {
370 CpuBackendContext cpu_backend_context;
371 std::default_random_engine random_engine;
372 cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
373 bool use_caching = static_cast<bool>(random_engine() % 2);
374 cpu_backend_context.SetUseCaching(use_caching);
375 const bool use_golden = !golden.empty();
376
377 std::vector<LhsScalar> lhs_data;
378 std::vector<RhsScalar> rhs_data;
379 std::vector<AccumScalar> bias_data;
380 std::vector<DstScalar> dst_data;
381 if (use_golden) {
382 MakeVectorFilledWithConsecutiveInts(rows * depth, &lhs_data);
383 MakeVectorFilledWithConsecutiveInts(depth * cols, &rhs_data);
384 MakeVectorFilledWithConsecutiveInts(rows, &bias_data);
385 } else {
386 MakeDeterministicPseudoRandomVector(rows * depth, &lhs_data);
387 MakeDeterministicPseudoRandomVector(depth * cols, &rhs_data);
388 MakeDeterministicPseudoRandomVector(rows, &bias_data);
389 }
390 MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
391
392 auto random_order = [&]() {
393 return random_engine() % 2 ? cpu_backend_gemm::Order::kRowMajor
394 : cpu_backend_gemm::Order::kColMajor;
395 };
396 MatrixParams<LhsScalar> lhs_params;
397 lhs_params.order =
398 use_golden ? cpu_backend_gemm::Order::kRowMajor : random_order();
399 lhs_params.rows = rows;
400 lhs_params.cols = depth;
401 // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
402 // zero_point as 0 because int16 might cause overflow in acuum int32.
403 // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
404 if (!std::is_floating_point<LhsScalar>::value &&
405 (!std::is_same<LhsScalar, int8_t>::value &&
406 !std::is_same<RhsScalar, int16_t>::value)) {
407 lhs_params.zero_point = 1;
408 if (!use_golden) {
409 lhs_params.zero_point += random_engine() % 8;
410 }
411 }
412
413 MatrixParams<RhsScalar> rhs_params;
414 rhs_params.order =
415 use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
416 rhs_params.rows = depth;
417 rhs_params.cols = cols;
418 // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
419 // zero_point as 0 because int16 might cause overflow in acuum int32.
420 // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
421 if (!std::is_floating_point<RhsScalar>::value &&
422 (!std::is_same<LhsScalar, int8_t>::value &&
423 !std::is_same<RhsScalar, int16_t>::value)) {
424 rhs_params.zero_point = 1;
425 if (!use_golden) {
426 rhs_params.zero_point += random_engine() % 8;
427 }
428 }
429
430 MatrixParams<DstScalar> dst_params;
431 dst_params.order =
432 use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
433 dst_params.rows = rows;
434 dst_params.cols = cols;
435 // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
436 // zero_point as 0 because int16 might cause overflow in acuum int32.
437 // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
438 if (!std::is_floating_point<DstScalar>::value &&
439 (!std::is_same<LhsScalar, int8_t>::value &&
440 !std::is_same<RhsScalar, int16_t>::value)) {
441 dst_params.zero_point = 1;
442 if (!use_golden) {
443 dst_params.zero_point += random_engine() % 8;
444 }
445 }
446
447 GemmParams<AccumScalar, DstScalar> params;
448 if (use_golden || (random_engine() % 2)) {
449 // cpu_backend_gemm supports bias=null only in the float path. Test that
450 // in 50% of float testcases.
451 params.bias = bias_data.data();
452 }
453 static constexpr std::int32_t kMultiplierFixedpointMin = 1234567890;
454 static constexpr std::int32_t kMultiplierFixedpointMax = 1987654321;
455 if (!std::is_floating_point<AccumScalar>::value) {
456 // some large int32 value. Not being a multiple of a large
457 // power of two helps testing rounding behavior.
458 params.multiplier_fixedpoint = kMultiplierFixedpointMin;
459 // Now find a suitable value for multiplier_exponent.
460 // It needs to be low enough for a substantial amount of dst values
461 // to avoid getting clamped.
462 int bisect_min = -8 * static_cast<int>(sizeof(AccumScalar));
463 // We don't increase test coverage by using positive multipliers,
464 // and using very large positive multipliers may at the moment
465 // result in overflow in some paths.
466 // TODO(benoitjacob): fix that.
467 int bisect_max = 0;
468 params.multiplier_exponent = BisectReasonableMultiplierExponent(
469 bisect_min, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
470 dst_params, &dst_data, params, &cpu_backend_context);
471 }
472
473 std::vector<DstScalar> expected;
474 if (use_golden) {
475 EXPECT_EQ(golden.size(), dst_data.size());
476 expected = golden;
477 } else {
478 expected.resize(dst_data.size());
479 ReferenceGemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(),
480 dst_params, expected.data(), params, &cpu_backend_context);
481 }
482
483 PerformGemmThenCompareResultsThenAgainWithClamping(
484 lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data, params,
485 expected, &cpu_backend_context);
486
487 if (!use_golden && !std::is_floating_point<AccumScalar>::value) {
488 // Try with per-channel quantized multipliers.
489 std::vector<AccumScalar> multiplier_fixedpoint_perchannel(rows);
490 std::vector<int> multiplier_exponent_perchannel(rows);
491 for (int i = 0; i < rows; i++) {
492 multiplier_fixedpoint_perchannel[i] =
493 kMultiplierFixedpointMin +
494 (random_engine() %
495 (kMultiplierFixedpointMax + 1 - kMultiplierFixedpointMin));
496 const int exponent_min = params.multiplier_exponent - 2;
497 const int exponent_max = params.multiplier_exponent + 2;
498 multiplier_exponent_perchannel[i] =
499 exponent_min + (random_engine() % (exponent_max + 1 - exponent_min));
500 }
501 static constexpr QuantizationFlavor perchannel_flavor =
502 std::is_floating_point<AccumScalar>::value
503 ? QuantizationFlavor::kFloatingPoint
504 : QuantizationFlavor::kIntegerWithPerRowMultiplier;
505 GemmParams<AccumScalar, DstScalar, perchannel_flavor> params_perchannel;
506 params_perchannel.bias = params.bias;
507 params_perchannel.clamp_min = params.clamp_min;
508 params_perchannel.clamp_max = params.clamp_max;
509 params_perchannel.multiplier_fixedpoint_perchannel =
510 multiplier_fixedpoint_perchannel.data();
511 params_perchannel.multiplier_exponent_perchannel =
512 multiplier_exponent_perchannel.data();
513 ReferenceGemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(),
514 dst_params, expected.data(), params_perchannel,
515 &cpu_backend_context);
516 PerformGemmThenCompareResultsThenAgainWithClamping(
517 lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data,
518 params_perchannel, expected, &cpu_backend_context);
519 }
520 }
521
522 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
523 typename DstScalar>
TestMaybeValidGemm(int lhs_rows,int lhs_cols,int rhs_rows,int rhs_cols,int dst_rows,int dst_cols)524 void TestMaybeValidGemm(int lhs_rows, int lhs_cols, int rhs_rows, int rhs_cols,
525 int dst_rows, int dst_cols) {
526 CpuBackendContext cpu_backend_context;
527 std::default_random_engine random_engine;
528 cpu_backend_context.SetMaxNumThreads(1 + (random_engine() % 8));
529 bool use_caching = static_cast<bool>(random_engine() % 2);
530 cpu_backend_context.SetUseCaching(use_caching);
531
532 std::vector<LhsScalar> lhs_data;
533 std::vector<RhsScalar> rhs_data;
534 std::vector<AccumScalar> bias_data;
535 std::vector<DstScalar> dst_data;
536 MakeDeterministicPseudoRandomVector(lhs_rows * lhs_cols, &lhs_data);
537 MakeDeterministicPseudoRandomVector(rhs_rows * rhs_cols, &rhs_data);
538 MakeDeterministicPseudoRandomVector(dst_rows, &bias_data);
539 MakeDeterministicPseudoRandomVector(dst_rows * dst_cols, &dst_data);
540
541 MatrixParams<LhsScalar> lhs_params;
542 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
543 lhs_params.rows = lhs_rows;
544 lhs_params.cols = lhs_cols;
545 // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
546 // zero_point as 0 because int16 might cause overflow in acuum int32.
547 // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
548 if (!std::is_floating_point<LhsScalar>::value &&
549 (!std::is_same<LhsScalar, int8_t>::value &&
550 !std::is_same<RhsScalar, int16_t>::value)) {
551 lhs_params.zero_point = 1;
552 }
553
554 MatrixParams<RhsScalar> rhs_params;
555 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
556 rhs_params.rows = rhs_rows;
557 rhs_params.cols = rhs_cols;
558 // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
559 // zero_point as 0 because int16 might cause overflow in acuum int32.
560 // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
561 if (!std::is_floating_point<RhsScalar>::value &&
562 (!std::is_same<LhsScalar, int8_t>::value &&
563 !std::is_same<RhsScalar, int16_t>::value)) {
564 rhs_params.zero_point = 1;
565 }
566
567 MatrixParams<DstScalar> dst_params;
568 dst_params.order = cpu_backend_gemm::Order::kColMajor;
569 dst_params.rows = dst_rows;
570 dst_params.cols = dst_cols;
571 // 16x8 quant only supports RUY path. For Ruy 16x8 gemm, it restricts
572 // zero_point as 0 because int16 might cause overflow in acuum int32.
573 // https://github.com/google/ruy/blob/master/ruy/validate.h#L53-L57
574 if (!std::is_floating_point<DstScalar>::value &&
575 (!std::is_same<LhsScalar, int8_t>::value &&
576 !std::is_same<RhsScalar, int16_t>::value)) {
577 dst_params.zero_point = 1;
578 }
579
580 GemmParams<AccumScalar, DstScalar> params;
581 params.bias = bias_data.data();
582 static constexpr std::int32_t kMultiplierFixedpointMin = 1234567890;
583 if (!std::is_floating_point<AccumScalar>::value) {
584 // some large int32 value. Not being a multiple of a large
585 // power of two helps testing rounding behavior.
586 params.multiplier_fixedpoint = kMultiplierFixedpointMin;
587 // Now find a suitable value for multiplier_exponent.
588 // It needs to be low enough for a substantial amount of dst values
589 // to avoid getting clamped.
590 int bisect_min = -8 * static_cast<int>(sizeof(AccumScalar));
591 // We don't increase test coverage by using positive multipliers,
592 // and using very large positive multipliers may at the moment
593 // result in overflow in some paths.
594 int bisect_max = 0;
595 params.multiplier_exponent = BisectReasonableMultiplierExponent(
596 bisect_min, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
597 dst_params, &dst_data, params, &cpu_backend_context);
598 }
599 Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
600 dst_data.data(), params, &cpu_backend_context);
601 }
602
TEST(CpuBackendGemmSimpleTestAgainstGolden,Float)603 TEST(CpuBackendGemmSimpleTestAgainstGolden, Float) {
604 TestSomeGemm<float, float, float, float>(2, 3, 4,
605 {15, 34, 33, 79, 51, 124, 69, 169});
606 }
607
TEST(CpuBackendGemmSimpleTestAgainstGolden,Uint8)608 TEST(CpuBackendGemmSimpleTestAgainstGolden, Uint8) {
609 TestSomeGemm<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>(
610 5, 2, 3, {2, 4, 6, 7, 9, 3, 10, 16, 22, 29, 4, 15, 26, 37, 48});
611 }
612
TEST(CpuBackendGemmSimpleTestAgainstGolden,Int8)613 TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8) {
614 TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int8_t>(
615 2, 6, 3, {13, 32, 31, 81, 50, 127});
616 }
617
TEST(CpuBackendGemmInvalidGemmTest,Float)618 TEST(CpuBackendGemmInvalidGemmTest, Float) {
619 // A standard Gemm operation.
620 TestMaybeValidGemm<float, float, float, float>(2, 3, 3, 4, 2, 4);
621 // An invalid Gemm that will abort in debug mode.
622 #if !defined(TARGET_IPHONE_SIMULATOR) && !defined(TARGET_OS_IPHONE)
623 ASSERT_DEBUG_DEATH(
624 (TestMaybeValidGemm<float, float, float, float>(2, 3, 3, 0, 2, 4)), "");
625 ASSERT_DEBUG_DEATH(
626 (TestMaybeValidGemm<float, float, float, float>(2, 3, 9, 4, 2, 4)), "");
627 #endif
628 }
629
TEST(CpuBackendGemmSimpleTestAgainstGolden,Int8Int16)630 TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8Int16) {
631 TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int16_t>(
632 3, 5, 4, {19, 48, 77, 48, 149, 250, 76, 249, 422, 105, 350, 595});
633 }
634
635 template <typename tLhsScalar, typename tRhsScalar, typename tAccumScalar,
636 typename tDstScalar>
637 struct TypesTuple {
638 using LhsScalar = tLhsScalar;
639 using RhsScalar = tRhsScalar;
640 using AccumScalar = tAccumScalar;
641 using DstScalar = tDstScalar;
642 };
643
644 template <typename TypesTupleType>
TestRandomGemms(const std::vector<std::tuple<int,int,int>> & shapes)645 void TestRandomGemms(const std::vector<std::tuple<int, int, int>>& shapes) {
646 using LhsScalar = typename TypesTupleType::LhsScalar;
647 using RhsScalar = typename TypesTupleType::RhsScalar;
648 using AccumScalar = typename TypesTupleType::AccumScalar;
649 using DstScalar = typename TypesTupleType::DstScalar;
650 for (const auto& shape : shapes) {
651 int rows = std::get<0>(shape);
652 int depth = std::get<1>(shape);
653 int cols = std::get<2>(shape);
654 TestSomeGemm<LhsScalar, RhsScalar, AccumScalar, DstScalar>(rows, depth,
655 cols, {});
656 }
657 }
658
659 template <typename TypesTupleType>
660 class CpuBackendGemmTest : public testing::Test {};
661
662 TYPED_TEST_SUITE_P(CpuBackendGemmTest);
663
664 typedef ::testing::Types<
665 TypesTuple<float, float, float, float>,
666 TypesTuple<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>,
667 TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int8_t>,
668 TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int16_t>,
669 TypesTuple<std::int8_t, std::int16_t, std::int32_t, std::int16_t>,
670 TypesTuple<std::uint8_t, std::uint8_t, std::int32_t, std::int8_t>>
671 CpuBackendGemmTestInstantiations;
672
673 TYPED_TEST_SUITE(CpuBackendGemmTest, CpuBackendGemmTestInstantiations);
674
TYPED_TEST(CpuBackendGemmTest,Square)675 TYPED_TEST(CpuBackendGemmTest, Square) {
676 std::vector<std::tuple<int, int, int>> shapes;
677 for (int size = 1; size < 50; size++) {
678 shapes.push_back(std::make_tuple(size, size, size));
679 }
680 TestRandomGemms<TypeParam>(shapes);
681 }
682
TYPED_TEST(CpuBackendGemmTest,SquarePowerOfTwo)683 TYPED_TEST(CpuBackendGemmTest, SquarePowerOfTwo) {
684 std::vector<std::tuple<int, int, int>> shapes;
685 for (int size = 64; size <= 128; size *= 2) {
686 shapes.push_back(std::make_tuple(size, size, size));
687 }
688 TestRandomGemms<TypeParam>(shapes);
689 }
690
TYPED_TEST(CpuBackendGemmTest,MatrixTimesVector)691 TYPED_TEST(CpuBackendGemmTest, MatrixTimesVector) {
692 std::vector<std::tuple<int, int, int>> shapes;
693 for (int size = 1; size < 200; size++) {
694 shapes.push_back(std::make_tuple(size, size, 1));
695 }
696 TestRandomGemms<TypeParam>(shapes);
697 }
698
TYPED_TEST(CpuBackendGemmTest,VectorTimesMatrix)699 TYPED_TEST(CpuBackendGemmTest, VectorTimesMatrix) {
700 std::vector<std::tuple<int, int, int>> shapes;
701 for (int size = 1; size < 200; size++) {
702 shapes.push_back(std::make_tuple(1, size, size));
703 }
704 TestRandomGemms<TypeParam>(shapes);
705 }
706
TYPED_TEST(CpuBackendGemmTest,MatrixTimesNarrow)707 TYPED_TEST(CpuBackendGemmTest, MatrixTimesNarrow) {
708 std::vector<std::tuple<int, int, int>> shapes;
709 for (int size = 1; size < 50; size++) {
710 shapes.push_back(std::make_tuple(size, size, 2));
711 shapes.push_back(std::make_tuple(size, size, 3));
712 shapes.push_back(std::make_tuple(size, size, 4));
713 shapes.push_back(std::make_tuple(size, size, 8));
714 }
715 TestRandomGemms<TypeParam>(shapes);
716 }
717
TYPED_TEST(CpuBackendGemmTest,Rectangular)718 TYPED_TEST(CpuBackendGemmTest, Rectangular) {
719 std::vector<std::tuple<int, int, int>> shapes;
720 for (int size = 1; size < 50; size++) {
721 shapes.push_back(std::make_tuple(size, size + 5, size + 1));
722 shapes.push_back(std::make_tuple(size + 10, size + 2, size));
723 }
724 TestRandomGemms<TypeParam>(shapes);
725 }
726
TYPED_TEST(CpuBackendGemmTest,HighlyRectangular)727 TYPED_TEST(CpuBackendGemmTest, HighlyRectangular) {
728 std::vector<std::tuple<int, int, int>> shapes;
729 for (int size = 1; size <= 1000; size *= 10) {
730 shapes.push_back(std::make_tuple(size, 10, 10));
731 shapes.push_back(std::make_tuple(10, size, 10));
732 shapes.push_back(std::make_tuple(10, 10, size));
733 }
734 TestRandomGemms<TypeParam>(shapes);
735 }
736
TYPED_TEST(CpuBackendGemmTest,InnerProduct)737 TYPED_TEST(CpuBackendGemmTest, InnerProduct) {
738 std::vector<std::tuple<int, int, int>> shapes;
739 for (int size = 1; size < 200; size++) {
740 shapes.push_back(std::make_tuple(1, size, 1));
741 }
742 TestRandomGemms<TypeParam>(shapes);
743 }
744
TYPED_TEST(CpuBackendGemmTest,OuterProduct)745 TYPED_TEST(CpuBackendGemmTest, OuterProduct) {
746 std::vector<std::tuple<int, int, int>> shapes;
747 for (int size = 1; size < 100; size++) {
748 shapes.push_back(std::make_tuple(size, 1, size));
749 }
750 TestRandomGemms<TypeParam>(shapes);
751 }
752
753 } // namespace
754
755 } // namespace tflite
756