xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/fully_connected_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite FULLY_CONNECTED op.
16 
17 #include "tensorflow/lite/kernels/fully_connected.h"
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 
22 #include <algorithm>
23 #include <initializer_list>
24 #include <limits>
25 #include <map>
26 #include <memory>
27 #include <random>
28 #include <string>
29 #include <vector>
30 
31 #include <gmock/gmock.h>
32 #include <gtest/gtest.h>
33 #include "absl/memory/memory.h"
34 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
35 #include "tensorflow/lite/core/api/op_resolver.h"
36 #include "tensorflow/lite/interpreter.h"
37 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
38 #include "tensorflow/lite/kernels/test_util.h"
39 #include "tensorflow/lite/schema/schema_generated.h"
40 #include "tensorflow/lite/string_type.h"
41 
42 namespace tflite {
43 namespace {
44 
45 using ::testing::ElementsAre;
46 using ::testing::ElementsAreArray;
47 
48 static float fully_connected_input[] = {
49     0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653,
50     0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390,
51     0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314,
52     0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550,
53     0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112,
54     0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999,
55     0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142,
56     0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494,
57     0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081,
58     0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552,
59     0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320,
60     0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458,
61     0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115,
62     0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771,
63     0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582,
64     0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962,
65     0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202,
66     0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691,
67     0.921330, 0.139902};
68 
69 static float fully_connected_golden_output[] = {
70     0,        0.0732134,   0,        0,          0,         0.280859,
71     0,        0.128927,    0,        0.0777251,  0,         0.270268,
72     0.271435, 0.0173503,   0.335465, 0.235562,
73 
74     0,        0.0745866,   0,        0.051611,   0,         0.253876,
75     0,        0.0814873,   0,        0.104104,   0,         0.248529,
76     0.264194, 0,           0.302973, 0.166252,
77 
78     0,        0.0170409,   0,        0.0509851,  0,         0.212834,
79     0,        0.0208326,   0,        0.129932,   0.203978,  0.103428,
80     0.298051, 0,           0.332233, 0.00445903,
81 
82     0,        0.125246,    0,        0.0735336,  0,         0.0910256,
83     0,        0,           0,        0.18933,    0.378111,  0.0712443,
84     0.277298, 0.0123414,   0.267454, 0,
85 
86     0,        0.14687,     0,        0.155495,   0.0300215, 0.147256,
87     0,        0,           0,        0.156412,   0.434914,  0.0461529,
88     0.246508, 0,           0.363138, 0,
89 
90     0,        0,           0,        0.0212949,  0,         0.301708,
91     0,        0.35497,     0,        0.406223,   0.0260211, 0.049195,
92     0.197161, 0,           0.37316,  0,
93 
94     0,        0.221783,    0,        0,          0.0116515, 0.281945,
95     0,        0,           0,        0,          0.285626,  0.181773,
96     0.296401, 0.170452,    0.367135, 0.142597,
97 
98     0,        0,           0,        0,          0,         0.418886,
99     0,        0.291063,    0,        0.227541,   0.0424759, 0.27589,
100     0.398286, 0.177146,    0.40359,  0.121452,
101 
102     0,        0.0834884,   0,        0,          0,         0.287441,
103     0,        0.0046838,   0,        0.0122087,  0,         0.217376,
104     0.140183, 0.0948412,   0.436677, 0.0589876,
105 
106     0,        0.0289969,   0,        0.0921397,  0,         0.396802,
107     0,        0.0126157,   0,        0.0968433,  0,         0.172271,
108     0.173295, 0.0664741,   0.53645,  0.00915603,
109 
110     0,        0,           0,        0,          0,         0.147942,
111     0,        0.263795,    0,        0.39782,    0,         0.382435,
112     0.561072, 0.0579847,   0.145712, 0.13508,
113 
114     0,        0,           0,        0.16382,    0,         0.322294,
115     0,        0.163798,    0,        0.405211,   0.367953,  0.076852,
116     0.342473, 0.0834118,   0.377537, 0,
117 
118     0,        0.206,       0,        0,          0,         0.375769,
119     0,        0,           0,        0,          0,         0.125165,
120     0,        0.105591,    0.52055,  0.0536445,
121 
122     0,        0.259261,    0,        0,          0,         0.247707,
123     0,        0,           0,        0,          0,         0.215862,
124     0.149153, 0.224678,    0.359519, 0.129419,
125 
126     0,        0.17611,     0,        0.280895,   0,         0.576484,
127     0,        0.000418848, 0,        0,          0,         0.151112,
128     0.211902, 0,           0.566341, 0.106305,
129 
130     0,        0.0246284,   0,        0,          0,         0.196267,
131     0,        0.0248624,   0,        0.265635,   0,         0.436199,
132     0.408079, 0.134514,    0.328489, 0.411368};
133 
134 class BaseFullyConnectedOpModel : public SingleOpModel {
135  public:
136   // TODO(ahentz): test different activation types too.
BaseFullyConnectedOpModel(TfLiteRegistration * registration,int units,int batches,const TensorData & input,const TensorData & output={TensorType_FLOAT32},const TensorType & bias_type=TensorType_FLOAT32,bool keep_num_dims=false,bool bias_tensor_optional=false,ActivationFunctionType activation_func=ActivationFunctionType_RELU,FullyConnectedOptionsWeightsFormat weights_format=FullyConnectedOptionsWeightsFormat_DEFAULT,int input_size=-1,bool weights_per_channel_quantized=false,std::vector<float> per_channel_quantization_scales={})137   BaseFullyConnectedOpModel(
138       TfLiteRegistration* registration, int units, int batches,
139       const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
140       const TensorType& bias_type = TensorType_FLOAT32,
141       bool keep_num_dims = false, bool bias_tensor_optional = false,
142       ActivationFunctionType activation_func = ActivationFunctionType_RELU,
143       FullyConnectedOptionsWeightsFormat weights_format =
144           FullyConnectedOptionsWeightsFormat_DEFAULT,
145       int input_size = -1, bool weights_per_channel_quantized = false,
146       std::vector<float> per_channel_quantization_scales = {})
147       : batches_(batches),
148         units_(units),
149         input_size_(input_size),
150         bias_type_(bias_type) {
151     if (input_size_ == -1) {
152       // Calculate input_size_ from batch and input shape.
153       int total_input_size = 1;
154       for (size_t i = 0; i < input.shape.size(); ++i) {
155         total_input_size *= input.shape[i];
156       }
157       input_size_ = total_input_size / batches_;
158     }
159 
160     input_ = AddInput(input);
161     if (weights_per_channel_quantized) {
162       std::vector<int64_t> per_channel_quantization_offsets(
163           per_channel_quantization_scales.size(), 0);
164       if (input.type == TensorType_INT16) {
165         weights_ = AddInput({TensorType_INT8,
166                              {units_, input_size_},
167                              0,
168                              0,
169                              0,
170                              0,
171                              true,
172                              per_channel_quantization_scales,
173                              per_channel_quantization_offsets,
174                              0});
175       } else {
176         weights_ = AddInput({input.type,
177                              {units_, input_size_},
178                              0,
179                              0,
180                              0,
181                              0,
182                              true,
183                              per_channel_quantization_scales,
184                              per_channel_quantization_offsets,
185                              0});
186       }
187     } else {
188       if (input.type == TensorType_INT16) {
189         // Set min and max values that are used to calculate per-tensor scale
190         // and zero points.
191         weights_ = AddInput({TensorType_INT8,
192                              {units_, input_size_},
193                              /*min=*/-63.5,
194                              /*max=*/64});
195       } else {
196         weights_ =
197             AddInput({input.type, {units_, input_size_}, input.min, input.max});
198       }
199     }
200 
201     if (bias_tensor_optional) {
202       bias_ = AddNullInput();
203     } else if (bias_type == TensorType_FLOAT32) {
204       bias_ = AddInput({TensorType_FLOAT32, {units_}});
205     } else {
206       // This is a quantized version. The scale of 'bias' depends on the scales
207       // of input and filter. Supposedly this is correctly set during quantized
208       // training.
209       if (weights_per_channel_quantized) {
210         std::vector<float> bias_scales = per_channel_quantization_scales;
211         const float input_scale = GetScale(input_);
212         for (float& bias_scale : bias_scales) {
213           bias_scale *= input_scale;
214         }
215         std::vector<int64_t> bias_zero_points(
216             per_channel_quantization_scales.size(), 0);
217         TensorData bias{bias_type,   {units_},         0, 0, 0, 0, true,
218                         bias_scales, bias_zero_points, 0};
219         bias_ = AddInput(bias);
220       } else {
221         auto bias_scale = GetScale(input_) * GetScale(weights_);
222         TensorData bias{bias_type, {units_}, 0, 0, bias_scale};
223         bias_ = AddInput(bias);
224       }
225     }
226 
227     output_ = AddOutput(output);
228     if (weights_format != FullyConnectedOptionsWeightsFormat_DEFAULT) {
229       AddOutput({TensorType_UINT8, input.shape});
230     }
231 
232     SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
233                  BuiltinOptions_FullyConnectedOptions,
234                  CreateFullyConnectedOptions(builder_, activation_func,
235                                              weights_format, keep_num_dims)
236                      .Union());
237     resolver_ = std::make_unique<SingleOpResolver>(
238         BuiltinOperator_FULLY_CONNECTED, registration);
239     BuildInterpreter({GetShape(input_), GetShape(weights_),
240                       (bias_ == kTfLiteOptionalTensor) ? std::vector<int>()
241                                                        : GetShape(bias_)});
242   }
243 
input_size()244   int input_size() { return input_size_; }
num_units()245   int num_units() { return units_; }
num_batches()246   int num_batches() { return batches_; }
247 
248  protected:
249   int input_;
250   int weights_;
251   int bias_;
252   int output_;
253 
254   int batches_;
255   int units_;
256   int input_size_;
257   TensorType bias_type_;
258 };
259 
260 class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
261  public:
262   using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
263 
SetBias(const std::vector<float> & f)264   void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
265 
SetWeights(const std::vector<float> & f)266   void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
267 
SetInput(const std::vector<float> & data)268   void SetInput(const std::vector<float>& data) {
269     PopulateTensor(input_, data);
270   }
SetInput(int offset,float * begin,float * end)271   void SetInput(int offset, float* begin, float* end) {
272     PopulateTensor(input_, offset, begin, end);
273   }
274 
GetOutput()275   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()276   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
277 };
278 
279 class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
280  public:
281   using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
QuantizedFullyConnectedOpModel(TfLiteRegistration * registration,int units,int batches,const TensorData & input,const TensorData & output={TensorType_INT8},const TensorType & bias_type=TensorType_INT32,bool keep_num_dims=false,bool bias_tensor_optional=false,ActivationFunctionType activation_func=ActivationFunctionType_RELU,FullyConnectedOptionsWeightsFormat weights_format=FullyConnectedOptionsWeightsFormat_DEFAULT,int input_size=-1)282   QuantizedFullyConnectedOpModel(
283       TfLiteRegistration* registration, int units, int batches,
284       const TensorData& input, const TensorData& output = {TensorType_INT8},
285       const TensorType& bias_type = TensorType_INT32,
286       bool keep_num_dims = false, bool bias_tensor_optional = false,
287       ActivationFunctionType activation_func = ActivationFunctionType_RELU,
288       FullyConnectedOptionsWeightsFormat weights_format =
289           FullyConnectedOptionsWeightsFormat_DEFAULT,
290       int input_size = -1)
291       : BaseFullyConnectedOpModel(registration, units, batches, input, output,
292                                   bias_type, keep_num_dims,
293                                   bias_tensor_optional, activation_func,
294                                   weights_format, input_size) {}
295 
SetBias(const std::vector<float> & data)296   void SetBias(const std::vector<float>& data) {
297     if (bias_type_ == TensorType_INT32) {
298       QuantizeAndPopulate<int32_t>(bias_, data);
299     } else {
300       QuantizeAndPopulate<int64_t>(bias_, data);
301     }
302   }
303 
304   template <typename T>
SetWeights(const std::vector<float> & data)305   void SetWeights(const std::vector<float>& data) {
306     QuantizeAndPopulate<T>(weights_, data);
307   }
308 
309   template <typename T>
ShuffleAndSetWeights(const std::vector<float> & data,int input_depth,int output_depth)310   void ShuffleAndSetWeights(const std::vector<float>& data, int input_depth,
311                             int output_depth) {
312     std::vector<float> shuffled_data(data.size());
313     CHECK_EQ(input_depth % 16, 0);
314     CHECK_EQ(output_depth % 4, 0);
315     float* shuffled_data_ptr = shuffled_data.data();
316     for (int block_o = 0; block_o < output_depth; block_o += 4) {
317       for (int block_i = 0; block_i < input_depth; block_i += 16) {
318         for (int o = 0; o < 4; o++) {
319           for (int i = 0; i < 16; i++) {
320             *shuffled_data_ptr++ =
321                 data[(block_o + o) * input_depth + block_i + i];
322           }
323         }
324       }
325     }
326     TfLiteTensor* t = interpreter_->tensor(weights_);
327     auto quantized_data =
328         Quantize<T>(shuffled_data, t->params.scale, t->params.zero_point);
329     for (T& q : quantized_data) {
330       q ^= 0x80;
331     }
332     PopulateTensor(weights_, 0, quantized_data.data(),
333                    quantized_data.data() + quantized_data.size());
334   }
335 
336   template <typename T>
SetInput(const std::vector<float> & data)337   void SetInput(const std::vector<float>& data) {
338     QuantizeAndPopulate<T>(input_, data);
339   }
340 
341   template <typename T>
GetOutput()342   std::vector<T> GetOutput() {
343     return ExtractVector<T>(output_);
344   }
345 
346   template <typename T>
GetDequantizedOutput()347   std::vector<float> GetDequantizedOutput() {
348     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
349                          GetZeroPoint(output_));
350   }
351 };
352 
353 class PerChannelQuantizedFullyConnectedOpModel
354     : public BaseFullyConnectedOpModel {
355  public:
356   using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
PerChannelQuantizedFullyConnectedOpModel(TfLiteRegistration * registration,int units,int batches,const TensorData & input,const std::vector<float> & per_channel_quantization_scales,const TensorData & output={TensorType_INT8},const TensorType & bias_type=TensorType_INT32,bool keep_num_dims=false,bool bias_tensor_optional=false,ActivationFunctionType activation_func=ActivationFunctionType_RELU,FullyConnectedOptionsWeightsFormat weights_format=FullyConnectedOptionsWeightsFormat_DEFAULT,int input_size=-1)357   PerChannelQuantizedFullyConnectedOpModel(
358       TfLiteRegistration* registration, int units, int batches,
359       const TensorData& input,
360       const std::vector<float>& per_channel_quantization_scales,
361       const TensorData& output = {TensorType_INT8},
362       const TensorType& bias_type = TensorType_INT32,
363       bool keep_num_dims = false, bool bias_tensor_optional = false,
364       ActivationFunctionType activation_func = ActivationFunctionType_RELU,
365       FullyConnectedOptionsWeightsFormat weights_format =
366           FullyConnectedOptionsWeightsFormat_DEFAULT,
367       int input_size = -1)
368       : BaseFullyConnectedOpModel(
369             registration, units, batches, input, output, bias_type,
370             keep_num_dims, bias_tensor_optional, activation_func,
371             weights_format, input_size, true, per_channel_quantization_scales) {
372   }
373 
SetBias(const std::vector<float> & data)374   void SetBias(const std::vector<float>& data) {
375     PerChannelQuantizeBias(bias_, data);
376   }
377 
378   template <typename T>
SetWeights(const std::vector<float> & data)379   void SetWeights(const std::vector<float>& data) {
380     PerChannelSymmetricQuantizeAndPopulate(weights_, data);
381   }
382 
383   template <typename T>
SetInput(const std::vector<float> & data)384   void SetInput(const std::vector<float>& data) {
385     QuantizeAndPopulate<T>(input_, data);
386   }
387 
388   template <typename T>
GetOutput()389   std::vector<T> GetOutput() {
390     return ExtractVector<T>(output_);
391   }
392 
393   template <typename T>
GetDequantizedOutput()394   std::vector<float> GetDequantizedOutput() {
395     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
396                          GetZeroPoint(output_));
397   }
398 };
399 
400 // In the hybrid model the weights are quantized (to uint8). But the bias,
401 // input (and output) are expected to be in float precision.
402 class HybridFullyConnectedOpModel : public SingleOpModel {
403  public:
HybridFullyConnectedOpModel(int units,int batches,const TensorData & input,const TensorData & weights,const TensorData & output={TensorType_FLOAT32},bool asymmetric_inputs=false,int num_threads=1)404   HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
405                               const TensorData& weights,
406                               const TensorData& output = {TensorType_FLOAT32},
407                               bool asymmetric_inputs = false,
408                               int num_threads = 1)
409       : batches_(batches), units_(units) {
410     int total_input_size = 1;
411     for (size_t i = 0; i < input.shape.size(); ++i) {
412       total_input_size *= input.shape[i];
413     }
414     input_size_ = total_input_size / batches_;
415 
416     input_ = AddInput(input);
417     weights_ = AddInput(weights);
418 
419     TensorData bias{TensorType_FLOAT32, {units_}};
420     bias_ = AddInput(bias);
421 
422     output_ = AddOutput(output);
423 
424     auto options = CreateFullyConnectedOptions(
425                        builder_, ActivationFunctionType_RELU,
426                        tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
427                        false, asymmetric_inputs)
428                        .Union();
429     SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
430                  BuiltinOptions_FullyConnectedOptions, options);
431     resolver_ = std::make_unique<SingleOpResolver>(
432         BuiltinOperator_FULLY_CONNECTED,
433         ops::builtin::Register_FULLY_CONNECTED_PIE());
434     BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)},
435                      num_threads, /*allow_fp32_relax_to_fp16=*/false,
436                      /*apply_delegate=*/true);
437   }
SetBias(const std::vector<float> & f)438   void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
SetWeights(const std::vector<float> & data)439   void SetWeights(const std::vector<float>& data) {
440     SymmetricQuantizeAndPopulate(weights_, data);
441   }
442 
SetSignedWeights(std::initializer_list<float> f)443   void SetSignedWeights(std::initializer_list<float> f) {
444     SignedSymmetricQuantizeAndPopulate(weights_, f);
445   }
446 
SetInput(const std::vector<float> & f)447   void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
GetOutput()448   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()449   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
450 
input_size()451   int input_size() { return input_size_; }
num_units()452   int num_units() { return units_; }
num_batches()453   int num_batches() { return batches_; }
454 
455  protected:
456   int input_;
457   int weights_;
458   int bias_;
459   int output_;
460 
461   int batches_;
462   int units_;
463   int input_size_;
464 };
465 
466 const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
467     {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
468     {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
469     {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
470 });
471 
472 class FloatFullyConnectedOpTest : public SingleOpTest {
473  protected:
GetKernelMap()474   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
475     return *kKernelMap;
476   }
477 };
478 
479 const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
480     {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
481     {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
482 });
483 
484 class QuantizedFullyConnectedOpTest : public SingleOpTest {
485  protected:
GetKernelMap()486   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
487     return *kKernelMapNoPie;
488   }
489 };
490 
491 const auto kKernelMapHybrid = new std::map<string, TfLiteRegistration*>({
492     {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
493     // Only Pie supports the hybrid path, so the optimized kernel should fall
494     // back to the Pie path in such cases.
495     {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
496 });
497 
498 // Hybrid mode is used by the Pie quantized kernel.
499 class HybridFullyConnectedOpTest : public SingleOpTest {
500  protected:
GetKernelMap()501   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
502     return *kKernelMapHybrid;
503   }
504 };
505 
506 // TODO(ahentz): add more small tests like this one, focused on making sure the
507 // calculations are correct.
TEST_P(FloatFullyConnectedOpTest,SimpleTest)508 TEST_P(FloatFullyConnectedOpTest, SimpleTest) {
509   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2,
510                                /*input=*/{TensorType_FLOAT32, {2, 10}});
511   m.SetWeights({
512       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
513       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
514       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
515   });
516   m.SetBias({1, 2, 3});
517 
518   m.SetInput({
519       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
520       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
521   });
522 
523   ASSERT_EQ(m.Invoke(), kTfLiteOk);
524 
525   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
526   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
527 }
528 
TEST_P(FloatFullyConnectedOpTest,SimpleTest2)529 TEST_P(FloatFullyConnectedOpTest, SimpleTest2) {
530   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/1, /*batches=*/2,
531                                /*input=*/{TensorType_FLOAT32, {2, 2}});
532   m.SetWeights({
533       2, 4,  // u = 0
534   });
535   m.SetBias({1});
536 
537   m.SetInput({
538       1, 2,  // b = 0
539       2, 1,  // b = 1
540   });
541 
542   ASSERT_EQ(m.Invoke(), kTfLiteOk);
543 
544   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
545   EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
546 }
547 
TEST_P(FloatFullyConnectedOpTest,FilterWithZeroSecondDimension1)548 TEST_P(FloatFullyConnectedOpTest, FilterWithZeroSecondDimension1) {
549   if (SingleOpModel::GetForceUseNnapi()) {
550     return;
551   }
552 
553   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/2, /*batches=*/2,
554                                /*input=*/{TensorType_FLOAT32, {2, 0}});
555   ASSERT_EQ(m.Invoke(), kTfLiteOk);
556 
557   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 2));
558   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 0));
559 }
560 
TEST_P(FloatFullyConnectedOpTest,FilterWithZeroSecondDimension2)561 TEST_P(FloatFullyConnectedOpTest, FilterWithZeroSecondDimension2) {
562   if (SingleOpModel::GetForceUseNnapi()) {
563     return;
564   }
565 
566   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/2, /*batches=*/2,
567                                /*input=*/{TensorType_FLOAT32, {2, 2, 0}},
568                                /*output=*/{TensorType_FLOAT32},
569                                /*bias_type=*/TensorType_FLOAT32,
570                                /*keep_num_dims=*/true);
571   ASSERT_EQ(m.Invoke(), kTfLiteOk);
572 
573   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 2, 2));
574   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 0, 0, 0, 0, 0));
575 }
576 
TEST_P(FloatFullyConnectedOpTest,FilterWithZeroSecondDimension3)577 TEST_P(FloatFullyConnectedOpTest, FilterWithZeroSecondDimension3) {
578   if (SingleOpModel::GetForceUseNnapi()) {
579     return;
580   }
581 
582   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/2, /*batches=*/2,
583                                /*input=*/{TensorType_FLOAT32, {2, 2, 0}});
584   ASSERT_EQ(m.Invoke(), kTfLiteOk);
585 
586   EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 2));
587   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 0, 0, 0, 0, 0));
588 }
589 
TEST(FloatFullyConnectedOpTest,SimpleTestNoBias)590 TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) {
591   // The optimized kernel assumes that the bias is specified.
592   FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
593                                /*units=*/1, /*batches=*/2,
594                                /*input=*/{TensorType_FLOAT32, {2, 2}},
595                                /*output=*/{TensorType_FLOAT32},
596                                /*bias_type=*/TensorType_FLOAT32,
597                                /*keep_num_dims=*/false,
598                                /*bias_tensor_optional=*/true);
599   m.SetWeights({
600       2, 4,  // u = 0
601   });
602 
603   m.SetInput({
604       1, 2,  // b = 0
605       2, 1,  // b = 1
606   });
607 
608   ASSERT_EQ(m.Invoke(), kTfLiteOk);
609 
610   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
611   EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8));
612 }
613 
TEST(FloatFullyConnectedOpTest,SimpleTestEmptyOutput)614 TEST(FloatFullyConnectedOpTest, SimpleTestEmptyOutput) {
615   if (SingleOpModel::GetForceUseNnapi()) {
616     return;
617   }
618 
619   FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
620                                /*units=*/1, /*batches=*/2,
621                                /*input=*/{TensorType_FLOAT32, {0, 2}},
622                                /*output=*/{TensorType_FLOAT32},
623                                /*bias_type=*/TensorType_FLOAT32,
624                                /*keep_num_dims=*/false,
625                                /*bias_tensor_optional=*/true,
626                                /*activation_func=*/ActivationFunctionType_RELU,
627                                /*weights_format=*/
628                                FullyConnectedOptionsWeightsFormat_DEFAULT,
629                                /*input_size=*/2);
630   m.SetWeights({
631       2, 4,  // u = 0
632   });
633 
634   ASSERT_EQ(m.Invoke(), kTfLiteOk);
635 
636   EXPECT_THAT(m.GetOutputShape(), ElementsAre(0, 1));
637 }
638 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedUint8)639 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) {
640   QuantizedFullyConnectedOpModel m(
641       GetRegistration(), /*units=*/3, /*batches*/ 2,
642       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
643       /*output=*/{TensorType_UINT8, {}, -127, 128});
644 
645   // input_product_scale < output_scale was not true.
646   m.SetWeights<uint8_t>({
647       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
648       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
649       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
650   });
651   m.SetBias({1, 2, 3});
652 
653   m.SetInput<uint8_t>({
654       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
655       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
656   });
657 
658   ASSERT_EQ(m.Invoke(), kTfLiteOk);
659 
660   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
661               ElementsAreArray(ArrayFloatNear({
662                   24, 25, 26,  //
663                   58, 59, 60,  //
664               })));
665   EXPECT_THAT(m.GetOutput<uint8_t>(),
666               ElementsAre(151, 152, 153, 185, 186, 187));
667 }
668 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedUint8NoBias)669 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8NoBias) {
670   QuantizedFullyConnectedOpModel m(
671       GetRegistration(), /*units=*/3, /*batches*/ 2,
672       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
673       /*output=*/{TensorType_UINT8, {}, -127, 128},
674       /*bias_type=*/TensorType_INT32,
675       /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/true,
676       /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU,
677       /*FullyConnectedOptionsWeightsFormat weights_format =*/
678       FullyConnectedOptionsWeightsFormat_DEFAULT);
679 
680   // input_product_scale < output_scale was not true.
681   m.SetWeights<uint8_t>({
682       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
683       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
684       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
685   });
686 
687   m.SetInput<uint8_t>({
688       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
689       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
690   });
691 
692   ASSERT_EQ(m.Invoke(), kTfLiteOk);
693 
694   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
695               ElementsAreArray(ArrayFloatNear({
696                   23, 23, 23,  //
697                   57, 57, 57,  //
698               })));
699   EXPECT_THAT(m.GetOutput<uint8_t>(),
700               ElementsAre(150, 150, 150, 184, 184, 184));
701 }
702 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt8)703 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) {
704   QuantizedFullyConnectedOpModel m(
705       GetRegistration(), /*units=*/3, /*batches*/ 2,
706       /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
707       /*output=*/{TensorType_INT8, {}, -127, 128});
708 
709   // input_product_scale < output_scale was not true.
710   m.SetWeights<int8_t>({
711       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
712       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
713       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
714   });
715   m.SetBias({1, 2, 3});
716 
717   m.SetInput<int8_t>({
718       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
719       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
720   });
721 
722   ASSERT_EQ(m.Invoke(), kTfLiteOk);
723 
724   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
725               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
726   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(23, 24, 25, 57, 58, 59));
727 }
728 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestPerChannelQuantizedInt8)729 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt8) {
730   PerChannelQuantizedFullyConnectedOpModel m(
731       GetRegistration(), /*units=*/3, /*batches*/ 2,
732       /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
733       /*per_channel_quantization_scales=*/{0.2, 0.25, 0.5},
734       /*output=*/{TensorType_INT8, {}, -127, 128});
735 
736   // input_product_scale < output_scale was not true.
737   m.SetWeights<int8_t>({
738       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
739       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
740       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
741   });
742   m.SetBias({1, 2, 3});
743 
744   m.SetInput<int8_t>({
745       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
746       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
747   });
748 
749   ASSERT_EQ(m.Invoke(), kTfLiteOk);
750 
751   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
752               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
753   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(23, 24, 25, 57, 58, 59));
754 }
755 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16Bias32)756 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias32) {
757   const float scale = 128.0 / 65536;
758   QuantizedFullyConnectedOpModel m(
759       GetRegistration(), /*units=*/3, /*batches*/ 2,
760       /*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
761       /*output=*/{TensorType_INT16, {}, 0, 0, scale, 0},
762       /*bias_type=*/TensorType_INT32);
763 
764   // input_product_scale < output_scale was not true.
765   m.SetWeights<int8_t>({
766       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
767       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
768       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
769   });
770   m.SetBias({1, 2, 3});
771 
772   m.SetInput<int16_t>({
773       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
774       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
775   });
776 
777   ASSERT_EQ(m.Invoke(), kTfLiteOk);
778 
779   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
780               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
781   EXPECT_THAT(m.GetOutput<int16_t>(),
782               ElementsAre(12288, 12800, 13312, 29696, 30208, 30720));
783 }
784 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestPerChannelQuantizedInt16Bias32)785 TEST_P(QuantizedFullyConnectedOpTest,
786        SimpleTestPerChannelQuantizedInt16Bias32) {
787   const float scale = 128.0 / 65536;
788   PerChannelQuantizedFullyConnectedOpModel m(
789       GetRegistration(), /*units=*/3, /*batches*/ 2,
790       /*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
791       /*per_channel_quantization_scales=*/{0.2, 0.25, 0.5},
792       /*output=*/{TensorType_INT16, {}, 0, 0, scale, 0},
793       /*bias_type=*/TensorType_INT32);
794 
795   // input_product_scale < output_scale was not true.
796   m.SetWeights<int8_t>({
797       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
798       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
799       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
800   });
801   m.SetBias({1, 2, 3});
802 
803   m.SetInput<int16_t>({
804       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
805       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
806   });
807 
808   ASSERT_EQ(m.Invoke(), kTfLiteOk);
809 
810   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
811               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
812   EXPECT_THAT(m.GetOutput<int16_t>(),
813               ElementsAre(12288, 12800, 13312, 29696, 30208, 30720));
814 }
815 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16Bias64)816 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias64) {
817   const float scale = 128.0 / 65536;
818   QuantizedFullyConnectedOpModel m(
819       GetRegistration(), /*units=*/3, /*batches*/ 2,
820       /*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
821       /*output=*/{TensorType_INT16, {}, 0, 0, scale, 0},
822       /*bias_type=*/TensorType_INT64);
823 
824   // input_product_scale < output_scale was not true.
825   m.SetWeights<int8_t>({
826       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
827       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
828       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
829   });
830   m.SetBias({1, 2, 3});
831 
832   m.SetInput<int16_t>({
833       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
834       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
835   });
836 
837   ASSERT_EQ(m.Invoke(), kTfLiteOk);
838 
839   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
840               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60})));
841   EXPECT_THAT(m.GetOutput<int16_t>(),
842               ElementsAre(12288, 12800, 13312, 29696, 30208, 30720));
843 }
844 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt8NoBias)845 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8NoBias) {
846   QuantizedFullyConnectedOpModel m(
847       GetRegistration(), /*units=*/3, /*batches*/ 2,
848       /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
849       /*output=*/{TensorType_INT8, {}, -127, 128},
850       /*bias_type=*/TensorType_INT32,
851       /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/true,
852       /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU,
853       /*FullyConnectedOptionsWeightsFormat weights_format =*/
854       FullyConnectedOptionsWeightsFormat_DEFAULT);
855 
856   // input_product_scale < output_scale was not true.
857   m.SetWeights<int8_t>({
858       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
859       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
860       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
861   });
862 
863   m.SetInput<int8_t>({
864       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
865       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
866   });
867 
868   ASSERT_EQ(m.Invoke(), kTfLiteOk);
869 
870   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
871               ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
872   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(22, 22, 22, 56, 56, 56));
873 }
874 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedOutputShape3DInt8)875 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedOutputShape3DInt8) {
876   if (SingleOpModel::GetForceUseNnapi()) {
877     return;
878   }
879 
880   QuantizedFullyConnectedOpModel m(
881       GetRegistration(), /*units=*/3, /*batches*/ 2,
882       /*input=*/{TensorType_INT8, {2, 2, 5}, -63.5, 64},
883       /*output=*/{TensorType_INT8, {}, -127, 128},
884       /*bias_type=*/TensorType_INT32,
885       /*keep_num_dims=*/true, /*bias_tensor_optional=*/false,
886       /*activation_func=*/ActivationFunctionType_RELU,
887       /*weights_format=*/FullyConnectedOptionsWeightsFormat_DEFAULT,
888       /*input_size=*/5);
889 
890   // input_product_scale < output_scale was not true.
891   m.SetWeights<int8_t>({
892       1, 2, 3, 4, 5,  // u = 0
893       1, 2, 3, 4, 5,  // u = 1
894       1, 2, 3, 4, 5,  // u = 2
895   });
896   m.SetBias({1, 2, 3});
897 
898   m.SetInput<int8_t>({
899       1, 2,  3,  4,  -5,  // b = 0, i = 0
900       1, 2,  3,  -4, 5,   // b = 0, i = 1
901       1, 2,  -3, 4,  5,   // b = 1, i = 0
902       1, -2, 3,  4,  5,   // b = 1, i = 1
903   });
904 
905   m.Invoke();
906 
907   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
908               ElementsAreArray(ArrayFloatNear({
909                   6, 7, 8,     // b = 0, i = 0
910                   24, 25, 26,  // b = 0, i = 1
911                   38, 39, 40,  // b = 1, i = 0
912                   48, 49, 50   // b = 1, i = 1
913               })));
914   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(5, 6, 7,     // b = 0, i = 0
915                                                  23, 24, 25,  // b = 0, i = 1
916                                                  37, 38, 39,  // b = 1, i = 0
917                                                  47, 48, 49   // b = 1, i = 1
918                                                  ));
919 }
920 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedOutputShape3DInt16)921 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedOutputShape3DInt16) {
922   const float scale = 128.0 / 65536;
923   QuantizedFullyConnectedOpModel m(
924       GetRegistration(), /*units=*/3, /*batches*/ 2,
925       /*input=*/{TensorType_INT16, {2, 2, 5}, 0, 0, scale, 0},
926       /*output=*/{TensorType_INT16, {}, 0, 0, scale, 0},
927       /*bias_type=*/TensorType_INT64,
928       /*keep_num_dims=*/true, /*bias_tensor_optional=*/false,
929       /*activation_func=*/ActivationFunctionType_RELU,
930       /*weights_format=*/FullyConnectedOptionsWeightsFormat_DEFAULT,
931       /*input_size=*/5);
932 
933   // input_product_scale < output_scale was not true.
934   m.SetWeights<int8_t>({
935       1, 2, 3, 4, 5,  // u = 0
936       1, 2, 3, 4, 5,  // u = 1
937       1, 2, 3, 4, 5,  // u = 2
938   });
939   m.SetBias({1, 2, 3});
940 
941   m.SetInput<int16_t>({
942       1, 2,  3,  4,  -5,  // b = 0, i = 0
943       1, 2,  3,  -4, 5,   // b = 0, i = 1
944       1, 2,  -3, 4,  5,   // b = 1, i = 0
945       1, -2, 3,  4,  5,   // b = 1, i = 1
946   });
947 
948   ASSERT_EQ(m.Invoke(), kTfLiteOk);
949 
950   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
951               ElementsAreArray(ArrayFloatNear({
952                   6, 7, 8,     // b = 0, i = 0
953                   24, 25, 26,  // b = 0, i = 1
954                   38, 39, 40,  // b = 1, i = 0
955                   48, 49, 50   // b = 1, i = 1
956               })));
957   EXPECT_THAT(m.GetOutput<int16_t>(),
958               ElementsAre(3072, 3584, 4096,     // b = 0, i = 0
959                           12288, 12800, 13312,  // b = 0, i = 1
960                           19456, 19968, 20480,  // b = 1, i = 0
961                           24576, 25088, 25600   // b = 1, i = 1
962                           ));
963 }
964 // Test the GEMV path.
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestSingleBatchQuantizedInt8)965 TEST_P(QuantizedFullyConnectedOpTest, SimpleTestSingleBatchQuantizedInt8) {
966   QuantizedFullyConnectedOpModel m(
967       GetRegistration(), /*units=*/4, /*batches*/ 1,
968       /*input=*/{TensorType_INT8, {1, 10}, -63.5, 64},
969       /*output=*/{TensorType_INT8, {}, -127, 128});
970 
971   // input_product_scale < output_scale was not true.
972   m.SetWeights<int8_t>({
973       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
974       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
975       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
976       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 3
977   });
978   m.SetBias({1, 2, 3, 4});
979 
980   m.SetInput<int8_t>({
981       1, 2, 3, 4, 5, 6, 7, -8, 9, -10  // b = 1
982   });
983 
984   ASSERT_EQ(m.Invoke(), kTfLiteOk);
985 
986   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
987               ElementsAreArray(ArrayFloatNear({58, 59, 60, 61})));
988   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(57, 58, 59, 60));
989 }
990 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8)991 TEST_P(QuantizedFullyConnectedOpTest,
992        SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8) {
993   // real_multiplier = 2.
994   QuantizedFullyConnectedOpModel m(
995       GetRegistration(), /*units=*/3, /*batches*/ 2,
996       /*input=*/{TensorType_UINT8, {2, 10}, -127, 128},
997       /*output=*/{TensorType_UINT8, {}, -63.5, 64});
998 
999   m.SetWeights<uint8_t>({
1000       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1001       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1002       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1003   });
1004   m.SetBias({1, 2, 3});
1005 
1006   m.SetInput<uint8_t>({
1007       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1008       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1009   });
1010 
1011   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1012 
1013   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1014               ElementsAreArray(ArrayFloatNear({
1015                   24, 25, 26,  // first batch
1016                   58, 59, 60,  // second batch
1017               })));
1018   EXPECT_THAT(m.GetOutput<uint8_t>(),
1019               ElementsAre(175, 177, 179, 243, 245, 247));
1020 }
1021 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedOutputMultiplierGreaterThan1Int8)1022 TEST_P(QuantizedFullyConnectedOpTest,
1023        SimpleTestQuantizedOutputMultiplierGreaterThan1Int8) {
1024   // real_multiplier = 2.
1025   QuantizedFullyConnectedOpModel m(
1026       GetRegistration(), /*units=*/3, /*batches*/ 2,
1027       /*input=*/{TensorType_INT8, {2, 10}, -127, 128},
1028       /*output=*/{TensorType_INT8, {}, -63.5, 64});
1029 
1030   m.SetWeights<int8_t>({
1031       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1032       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1033       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1034   });
1035   m.SetBias({1, 2, 3});
1036 
1037   m.SetInput<int8_t>({
1038       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1039       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1040   });
1041 
1042   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1043 
1044   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
1045               ElementsAreArray(ArrayFloatNear({
1046                   24, 25, 26,  // first batch
1047                   58, 59, 60,  // second batch
1048               })));
1049   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(47, 49, 51, 115, 117, 119));
1050 }
1051 
SimpleTestQuantizedInt16OutputCase(TfLiteRegistration * registration,int input_depth,int output_depth,int batches,FullyConnectedOptionsWeightsFormat weights_format)1052 void SimpleTestQuantizedInt16OutputCase(
1053     TfLiteRegistration* registration, int input_depth, int output_depth,
1054     int batches, FullyConnectedOptionsWeightsFormat weights_format) {
1055   const uint8_t kWeightsZeroPoint = 128;
1056   const float kWeightsScale = 1.f / 128.f;
1057   const uint8_t kInputZeroPoint = 128;
1058   const float kInputScale = 1.f / 128.f;
1059   const float kInputMin = (0 - kInputZeroPoint) * kInputScale;
1060   const float kInputMax = (255 - kInputZeroPoint) * kInputScale;
1061   // Output ranges in [-8..8] encoded as int16
1062   const float kOutputScale = 8.f / 32768.f;
1063   const float kOutputMin = -32768 * kOutputScale;
1064   const float kOutputMax = 32767 * kOutputScale;
1065 
1066   QuantizedFullyConnectedOpModel m(
1067       registration, output_depth, batches,
1068       /*input=*/
1069       {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax},
1070       /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax},
1071       /*bias_type=*/TensorType_INT32,
1072       /*keep_num_dims=*/false,
1073       /*bias_tensor_optional=*/false,
1074       /*activation_func=*/ActivationFunctionType_NONE, weights_format);
1075 
1076   std::mt19937 random_engine;
1077   // Some compilers don't support uint8_t for uniform_distribution.
1078   std::uniform_int_distribution<uint32_t> weights_dist(
1079       0, std::numeric_limits<uint8_t>::max());
1080 
1081   std::vector<float> weights_data(input_depth * output_depth);
1082   for (auto& w : weights_data) {
1083     uint8_t q = static_cast<uint8_t>(weights_dist(random_engine));
1084     w = (q - kWeightsZeroPoint) * kWeightsScale;
1085   }
1086 
1087   // Based on weights_format, enforce any shape requirement for that format/path
1088   // and set the (possibly shuffled) weights.
1089   switch (weights_format) {
1090     case FullyConnectedOptionsWeightsFormat_DEFAULT:
1091       m.SetWeights<uint8_t>(weights_data);
1092       break;
1093     case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
1094       // The shuffled path currently supports only a restrictive subset of
1095       // shapes, described by the following assertions:
1096       CHECK_EQ(input_depth % 16, 0);
1097       CHECK_EQ(output_depth % 4, 0);
1098       CHECK(batches == 1 || batches == 4);
1099       m.ShuffleAndSetWeights<uint8_t>(weights_data, input_depth, output_depth);
1100       break;
1101     default:
1102       LOG(FATAL) << "Unhandled weights format";
1103   }
1104 
1105   // Some compilers don't support uint8_t for uniform_distribution.
1106   std::uniform_int_distribution<uint32_t> input_dist(
1107       0, std::numeric_limits<uint8_t>::max());
1108   std::vector<float> input_data(input_depth * batches);
1109   for (auto& i : input_data) {
1110     uint8_t q = static_cast<uint8_t>(input_dist(random_engine));
1111     i = (q - kInputZeroPoint) * kInputScale;
1112   }
1113 
1114   std::vector<float> bias_data(output_depth);
1115   // As the output ranges in [-8, 8], it's reasonable to have bias values
1116   // in [-1, 1], this won't result in too much saturation.
1117   std::uniform_real_distribution<float> bias_dist(-1.f, 1.f);
1118   for (auto& b : bias_data) {
1119     b = bias_dist(random_engine);
1120   }
1121 
1122   m.SetBias(bias_data);
1123   m.SetInput<uint8_t>(input_data);
1124 
1125   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1126 
1127   std::vector<float> expected_output_data(output_depth * batches);
1128   for (int b = 0; b < batches; b++) {
1129     for (int o = 0; o < output_depth; o++) {
1130       float accum = bias_data[o];
1131       for (int i = 0; i < input_depth; i++) {
1132         accum +=
1133             input_data[b * input_depth + i] * weights_data[o * input_depth + i];
1134       }
1135       accum = std::min(accum, kOutputMax);
1136       accum = std::max(accum, kOutputMin);
1137       expected_output_data[b * output_depth + o] = accum;
1138     }
1139   }
1140 
1141   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1142               ElementsAreArray(ArrayFloatNear(expected_output_data, 3e-4f)));
1143 }
1144 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16OutputDefaultWeights)1145 TEST_P(QuantizedFullyConnectedOpTest,
1146        SimpleTestQuantizedInt16OutputDefaultWeights) {
1147   for (int input_depth : {1, 3, 10, 100}) {
1148     for (int output_depth : {1, 3, 10, 100}) {
1149       for (int batch : {1, 3, 10, 100}) {
1150         SimpleTestQuantizedInt16OutputCase(
1151             GetRegistration(), input_depth, output_depth, batch,
1152             FullyConnectedOptionsWeightsFormat_DEFAULT);
1153       }
1154     }
1155   }
1156 }
1157 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights)1158 TEST_P(QuantizedFullyConnectedOpTest,
1159        SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) {
1160   // The shuffled weights block shape is 4x16. The shape of the weights matrix
1161   // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16.
1162   // This means that output_depth must be a multiple of 4, and input_depth must
1163   // be a multiple of 16.
1164   for (int input_depth_numblocks : {1, 3}) {
1165     for (int output_depth_numblocks : {1, 3}) {
1166       int input_depth = 16 * input_depth_numblocks;
1167       int output_depth = 4 * output_depth_numblocks;
1168       // The fast shuffled path is currently supporting only batch sizes of 1
1169       // and 4. The idea is that the whole point of that path is to go as fast
1170       // as possible for small batch size, which requires fully specializing
1171       // it for each batch size, and for larger batch sizes the generic
1172       // gemmlowp-based implementation is fast enough.
1173       for (int batch : {1, 4}) {
1174         SimpleTestQuantizedInt16OutputCase(
1175             GetRegistration(), input_depth, output_depth, batch,
1176             FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
1177       }
1178     }
1179   }
1180 }
1181 
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedUint8)1182 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedUint8) {
1183   HybridFullyConnectedOpModel m(
1184       /*units=*/3, /*batches=*/2,
1185       /*input=*/{TensorType_FLOAT32, {2, 10}},
1186       /*weights=*/
1187       {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0});  // Hybrid
1188 
1189   m.SetWeights({
1190       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1191       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1192       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1193   });
1194   m.SetBias({1, 2, 3});
1195 
1196   m.SetInput({
1197       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1198       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1199   });
1200 
1201   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1202 
1203   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
1204                                  {
1205                                      24, 25, 26,  //
1206                                      58, 59, 60,  //
1207                                  },
1208                                  /*max_abs_error=*/1.3f)));
1209 }
1210 
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedInt8)1211 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
1212   HybridFullyConnectedOpModel m(
1213       /*units=*/3, /*batches=*/2,
1214       /*input=*/{TensorType_FLOAT32, {2, 10}},
1215       /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0});  // Hybrid
1216 
1217   m.SetSignedWeights({
1218       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1219       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1220       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1221   });
1222   m.SetBias({1, 2, 3});
1223 
1224   m.SetInput({
1225       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1226       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1227   });
1228 
1229   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1230 
1231   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
1232                                  {
1233                                      24, 25, 26,  //
1234                                      58, 59, 60,  //
1235                                  },
1236                                  /*max_abs_error=*/1.3f)));
1237 }
1238 
TEST(HybridFullyConnectedOpTest,SimpleTestQuantizedInt8MultiThreaded)1239 TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8MultiThreaded) {
1240   for (int num_threads = 1; num_threads <= 4; ++num_threads) {
1241     HybridFullyConnectedOpModel m(
1242         /*units=*/3, /*batches=*/4,
1243         /*input=*/{TensorType_FLOAT32, {4, 10}},
1244         /*weights=*/
1245         {TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
1246         /*output=*/{TensorType_FLOAT32}, /*asymmetric_inputs=*/false,
1247         /*num_threads=*/num_threads);  // Hybrid
1248 
1249     m.SetSignedWeights({
1250         1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1251         1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1252         1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1253     });
1254     m.SetBias({1, 2, 3});
1255 
1256     m.SetInput({
1257         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1258         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1259         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 2
1260         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 3
1261     });
1262 
1263     ASSERT_EQ(m.Invoke(), kTfLiteOk);
1264 
1265     EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 3));
1266     EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
1267                                    {
1268                                        24, 25, 26,  //
1269                                        58, 59, 60,  //
1270                                        24, 25, 26,  //
1271                                        58, 59, 60,  //
1272                                    },
1273                                    /*max_abs_error=*/1.3f)));
1274   }
1275 }
1276 
TEST(HybridAsymmetricInputFullyConnectedOpTest,SimpleTestQuantizedUint8)1277 TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
1278   HybridFullyConnectedOpModel m(
1279       /*units=*/3, /*batches=*/2,
1280       /*input=*/{TensorType_FLOAT32, {2, 10}},
1281       /*weights=*/
1282       {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32},
1283       /*asymmetric_quantize_input*/ true);  // Hybrid asymmetric
1284 
1285   m.SetWeights({
1286       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1287       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1288       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1289   });
1290   m.SetBias({1, 2, 3});
1291 
1292   m.SetInput({
1293       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1294       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1295   });
1296 
1297   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1298 
1299   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
1300                                  {
1301                                      24, 25, 26,  //
1302                                      58, 59, 60,  //
1303                                  },
1304                                  /*max_abs_error=*/0.64f)));
1305 }
1306 
TEST(HybridAsymmetricInputFullyConnectedOpTest,SimpleTestQuantizedInt8)1307 TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) {
1308   HybridFullyConnectedOpModel m(
1309       /*units=*/3, /*batches=*/2,
1310       /*input=*/{TensorType_FLOAT32, {2, 10}},
1311       /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
1312       {TensorType_FLOAT32},
1313       /*asymmetric_quantize_input*/ true);
1314 
1315   m.SetSignedWeights({
1316       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1317       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1318       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1319   });
1320   m.SetBias({1, 2, 3});
1321 
1322   m.SetInput({
1323       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1324       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1325   });
1326 
1327   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1328 
1329   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
1330                                  {
1331                                      24, 25, 26,  //
1332                                      58, 59, 60,  //
1333                                  },
1334                                  /*max_abs_error=*/1.3f)));
1335 }
1336 
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInput)1337 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
1338   // Note that it is not required that the first dimension be the number of
1339   // batches. All we care is that the input can be evenly distributed in
1340   // batches. In this case, we need the input to have multiples of '2'.
1341   FloatFullyConnectedOpModel m(GetRegistration(),
1342                                /*units=*/3, /*batches=*/2,
1343                                /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
1344   m.SetWeights({
1345       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1346       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1347       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1348   });
1349   m.SetBias({1, 2, 3});
1350 
1351   m.SetInput({
1352       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // first batch
1353       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // second batch
1354   });
1355 
1356   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1357 
1358   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1359   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1360                                  24, 25, 26,  // first batch
1361                                  58, 59, 60,  // second batch
1362                              }));
1363 }
1364 
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInput4DOutput)1365 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput4DOutput) {
1366   // Note that it is not required that the first dimension be the number of
1367   // batches. All we care is that the input can be evenly distributed in
1368   // batches. In this case, we need the input to have multiples of '2'.
1369   FloatFullyConnectedOpModel m(GetRegistration(),
1370                                /*units=*/3, /*batches=*/2,
1371                                /*input=*/{TensorType_FLOAT32, {1, 2, 1, 10}},
1372                                /*output=*/{TensorType_FLOAT32},
1373                                /*bias_type=*/TensorType_FLOAT32,
1374                                /*keep_num_dims=*/true);
1375   m.SetWeights({
1376       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1377       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1378       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1379   });
1380   m.SetBias({1, 2, 3});
1381 
1382   m.SetInput({
1383       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // first batch
1384       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // second batch
1385   });
1386 
1387   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1388 
1389   EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 1, 3));
1390   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1391                                  24, 25, 26,  // first batch
1392                                  58, 59, 60,  // second batch
1393                              }));
1394 }
1395 
1396 #ifdef GTEST_HAS_DEATH_TEST
TEST_P(FloatFullyConnectedOpTest,SimpleTest4DInputInvalidShape)1397 TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInputInvalidShape) {
1398   // Note that it is not required that the first dimension be the number of
1399   // batches. But it is required that the last dimension is the 'input_dim'.
1400   //
1401   // For this particular test, it is required for the output to be reformattable
1402   // into a shape of form {4, 1, 5, ?} but since the output size (the product of
1403   // output dimensions: units times batches) is 6, this is not possible.
1404   EXPECT_DEATH(FloatFullyConnectedOpModel m(
1405                    GetRegistration(), /*units=*/3, /*batches=*/2,
1406                    /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}},
1407                    /*output=*/{TensorType_FLOAT32},
1408                    /*bias_type=*/TensorType_FLOAT32,
1409                    /*keep_num_dims=*/true),
1410                "Cannot allocate tensors");
1411 }
1412 #endif
1413 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTest4dInputQuantizedUint8)1414 TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantizedUint8) {
1415   QuantizedFullyConnectedOpModel m(
1416       GetRegistration(), /*units=*/3, /*batches=*/2,
1417       /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
1418       /*output=*/{TensorType_UINT8, {}, -127, 128});
1419 
1420   // input_product_scale < output_scale was not true.
1421   m.SetWeights<uint8_t>({
1422       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1423       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1424       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1425   });
1426   m.SetBias({1, 2, 3});
1427 
1428   m.SetInput<uint8_t>({
1429       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1430       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1431   });
1432 
1433   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1434 
1435   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1436               ElementsAreArray(ArrayFloatNear({
1437                   24, 25, 26,  //
1438                   58, 59, 60,  //
1439               })));
1440   EXPECT_THAT(m.GetOutput<uint8_t>(),
1441               ElementsAre(151, 152, 153, 185, 186, 187));
1442 }
1443 
TEST_P(QuantizedFullyConnectedOpTest,SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8)1444 TEST_P(QuantizedFullyConnectedOpTest,
1445        SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8) {
1446   // real_multiplier = 2.
1447   QuantizedFullyConnectedOpModel m(
1448       GetRegistration(), /*units=*/3, /*batches=*/2,
1449       /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -127, 128},
1450       /*output=*/{TensorType_UINT8, {}, -63.5, 64});
1451 
1452   m.SetWeights<uint8_t>({
1453       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1454       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1455       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1456   });
1457   m.SetBias({1, 2, 3});
1458 
1459   m.SetInput<uint8_t>({
1460       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1461       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1462   });
1463 
1464   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1465 
1466   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1467               ElementsAreArray(ArrayFloatNear({
1468                   24, 25, 26,  // first batch
1469                   58, 59, 60,  // second batch
1470               })));
1471   EXPECT_THAT(m.GetOutput<uint8_t>(),
1472               ElementsAre(175, 177, 179, 243, 245, 247));
1473 }
1474 
1475 INSTANTIATE_TEST_SUITE_P(
1476     FloatFullyConnectedOpTest, FloatFullyConnectedOpTest,
1477     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
1478 
1479 INSTANTIATE_TEST_SUITE_P(
1480     QuantizedFullyConnectedOpTest, QuantizedFullyConnectedOpTest,
1481     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
1482 
1483 // TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
1484 // to debug errors and doesn't necessarily test all the important details.
TEST_P(FloatFullyConnectedOpTest,BlackBoxTest)1485 TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
1486   FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/16, /*batches=*/2,
1487                                /*input=*/{TensorType_FLOAT32, {2, 8}});
1488   m.SetWeights(
1489       {0.091327,  0.103366,  -0.316505, -0.083120, 0.149366,  -0.196636,
1490        -0.123672, 0.062800,  0.063031,  0.191670,  -0.062001, -0.061504,
1491        -0.275581, 0.059388,  -0.118497, -0.079224, 0.109758,  0.008307,
1492        -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650,
1493        0.266455,  0.051517,  -0.123448, 0.322464,  0.043282,  -0.173782,
1494        -0.190381, 0.002013,  0.096086,  0.131157,  0.031164,  0.100638,
1495        -0.312191, -0.080923, -0.101318, -0.116614, 0.142238,  0.086540,
1496        -0.139154, 0.174268,  -0.073161, 0.080072,  0.006874,  0.229382,
1497        -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824,
1498        -0.025021, 0.074460,  -0.252595, -0.161750, -0.136403, 0.008308,
1499        0.005710,  0.096600,  0.289839,  0.218816,  -0.304651, -0.070958,
1500        0.054598,  0.147113,  -0.139112, -0.072798, -0.163335, -0.167863,
1501        -0.128762, -0.035780, 0.117262,  0.017177,  0.263335,  -0.176612,
1502        0.262961,  -0.093654, -0.339283, 0.333071,  0.180827,  0.287583,
1503        0.066350,  -0.197947, -0.114449, -0.236035, 0.103532,  -0.034284,
1504        0.093299,  -0.145361, 0.054001,  0.250570,  0.157010,  -0.143480,
1505        -0.139061, -0.048873, 0.067557,  0.139038,  0.324106,  0.227041,
1506        0.037793,  -0.225747, -0.241619, 0.357835,  0.135762,  -0.306764,
1507        -0.125982, 0.091916,  0.266587,  0.030135,  0.265148,  0.141627,
1508        0.020120,  0.083815,  -0.124556, -0.100124, -0.048159, 0.181172,
1509        0.302309,  -0.041084, 0.146334,  -0.061511, -0.232605, 0.281324,
1510        0.145408,  -0.221897});
1511   m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860,
1512              0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804,
1513              0.048478, -0.032270, 0.175688, -0.085662});
1514 
1515   const int input_sequence_size = sizeof(fully_connected_input) /
1516                                   sizeof(float) /
1517                                   (m.input_size() * m.num_batches());
1518   for (int i = 0; i < input_sequence_size; i++) {
1519     // TODO(ahentz): This is what the original test was doing: two equal
1520     // batches per invocation. We could instead use two different batches.
1521     float* batch_start = fully_connected_input + i * m.input_size();
1522     float* batch_end = batch_start + m.input_size();
1523     m.SetInput(0, batch_start, batch_end);
1524     m.SetInput(m.input_size(), batch_start, batch_end);
1525 
1526     ASSERT_EQ(m.Invoke(), kTfLiteOk);
1527 
1528     float* golden_start = fully_connected_golden_output + i * m.num_units();
1529     float* golden_end = golden_start + m.num_units();
1530     std::vector<float> expected;
1531     expected.insert(expected.end(), golden_start, golden_end);
1532     expected.insert(expected.end(), golden_start, golden_end);
1533 
1534     EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
1535   }
1536 }
1537 
1538 template <typename T>
1539 class SparseFullyConnectedOpModel : public SingleOpModel {
1540  public:
SparseFullyConnectedOpModel(TfLiteRegistration * registration,int units,int batches,const TensorData & input,const TensorData & weights,const std::vector<T> & weights_data,const TensorData & output={TensorType_FLOAT32},bool bias_tensor_optional=false,int num_threads=1,bool symmetric_quantize_weights=false,bool asymmetric_quantize_inputs=false)1541   SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
1542                               int batches, const TensorData& input,
1543                               const TensorData& weights,
1544                               const std::vector<T>& weights_data,
1545                               const TensorData& output = {TensorType_FLOAT32},
1546                               bool bias_tensor_optional = false,
1547                               int num_threads = 1,
1548                               bool symmetric_quantize_weights = false,
1549                               bool asymmetric_quantize_inputs = false)
1550       : batches_(batches), units_(units) {
1551     int total_input_size = 1;
1552     for (size_t i = 0; i < input.shape.size(); ++i) {
1553       total_input_size *= input.shape[i];
1554     }
1555     input_size_ = total_input_size / batches_;
1556 
1557     input_ = AddInput(input);
1558     weights_ =
1559         AddConstSparseInput(weights, weights_data, symmetric_quantize_weights);
1560 
1561     if (bias_tensor_optional) {
1562       bias_ = AddNullInput();
1563     } else if (input.type == TensorType_INT8) {
1564       // This is a quantized version. The scale of 'bias' depends on the scales
1565       // of input and filter.
1566       auto bias_scale = GetScale(input_) * GetScale(weights_);
1567       TensorData bias = {TensorType_INT32, {units_}, 0, 0, bias_scale};
1568       bias_ = AddInput(bias);
1569     } else {
1570       bias_ = AddInput({input.type, {units_}});
1571     }
1572 
1573     output_ = AddOutput(output);
1574 
1575     SetBuiltinOp(
1576         BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
1577         CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU,
1578                                     FullyConnectedOptionsWeightsFormat_DEFAULT,
1579                                     /*keep_num_dims=*/false,
1580                                     asymmetric_quantize_inputs)
1581             .Union());
1582     resolver_ = std::make_unique<SingleOpResolver>(
1583         BuiltinOperator_FULLY_CONNECTED, registration);
1584     std::vector<std::vector<int>> inputs = {GetShape(input_),
1585                                             GetShape(weights_)};
1586     inputs.push_back((bias_ == kTfLiteOptionalTensor) ? std::vector<int>()
1587                                                       : GetShape(bias_));
1588     BuildInterpreter(inputs, num_threads, /*allow_fp32_relax_to_fp16=*/false,
1589                      /*apply_delegate=*/false);
1590   }
SetBias(const std::vector<T> & data)1591   void SetBias(const std::vector<T>& data) { PopulateTensor(bias_, data); }
SetInput(const std::vector<T> & data)1592   void SetInput(const std::vector<T>& data) { PopulateTensor(input_, data); }
GetOutput()1593   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()1594   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1595 
input_size()1596   int input_size() { return input_size_; }
num_units()1597   int num_units() { return units_; }
num_batches()1598   int num_batches() { return batches_; }
1599 
1600  protected:
1601   int input_;
1602   int weights_;
1603   int bias_;
1604   int output_;
1605 
1606   int batches_;
1607   int units_;
1608   int input_size_;
1609 };
1610 
1611 class SparseFullyConnectedOpTest : public SingleOpTest {
1612  protected:
GetKernelMap()1613   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
1614     return *kKernelMapNoPie;
1615   }
1616 };
1617 
1618 struct SparseTestParam {
1619   std::string kernel_tag;
1620   bool asymmetric_quantize_input;
1621 };
1622 
1623 class SparseHybridFullyConnectedOpTest
1624     : public ::testing::TestWithParam<SparseTestParam> {
1625  public:
GetKernelTags(const std::map<string,TfLiteRegistration * > & kernel_map)1626   static std::vector<string> GetKernelTags(
1627       const std::map<string, TfLiteRegistration*>& kernel_map) {
1628     std::vector<string> tags;
1629     tags.reserve(kernel_map.size());
1630     for (const auto& it : kernel_map) {
1631       tags.push_back(it.first);
1632     }
1633     return tags;
1634   }
1635 
1636  protected:
GetKernelMap()1637   const std::map<string, TfLiteRegistration*>& GetKernelMap() {
1638     return *kKernelMapNoPie;
1639   }
GetRegistration()1640   TfLiteRegistration* GetRegistration() {
1641     return GetKernelMap().at(GetParam().kernel_tag);
1642   }
1643 };
1644 
TEST_P(SparseFullyConnectedOpTest,SimpleTest)1645 TEST_P(SparseFullyConnectedOpTest, SimpleTest) {
1646   std::initializer_list<float> weight_data = {
1647       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1648       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1649       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1650   };
1651   TensorData weight = {};
1652   weight.type = TensorType_FLOAT32;
1653   weight.shape = {3, 10};
1654   weight.traversal_order = {0, 1};
1655   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1656   SparseFullyConnectedOpModel<float> m(
1657       GetRegistration(), /*units=*/3, /*batches=*/2,
1658       /*input=*/{TensorType_FLOAT32, {2, 10}}, weight, weight_data);
1659   m.SetBias({1, 2, 3});
1660 
1661   m.SetInput({
1662       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1663       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1664   });
1665 
1666   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1667 
1668   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1669   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
1670 }
1671 
TEST_P(SparseFullyConnectedOpTest,SimpleTestNoBias)1672 TEST_P(SparseFullyConnectedOpTest, SimpleTestNoBias) {
1673   std::initializer_list<float> weight_data = {
1674       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1675       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1676       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1677   };
1678   TensorData weight = {};
1679   weight.type = TensorType_FLOAT32;
1680   weight.shape = {3, 10};
1681   weight.traversal_order = {0, 1};
1682   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1683   SparseFullyConnectedOpModel<float> m(
1684       GetRegistration(), /*units=*/3, /*batches=*/2,
1685       /*input=*/{TensorType_FLOAT32, {2, 10}}, weight, weight_data,
1686       /*output=*/{TensorType_FLOAT32},
1687       /*bias_tensor_optional=*/true);
1688 
1689   m.SetInput({
1690       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1691       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1692   });
1693 
1694   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1695 
1696   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1697   EXPECT_THAT(m.GetOutput(), ElementsAre(23, 23, 23, 57, 57, 57));
1698 }
1699 
TEST_P(SparseFullyConnectedOpTest,SimpleTest2)1700 TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
1701   std::initializer_list<float> weight_data = {
1702       2, 4  // u = 0
1703   };
1704   TensorData weight = {};
1705   weight.type = TensorType_FLOAT32;
1706   weight.shape = {1, 2};
1707   weight.traversal_order = {0, 1};
1708   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1709   SparseFullyConnectedOpModel<float> m(
1710       GetRegistration(), /*units=*/1, /*batches=*/2,
1711       /*input=*/{TensorType_FLOAT32, {2, 2}}, weight, weight_data);
1712   m.SetBias({1});
1713 
1714   m.SetInput({
1715       1, 2,  // b = 0
1716       2, 1   // b = 1
1717   });
1718 
1719   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1720 
1721   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
1722   EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
1723 }
1724 
TEST_P(SparseFullyConnectedOpTest,Simple1x4Test)1725 TEST_P(SparseFullyConnectedOpTest, Simple1x4Test) {
1726   std::initializer_list<float> weight_data = {
1727       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 0
1728       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 1
1729       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 2
1730   };
1731   TensorData weight = {};
1732   weight.type = TensorType_FLOAT32;
1733   weight.shape = {3, 12};
1734   weight.traversal_order = {0, 1, 2};
1735   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1736   weight.block_map = {1};
1737   weight.block_size = {4};
1738   SparseFullyConnectedOpModel<float> m(GetRegistration(),
1739                                        /*units=*/3, /*batches=*/2,
1740                                        /*input=*/{TensorType_FLOAT32, {2, 12}},
1741                                        weight, weight_data);
1742   m.SetBias({1, 2, 3});
1743 
1744   m.SetInput({
1745       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 0
1746       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 1
1747   });
1748 
1749   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1750 
1751   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1752   EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83));
1753 }
1754 
TEST_P(SparseFullyConnectedOpTest,Simple1x4TestNoBias)1755 TEST_P(SparseFullyConnectedOpTest, Simple1x4TestNoBias) {
1756   std::initializer_list<float> weight_data = {
1757       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 0
1758       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 1
1759       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 2
1760   };
1761   TensorData weight = {};
1762   weight.type = TensorType_FLOAT32;
1763   weight.shape = {3, 12};
1764   weight.traversal_order = {0, 1, 2};
1765   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1766   weight.block_map = {1};
1767   weight.block_size = {4};
1768   SparseFullyConnectedOpModel<float> m(GetRegistration(),
1769                                        /*units=*/3, /*batches=*/2,
1770                                        /*input=*/{TensorType_FLOAT32, {2, 12}},
1771                                        weight, weight_data,
1772                                        /*output=*/{TensorType_FLOAT32},
1773                                        /*bias_tensor_optional=*/true);
1774   m.SetInput({
1775       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 0
1776       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 1
1777   });
1778 
1779   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1780 
1781   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1782   EXPECT_THAT(m.GetOutput(), ElementsAre(288, 288, 288, 80, 80, 80));
1783 }
1784 
TEST_P(SparseFullyConnectedOpTest,Simple1x4TestMultiThreaded)1785 TEST_P(SparseFullyConnectedOpTest, Simple1x4TestMultiThreaded) {
1786   std::initializer_list<float> weight_data = {
1787       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 0
1788       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 1
1789       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 2
1790   };
1791   TensorData weight = {};
1792   weight.type = TensorType_FLOAT32;
1793   weight.shape = {3, 12};
1794   weight.traversal_order = {0, 1, 2};
1795   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1796   weight.block_map = {1};
1797   weight.block_size = {4};
1798   for (int num_threads = 1; num_threads <= 4; num_threads++) {
1799     SparseFullyConnectedOpModel<float> m(
1800         GetRegistration(),
1801         /*units=*/3, /*batches=*/2,
1802         /*input=*/{TensorType_FLOAT32, {2, 12}}, weight, weight_data,
1803         /*output=*/{TensorType_FLOAT32},
1804         /*bias_tensor_optional=*/false, /*num_threads=*/num_threads);
1805     m.SetBias({1, 2, 3});
1806 
1807     m.SetInput({
1808         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 0
1809         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 1
1810     });
1811 
1812     ASSERT_EQ(m.Invoke(), kTfLiteOk);
1813 
1814     EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
1815     EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83));
1816   }
1817 }
1818 
TEST_P(SparseFullyConnectedOpTest,Simple1x4TestMultiThreadedMoreBatches)1819 TEST_P(SparseFullyConnectedOpTest, Simple1x4TestMultiThreadedMoreBatches) {
1820   std::initializer_list<float> weight_data = {
1821       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 0
1822       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 1
1823       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  // u = 2
1824   };
1825   TensorData weight = {};
1826   weight.type = TensorType_FLOAT32;
1827   weight.shape = {3, 12};
1828   weight.traversal_order = {0, 1, 2};
1829   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1830   weight.block_map = {1};
1831   weight.block_size = {4};
1832   for (int num_threads = 1; num_threads <= 4; num_threads++) {
1833     SparseFullyConnectedOpModel<float> m(
1834         GetRegistration(),
1835         /*units=*/3, /*batches=*/6,
1836         /*input=*/{TensorType_FLOAT32, {6, 12}}, weight, weight_data,
1837         /*output=*/{TensorType_FLOAT32},
1838         /*bias_tensor_optional=*/false, /*num_threads=*/num_threads);
1839     m.SetBias({1, 2, 3});
1840 
1841     m.SetInput({
1842         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 0
1843         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 1
1844         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 2
1845         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 3
1846         1, 2, 3, 4, 5, 6, 7, 8,  -9, -10, 11,  12,  // b = 4
1847         1, 2, 3, 4, 5, 6, 7, -8, 9,  -10, -11, 12,  // b = 5
1848     });
1849 
1850     ASSERT_EQ(m.Invoke(), kTfLiteOk);
1851 
1852     EXPECT_THAT(m.GetOutputShape(), ElementsAre(6, 3));
1853     EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291,  // b = 0
1854                                            81, 82, 83,     // b = 1
1855                                            289, 290, 291,  // b = 2
1856                                            81, 82, 83,     // b = 3
1857                                            289, 290, 291,  // b = 4
1858                                            81, 82, 83      // b = 5
1859                                            ));
1860   }
1861 }
1862 
TEST_P(SparseHybridFullyConnectedOpTest,SparseHybrid1x16Test)1863 TEST_P(SparseHybridFullyConnectedOpTest, SparseHybrid1x16Test) {
1864   std::initializer_list<float> weight_data = {
1865       /* 1st row */
1866       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1867       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1868       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
1869       10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
1870       /* 2nd row */
1871       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1872       0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1873       -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1874       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1875       /* 3rd row */
1876       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1877       0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
1878       -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1879       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1880       /* 4th row */
1881       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1882       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1883       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
1884       8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
1885   TensorData weight = {};
1886   weight.type = TensorType_FLOAT32;
1887   weight.shape = {4, 48};
1888   weight.traversal_order = {0, 1, 2};
1889   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1890   weight.block_map = {1};
1891   weight.block_size = {16};
1892   SparseFullyConnectedOpModel<float> m(
1893       GetRegistration(),
1894       /*units=*/4, /*batches=*/2,
1895       /*input=*/{TensorType_FLOAT32, {2, 48}}, weight, weight_data,
1896       /*output=*/{TensorType_FLOAT32},
1897       /*bias_tensor_optional=*/false, /*num_threads)=*/1,
1898       /*symmetric_quantize_weights=*/true,
1899       /*asymmetric_quantize_inputs=*/GetParam().asymmetric_quantize_input);
1900   m.SetBias({1, 2, 3, 4});
1901   m.SetInput({
1902       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1903       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1904       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1905       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1906       1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,  // b = 0
1907       2.5,  0.0,  -2.1, 0.0,  3.0,  0.0,  -1.3, 0.0,  1.3,  0.0,
1908       -1.1, 0.0,  2.0,  0.0,  -1.7, 0.0,  1.9,  0.0,  -1.5, 0.0,
1909       0.5,  0.0,  -0.7, 0.0,  0.8,  0.0,  -0.3, 0.0,  2.8,  0.0,
1910       -2.8, 0.0,  1.1,  -2.3, 1.9,  -1.9, 2.1,  -0.5, 2.4,  -0.1,
1911       1.0,  -2.5, 0.7,  -1.9, 0.2,  0.1,  0.2,  0.3,  // b = 1
1912   });
1913 
1914   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1915 
1916   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4));
1917   std::vector<float> expected = {0,      7.4715, 85.8359, 0,
1918                                  5.9655, 3.0520, 1.9480,  0};
1919   if (GetParam().asymmetric_quantize_input) {
1920     expected = {0, 7.4500, 85.5111, 0, 5.9750, 2.8856, 2.1144, 0};
1921   }
1922   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-3)));
1923 }
1924 
TEST_P(SparseHybridFullyConnectedOpTest,SparseHybrid1x16TestMultiThreaded)1925 TEST_P(SparseHybridFullyConnectedOpTest, SparseHybrid1x16TestMultiThreaded) {
1926   std::initializer_list<float> weight_data = {
1927       /* 1st row */
1928       1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
1929       14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1930       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
1931       10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
1932       /* 2nd row */
1933       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1934       0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
1935       -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1936       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1937       /* 3rd row */
1938       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1939       0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
1940       -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1941       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1942       /* 4th row */
1943       -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
1944       -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1945       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
1946       8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
1947   TensorData weight = {};
1948   weight.type = TensorType_FLOAT32;
1949   weight.shape = {4, 48};
1950   weight.traversal_order = {0, 1, 2};
1951   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
1952   weight.block_map = {1};
1953   weight.block_size = {16};
1954   for (int num_threads = 1; num_threads <= 4; ++num_threads) {
1955     SparseFullyConnectedOpModel<float> m(
1956         GetRegistration(),
1957         /*units=*/4, /*batches=*/4,
1958         /*input=*/{TensorType_FLOAT32, {4, 48}}, weight, weight_data,
1959         /*output=*/{TensorType_FLOAT32},
1960         /*bias_tensor_optional=*/false, /*num_threads=*/num_threads,
1961         /*symmetric_quantize_weights=*/true,
1962         /*asymmetric_quantize_inputs=*/GetParam().asymmetric_quantize_input);
1963     m.SetBias({1, 2, 3, 4});
1964     m.SetInput({
1965         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1966         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1967         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1968         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1969         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,  // b = 0
1970         2.5,  0.0,  -2.1, 0.0,  3.0,  0.0,  -1.3, 0.0,  1.3,  0.0,
1971         -1.1, 0.0,  2.0,  0.0,  -1.7, 0.0,  1.9,  0.0,  -1.5, 0.0,
1972         0.5,  0.0,  -0.7, 0.0,  0.8,  0.0,  -0.3, 0.0,  2.8,  0.0,
1973         -2.8, 0.0,  1.1,  -2.3, 1.9,  -1.9, 2.1,  -0.5, 2.4,  -0.1,
1974         1.0,  -2.5, 0.7,  -1.9, 0.2,  0.1,  0.2,  0.3,  // b = 1
1975         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1976         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1977         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1978         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,
1979         1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0, 1.0,  -1.0,  // b = 2
1980         2.5,  0.0,  -2.1, 0.0,  3.0,  0.0,  -1.3, 0.0,  1.3,  0.0,
1981         -1.1, 0.0,  2.0,  0.0,  -1.7, 0.0,  1.9,  0.0,  -1.5, 0.0,
1982         0.5,  0.0,  -0.7, 0.0,  0.8,  0.0,  -0.3, 0.0,  2.8,  0.0,
1983         -2.8, 0.0,  1.1,  -2.3, 1.9,  -1.9, 2.1,  -0.5, 2.4,  -0.1,
1984         1.0,  -2.5, 0.7,  -1.9, 0.2,  0.1,  0.2,  0.3,  // b = 3
1985     });
1986 
1987     ASSERT_EQ(m.Invoke(), kTfLiteOk);
1988 
1989     EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 4));
1990     std::vector<float> expected = {
1991         0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0,
1992         0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0};
1993     if (GetParam().asymmetric_quantize_input) {
1994       expected = {
1995           0, 7.4500, 85.5111, 0, 5.9750, 2.8856, 2.1144, 0,
1996           0, 7.4500, 85.5111, 0, 5.9750, 2.8856, 2.1144, 0,
1997       };
1998     }
1999     EXPECT_THAT(m.GetOutput(),
2000                 ElementsAreArray(ArrayFloatNear(expected, 1e-3)));
2001   }
2002 }
2003 // TODO(b/148391360): Add tests for unsupported sparsity format.
2004 // TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
2005 
2006 INSTANTIATE_TEST_SUITE_P(
2007     SparseFullyConnectedOpTest, SparseFullyConnectedOpTest,
2008     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
2009 
GenerateSparseTestParam(std::vector<std::string> kernel_tags)2010 std::vector<SparseTestParam> GenerateSparseTestParam(
2011     std::vector<std::string> kernel_tags) {
2012   std::vector<SparseTestParam> test_params;
2013   for (const std::string& kernel_tag : kernel_tags) {
2014     test_params.push_back({kernel_tag, false});
2015     test_params.push_back({kernel_tag, true});
2016   }
2017   return test_params;
2018 }
2019 
2020 INSTANTIATE_TEST_SUITE_P(SparseHybridFullyConnectedOpTest,
2021                          SparseHybridFullyConnectedOpTest,
2022                          ::testing::ValuesIn(GenerateSparseTestParam(
2023                              SingleOpTest::GetKernelTags(*kKernelMapNoPie))));
2024 
2025 class SparseQuantizedFullyConnectedOpModel
2026     : public SparseFullyConnectedOpModel<float> {
2027  public:
2028   using SparseFullyConnectedOpModel::SparseFullyConnectedOpModel;
SetBias(const std::vector<float> & data)2029   void SetBias(const std::vector<float>& data) {
2030     QuantizeAndPopulate<int32_t>(bias_, data);
2031   }
SetInput(const std::vector<float> & data)2032   void SetInput(const std::vector<float>& data) {
2033     QuantizeAndPopulate<int8_t>(input_, data);
2034   }
GetOutput()2035   std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
2036 };
2037 
2038 class SparseQuantizedFullyConnectedOpTest : public SingleOpTest {
2039  protected:
GetKernelMap()2040   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
2041     return *kKernelMapNoPie;
2042   }
2043 };
2044 
TEST_P(SparseQuantizedFullyConnectedOpTest,Simple1x16Test)2045 TEST_P(SparseQuantizedFullyConnectedOpTest, Simple1x16Test) {
2046   std::vector<float> weight_data = {
2047       1,  2,  3,  4,  -1, -2, -3, -4, 1,  2,  3,  4, -4, -3, -2, -1,  // u = 0
2048       0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 0,  0,  0,  0,   // u = 1
2049       -1, -2, -3, -4, 4,  3,  2,  1,  -1, -2, -3, 4, 1,  2,  3,  4,   // u = 2
2050   };
2051   TensorData weight = {TensorType_INT8, {3, 16}, 0, 0, 1};
2052   weight.traversal_order = {0, 1, 2};
2053   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
2054   weight.block_map = {1};
2055   weight.block_size = {16};
2056   SparseQuantizedFullyConnectedOpModel m(
2057       GetRegistration(),
2058       /*units=*/3, /*batches=*/2,
2059       /*input=*/{TensorType_INT8, {2, 16}, 0, 0, 1}, weight, weight_data,
2060       /*output=*/{TensorType_INT8, {}, 0, 0, 1});
2061 
2062   m.SetBias({1, 2, 3});
2063   m.SetInput({
2064       1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,  // b = 0
2065       4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1,  // b = 1
2066   });
2067 
2068   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2069 
2070   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
2071   EXPECT_THAT(m.GetOutput(), ElementsAre(11, 2, 25, 0, 2, 21));
2072 }
2073 
TEST_P(SparseQuantizedFullyConnectedOpTest,Simple1x16TestNoBias)2074 TEST_P(SparseQuantizedFullyConnectedOpTest, Simple1x16TestNoBias) {
2075   std::vector<float> weight_data = {
2076       1,  2,  3,  4,  -1, -2, -3, -4, 1,  2,  3,  4, -4, -3, -2, -1,  // u = 0
2077       0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 0,  0,  0,  0,   // u = 1
2078       -1, -2, -3, -4, 4,  3,  2,  1,  -1, -2, -3, 4, 1,  2,  3,  4,   // u = 2
2079   };
2080   TensorData weight = {TensorType_INT8, {3, 16}, 0, 0, 1};
2081   weight.traversal_order = {0, 1, 2};
2082   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
2083   weight.block_map = {1};
2084   weight.block_size = {16};
2085   SparseQuantizedFullyConnectedOpModel m(
2086       GetRegistration(),
2087       /*units=*/3, /*batches=*/2,
2088       /*input=*/{TensorType_INT8, {2, 16}, 0, 0, 1}, weight, weight_data,
2089       /*output=*/{TensorType_INT8, {}, 0, 0, 1}, /*bias_tensor_optional=*/true);
2090 
2091   m.SetInput({
2092       1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,  // b = 0
2093       4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1,  // b = 1
2094   });
2095 
2096   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2097 
2098   EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
2099   EXPECT_THAT(m.GetOutput(), ElementsAre(10, 0, 22, 0, 0, 18));
2100 }
2101 
TEST_P(SparseQuantizedFullyConnectedOpTest,Simple1x16TestScaledInputOutput)2102 TEST_P(SparseQuantizedFullyConnectedOpTest, Simple1x16TestScaledInputOutput) {
2103   std::initializer_list<float> weight_data = {
2104       0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
2105       0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
2106       0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
2107       0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
2108       0,     0,     0,     0,     0,     0,     0,     0,  // u = 0
2109       0.28,  0.27,  0.40,  0.38,  -0.16, -0.14, -0.12, 0.03,  0.11,  0.22,
2110       0.02,  0.27,  0.22,  -0.39, 0.09,  -0.27, 0,     0,     0,     0,
2111       0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
2112       0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
2113       0,     0,     0,     0,     0,     0,     0,     0,  // u = 1
2114       0.06,  0.43,  -0.03, -0.30, -0.09, 0.49,  0.11,  0.24,  -0.21, 0.14,
2115       -0.18, 0.84,  0.10,  -0.20, -0.51, -0.12, 0.11,  0.02,  -0.09, -0.01,
2116       -0.31, 0.28,  -0.08, 0.32,  0.77,  0.69,  0.45,  -0.20, 0.21,  -0.07,
2117       -0.46, -0.20, 0,     0,     0,     0,     0,     0,     0,     0,
2118       0,     0,     0,     0,     0,     0,     0,     0,  // u = 2
2119   };
2120   TensorData weight = {TensorType_INT8, {3, 48}, 0, 0, 0.014362592250108719};
2121   weight.traversal_order = {0, 1, 2};
2122   weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
2123   weight.block_map = {1};
2124   weight.block_size = {16};
2125   SparseQuantizedFullyConnectedOpModel m(
2126       GetRegistration(),
2127       /*units=*/3, /*batches=*/1,
2128       /*input=*/{TensorType_INT8, {1, 48}, 0, 0, 0.01739450730383396, -128},
2129       weight, weight_data,
2130       /*output=*/{TensorType_INT8, {}, 0, 0, 0.08671142160892487, -52});
2131   m.SetBias({-0.21742193, -0.38303897, -0.2735016});
2132   m.SetInput(
2133       {0.15919347, 0.7385435,  0.01092399, 2.1284404,  0.39123753, 0.01069902,
2134        0.6752592,  0.15486322, 0.,         0.,         0.16048427, 0.33702788,
2135        0.,         1.1263783,  0.,         0.,         0.,         0.,
2136        0.5067856,  0.,         0.,         0.,         0.01031927, 0.,
2137        0.,         0.07268289, 0.02804407, 0.710703,   0.35505712, 0.15339729,
2138        0.,         0.,         0.5485122,  0.10860074, 0.01710763, 0.08116849,
2139        0.05225316, 0.03152719, 0.8149394,  0.6554623,  0.0311714,  0.02122466,
2140        0.995122,   0.06201557, 0.16699032, 0.,         0.,         0.06638951});
2141 
2142   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2143 
2144   EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 3));
2145   EXPECT_THAT(m.GetOutput(), ElementsAre(-52, -50, -52));
2146 }
2147 
2148 INSTANTIATE_TEST_SUITE_P(
2149     SparseQuantizedFullyConnectedOpTest, SparseQuantizedFullyConnectedOpTest,
2150     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
2151 
2152 }  // namespace
2153 }  // namespace tflite
2154