xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/optimize/quantize_model_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantize_model.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
26 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/platform/init_main.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/schema/schema_generated.h"
32 #include "tensorflow/lite/schema/schema_utils.h"
33 #include "tensorflow/lite/testing/util.h"
34 #include "tensorflow/lite/tools/optimize/test_util.h"
35 
36 // Note: More rigorous model tests can be found in subgraph_quantizer_test.cc
37 
38 namespace {
39 tensorflow::string* g_test_model_dir = nullptr;
40 }  // namespace
41 
42 namespace tflite {
43 namespace optimize {
44 namespace {
45 
ReadModel(const string & model_name)46 std::unique_ptr<FlatBufferModel> ReadModel(const string& model_name) {
47   auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name);
48   return FlatBufferModel::BuildFromFile(model_path.c_str());
49 }
50 
51 template <typename T>
GetAsVector(const flatbuffers::Vector<T> * vec)52 std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
53   return std::vector<T>(vec->begin(), vec->end());
54 }
55 
VerifyAsymmetricQuantizationScale(const QuantizationParameters & float_quant_params,const QuantizationParametersT & quantized_quant_params)56 void VerifyAsymmetricQuantizationScale(
57     const QuantizationParameters& float_quant_params,
58     const QuantizationParametersT& quantized_quant_params) {
59   const float eps = 1e-7;
60   ASSERT_EQ(float_quant_params.min()->size(), 1);
61   ASSERT_EQ(float_quant_params.max()->size(), 1);
62   float float_min = std::min(0.f, float_quant_params.min()->Get(0));
63   float float_max = std::max(0.f, float_quant_params.max()->Get(0));
64 
65   ASSERT_EQ(quantized_quant_params.scale.size(), 1);
66   ASSERT_EQ(quantized_quant_params.zero_point.size(), 1);
67 
68   float scale = (float_max - float_min) / 255;
69   EXPECT_NEAR(scale, quantized_quant_params.scale[0], eps);
70 }
71 
GetBiasTensorType(TensorType & activation_type)72 TensorType GetBiasTensorType(TensorType& activation_type) {
73   return activation_type == TensorType_INT16 ? TensorType_INT64
74                                              : TensorType_INT32;
75 }
76 
77 class QuantizeModelTest : public testing::Test {
78  protected:
QuantizeModelTest()79   QuantizeModelTest() {
80     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
81     readonly_model_ = input_model_->GetModel();
82     readonly_model_->UnPackTo(&model_);
83   }
84 
85   std::unique_ptr<FlatBufferModel> input_model_;
86   const Model* readonly_model_;
87   tflite::ModelT model_;
88   flatbuffers::FlatBufferBuilder builder_;
89   internal::FailOnErrorReporter error_reporter_;
90 };
91 
ExpectSameModels(const ModelT & model,const ModelT & expected_model)92 void ExpectSameModels(const ModelT& model, const ModelT& expected_model) {
93   ASSERT_EQ(model.subgraphs.size(), expected_model.subgraphs.size());
94   for (size_t subgraph_idx = 0; subgraph_idx < model.subgraphs.size();
95        subgraph_idx++) {
96     const auto graph = model.subgraphs[subgraph_idx].get();
97     const auto expected_graph = expected_model.subgraphs[subgraph_idx].get();
98     ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size());
99     for (size_t i = 0; i < graph->tensors.size(); i++) {
100       const auto tensor = graph->tensors[i].get();
101       const auto expected_tensor = expected_graph->tensors[i].get();
102       EXPECT_EQ(tensor->buffer, expected_tensor->buffer);
103       EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable);
104       EXPECT_EQ(tensor->shape, expected_tensor->shape);
105       EXPECT_EQ(tensor->name, expected_tensor->name);
106       EXPECT_EQ(tensor->type, expected_tensor->type);
107       const auto quantization_params = tensor->quantization.get();
108       const auto expected_quantization_params =
109           expected_tensor->quantization.get();
110       if (quantization_params != nullptr ||
111           expected_quantization_params != nullptr) {
112         EXPECT_NE(quantization_params, nullptr);
113         EXPECT_NE(expected_quantization_params, nullptr);
114         EXPECT_EQ(quantization_params->scale,
115                   expected_quantization_params->scale);
116         EXPECT_EQ(quantization_params->zero_point,
117                   expected_quantization_params->zero_point);
118       }
119     }
120   }
121   ASSERT_EQ(model.buffers.size(), expected_model.buffers.size());
122   for (size_t buffer_idx = 0; buffer_idx < model.buffers.size(); ++buffer_idx) {
123     const auto buffer = model.buffers[buffer_idx].get()->data;
124     const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data;
125     EXPECT_EQ(buffer, expected_buffer);
126   }
127   // TODO(jianlijianli): Compare operators as well.
128 }
129 
130 class QuantizeConvModelTest : public QuantizeModelTest,
131                               public testing::WithParamInterface<TensorType> {
132  protected:
QuantizeConvModelTest()133   QuantizeConvModelTest() {
134     tensor_type_ = GetParam();
135     bias_type_ = GetBiasTensorType(tensor_type_);
136     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
137     readonly_model_ = input_model_->GetModel();
138     readonly_model_->UnPackTo(&model_);
139   }
140   TensorType tensor_type_;
141   TensorType bias_type_;
142 };
143 
144 INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
145                          testing::ValuesIn({TensorType_INT8,
146                                             TensorType_INT16}));
147 
TEST_P(QuantizeConvModelTest,QuantizationSucceeds)148 TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
149   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
150                                           tensor_type_, false, tensor_type_,
151                                           bias_type_, &error_reporter_);
152   EXPECT_EQ(status, kTfLiteOk);
153   const uint8_t* buffer = builder_.GetBufferPointer();
154   const Model* output_model = GetModel(buffer);
155   ASSERT_TRUE(output_model);
156 }
157 
TEST_P(QuantizeConvModelTest,SkipUnspecifiedLayer)158 TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
159   auto status =
160       QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
161                     /*allow_float=*/true, {}, TensorType_FLOAT32,
162                     TensorType_FLOAT32, &error_reporter_);
163   EXPECT_EQ(status, kTfLiteOk);
164   ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
165   // The resulting model should be the same.
166   ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
167   for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
168        subgraph_idx++) {
169     const auto quantized_graph = model_.subgraphs[subgraph_idx].get();
170     const auto float_graph = readonly_model_->subgraphs()->Get(subgraph_idx);
171     ASSERT_EQ(quantized_graph->tensors.size(), float_graph->tensors()->size());
172     for (size_t i = 0; i < quantized_graph->tensors.size(); i++) {
173       const auto quant_tensor = quantized_graph->tensors[i].get();
174       const auto float_tensor = float_graph->tensors()->Get(i);
175       EXPECT_EQ(quant_tensor->buffer, float_tensor->buffer());
176       EXPECT_EQ(quant_tensor->is_variable, float_tensor->is_variable());
177       EXPECT_EQ(quant_tensor->shape, GetAsVector(float_tensor->shape()));
178       EXPECT_EQ(quant_tensor->name, float_tensor->name()->str());
179       EXPECT_EQ(quant_tensor->type, float_tensor->type());
180     }
181   }
182 }
183 
TEST_P(QuantizeConvModelTest,TensorShapesAndStructureIsUnchanged)184 TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
185   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
186                                           tensor_type_, false, tensor_type_,
187                                           bias_type_, &error_reporter_);
188   EXPECT_EQ(status, kTfLiteOk);
189   ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
190   for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
191        subgraph_idx++) {
192     const auto quantized_graph = model_.subgraphs[subgraph_idx].get();
193     const auto float_graph = readonly_model_->subgraphs()->Get(subgraph_idx);
194     ASSERT_EQ(quantized_graph->tensors.size(), float_graph->tensors()->size());
195     for (size_t i = 0; i < quantized_graph->tensors.size(); i++) {
196       const auto quant_tensor = quantized_graph->tensors[i].get();
197       const auto float_tensor = float_graph->tensors()->Get(i);
198       EXPECT_EQ(quant_tensor->buffer, float_tensor->buffer());
199       EXPECT_EQ(quant_tensor->is_variable, float_tensor->is_variable());
200       EXPECT_EQ(quant_tensor->shape, GetAsVector(float_tensor->shape()));
201       EXPECT_EQ(quant_tensor->name, float_tensor->name()->str());
202     }
203   }
204   // check op and versioning.
205   EXPECT_EQ(model_.operator_codes.size(), 1);
206   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
207             BuiltinOperator_CONV_2D);
208   EXPECT_EQ(model_.operator_codes[0]->version, 3);
209 }
210 
TEST_P(QuantizeConvModelTest,OperatorsAreUnchanged)211 TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) {
212   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
213                                           tensor_type_, false, tensor_type_,
214                                           bias_type_, &error_reporter_);
215   EXPECT_EQ(status, kTfLiteOk);
216   ASSERT_EQ(model_.operator_codes.size(),
217             readonly_model_->operator_codes()->size());
218   for (size_t i = 0; i < model_.operator_codes.size(); i++) {
219     const auto float_model_op = readonly_model_->operator_codes()->Get(i);
220     EXPECT_EQ(GetBuiltinCode(model_.operator_codes[i].get()),
221               GetBuiltinCode(float_model_op));
222     if (GetBuiltinCode(model_.operator_codes[i].get()) ==
223         BuiltinOperator_CONV_2D) {
224       EXPECT_EQ(model_.operator_codes[i]->version, 3);
225     } else {
226       EXPECT_EQ(model_.operator_codes[i]->version, 2);
227     }
228   }
229 
230   ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
231   for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
232        subgraph_idx++) {
233     const auto quantized_graph = model_.subgraphs[subgraph_idx].get();
234     const auto float_graph = readonly_model_->subgraphs()->Get(subgraph_idx);
235     ASSERT_EQ(quantized_graph->operators.size(),
236               float_graph->operators()->size());
237     for (size_t i = 0; i < quantized_graph->operators.size(); i++) {
238       const auto quant_op = quantized_graph->operators[i].get();
239       const auto float_op = float_graph->operators()->Get(i);
240       EXPECT_EQ(quant_op->inputs, GetAsVector(float_op->inputs()));
241       EXPECT_EQ(quant_op->outputs, GetAsVector(float_op->outputs()));
242       EXPECT_EQ(quant_op->opcode_index, float_op->opcode_index());
243     }
244   }
245 }
246 
TEST_P(QuantizeConvModelTest,GraphIsFullyQuantized)247 TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
248   auto status = QuantizeModelAllOperators(
249       &builder_, &model_, tensor_type_, tensor_type_,
250       /*allow_float*/ false, tensor_type_, bias_type_, &error_reporter_);
251   EXPECT_EQ(status, kTfLiteOk);
252   for (const auto& subgraph : model_.subgraphs) {
253     for (const auto& tensor : subgraph->tensors) {
254       if (tensor_type_ == TensorType_INT8) {
255         EXPECT_TRUE(tensor->type == TensorType_INT32 ||
256                     tensor->type == TensorType_INT8);
257       } else if (tensor_type_ == TensorType_INT16) {
258         EXPECT_TRUE(tensor->type == TensorType_INT64 ||  // bias
259                     tensor->type == TensorType_INT8 ||   // weights
260                     tensor->type == TensorType_INT16);   // activations
261       }
262     }
263   }
264 }
265 
TEST_P(QuantizeConvModelTest,FloatInputAndOutput)266 TEST_P(QuantizeConvModelTest, FloatInputAndOutput) {
267   auto status = QuantizeModelAllOperators(
268       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
269       /*allow_float*/ false, tensor_type_, bias_type_, &error_reporter_);
270   EXPECT_EQ(status, kTfLiteOk);
271 
272   for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
273        ++subgraph_idx) {
274     const auto& subgraph = model_.subgraphs[subgraph_idx];
275     const auto& readonly_subgraph =
276         readonly_model_->subgraphs()->Get(subgraph_idx);
277     // The model has one input and output, so the converted model should have
278     // two extra ops, a Quantize and Dequantize.
279     EXPECT_EQ(subgraph->operators.size(),
280               readonly_subgraph->operators()->size() + 2);
281     // Check that the first op is Quantize and the last is Dequant.
282     const auto& quant_op = subgraph->operators[0];
283     const auto& dequant_op =
284         subgraph->operators[subgraph->operators.size() - 1];
285     const int32_t quant_idx = quant_op->opcode_index;
286     const int32_t dequant_idx = dequant_op->opcode_index;
287     EXPECT_EQ(GetBuiltinCode(model_.operator_codes[quant_idx].get()),
288               BuiltinOperator_QUANTIZE);
289     EXPECT_EQ(GetBuiltinCode(model_.operator_codes[dequant_idx].get()),
290               BuiltinOperator_DEQUANTIZE);
291     // The model should only have one input and output.
292     EXPECT_EQ(subgraph->inputs.size(), 1);
293     EXPECT_EQ(subgraph->outputs.size(), 1);
294     const int32_t input_idx = subgraph->inputs[0];
295     const int32_t output_idx = subgraph->outputs[0];
296     // Ensure: new input -> Quant -> old input.
297     EXPECT_EQ(quant_op->inputs[0], input_idx);
298     EXPECT_EQ(quant_op->outputs[0], readonly_subgraph->inputs()->Get(0));
299     // Ensure: old output -> dequant -> new output.
300     EXPECT_EQ(dequant_op->inputs[0], readonly_subgraph->outputs()->Get(0));
301     EXPECT_EQ(dequant_op->outputs[0], output_idx);
302     // The input and output types should be float.
303     EXPECT_EQ(subgraph->tensors[input_idx]->type, TensorType_FLOAT32);
304     EXPECT_EQ(subgraph->tensors[input_idx]->name, "input");
305     EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_FLOAT32);
306     EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
307     // The original input and output has been renamed.
308     std::string control_suffix =
309         (tensor_type_ == TensorType_INT16) ? "int16" : "int8";
310     EXPECT_EQ(subgraph->tensors[quant_op->outputs[0]]->name,
311               "input_" + control_suffix);
312     EXPECT_EQ(subgraph->tensors[dequant_op->inputs[0]]->name,
313               "output_" + control_suffix);
314     for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
315          ++tensor_idx) {
316       const auto& tensor = subgraph->tensors[tensor_idx];
317       if (input_idx != tensor_idx && output_idx != tensor_idx) {
318         if (tensor_type_ == TensorType_INT8) {
319           EXPECT_TRUE(tensor->type == TensorType_INT32 ||
320                       tensor->type == TensorType_INT8);
321         } else if (tensor_type_ == TensorType_INT16) {
322           EXPECT_TRUE(tensor->type == TensorType_INT64 ||  // bias
323                       tensor->type == TensorType_INT8 ||   // weights
324                       tensor->type == TensorType_INT16);   // activations
325         }
326       }
327     }
328   }
329 }
330 
TEST_P(QuantizeConvModelTest,Uint8InputAndOutput)331 TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) {
332   auto status = QuantizeModelAllOperators(
333       &builder_, &model_, TensorType_UINT8, TensorType_UINT8, false,
334       TensorType_INT8, TensorType_INT32, &error_reporter_);
335   EXPECT_EQ(status, kTfLiteOk);
336 
337   for (int32_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
338        ++subgraph_idx) {
339     const auto& subgraph = model_.subgraphs[subgraph_idx];
340     const auto& readonly_subgraph =
341         readonly_model_->subgraphs()->Get(subgraph_idx);
342     // The model has one input and output, so the converted model should have
343     // two extra ops, a Quantize and Dequantize.
344     EXPECT_EQ(subgraph->operators.size(),
345               readonly_subgraph->operators()->size() + 2);
346     // Check that the first op is Quantize and the last is Dequant.
347     const auto& quant_op_uint8_int8 = subgraph->operators[0];
348     const auto& quant_op_int8_uint8 =
349         subgraph->operators[subgraph->operators.size() - 1];
350     const int32_t quant_op_uint8_int8_idx = quant_op_uint8_int8->opcode_index;
351     const int32_t quant_op_int8_uint8_idx = quant_op_int8_uint8->opcode_index;
352     EXPECT_EQ(
353         GetBuiltinCode(model_.operator_codes[quant_op_uint8_int8_idx].get()),
354         BuiltinOperator_QUANTIZE);
355     EXPECT_EQ(
356         GetBuiltinCode(model_.operator_codes[quant_op_int8_uint8_idx].get()),
357         BuiltinOperator_QUANTIZE);
358     // The model should only have one input and output.
359     EXPECT_EQ(subgraph->inputs.size(), 1);
360     EXPECT_EQ(subgraph->outputs.size(), 1);
361     const int32_t input_idx = subgraph->inputs[0];
362     const int32_t output_idx = subgraph->outputs[0];
363     // Ensure: new input -> Quant -> old input.
364     EXPECT_EQ(quant_op_uint8_int8->inputs[0], input_idx);
365     EXPECT_EQ(quant_op_uint8_int8->outputs[0],
366               readonly_subgraph->inputs()->Get(0));
367     // Ensure: old output -> dequant -> new output.
368     EXPECT_EQ(quant_op_int8_uint8->inputs[0],
369               readonly_subgraph->outputs()->Get(0));
370     EXPECT_EQ(quant_op_int8_uint8->outputs[0], output_idx);
371     // The input and output types should be uint8.
372     EXPECT_EQ(subgraph->tensors[input_idx]->type, TensorType_UINT8);
373     EXPECT_EQ(subgraph->tensors[input_idx]->name, "input");
374     EXPECT_EQ(subgraph->tensors[input_idx]->quantization->scale.size(), 1);
375     EXPECT_FLOAT_EQ(subgraph->tensors[input_idx]->quantization->scale[0],
376                     0.0392156877);
377     EXPECT_EQ(subgraph->tensors[input_idx]->quantization->zero_point.size(), 1);
378     EXPECT_EQ(subgraph->tensors[input_idx]->quantization->zero_point[0], 0);
379     EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_UINT8);
380     EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
381     EXPECT_EQ(subgraph->tensors[output_idx]->quantization->scale.size(), 1);
382     EXPECT_FLOAT_EQ(subgraph->tensors[output_idx]->quantization->scale[0],
383                     0.0392156877);
384     EXPECT_EQ(subgraph->tensors[output_idx]->quantization->zero_point.size(),
385               1);
386     EXPECT_EQ(subgraph->tensors[output_idx]->quantization->zero_point[0], 0);
387     // The original input and output has been renamed.
388     EXPECT_EQ(subgraph->tensors[quant_op_uint8_int8->outputs[0]]->name,
389               "input_int8");
390     EXPECT_EQ(subgraph->tensors[quant_op_int8_uint8->inputs[0]]->name,
391               "output_int8");
392     for (int tensor_idx = 0; tensor_idx < subgraph->tensors.size();
393          ++tensor_idx) {
394       const auto& tensor = subgraph->tensors[tensor_idx];
395       if (input_idx != tensor_idx && output_idx != tensor_idx) {
396         EXPECT_TRUE(tensor->type == TensorType_INT32 ||
397                     tensor->type == TensorType_INT8);
398       }
399     }
400   }
401 }
402 
403 class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
404  protected:
QuantizeConvNoBiasModelTest()405   QuantizeConvNoBiasModelTest() {
406     input_model_ = ReadModel(internal::kConvModelWithNoBias);
407     readonly_model_ = input_model_->GetModel();
408     readonly_model_->UnPackTo(&model_);
409   }
410 };
411 
TEST_F(QuantizeConvNoBiasModelTest,QuantizationSucceeds)412 TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
413   auto status = QuantizeModelAllOperators(
414       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
415       TensorType_INT8, TensorType_INT32, &error_reporter_);
416   EXPECT_EQ(status, kTfLiteOk);
417   const uint8_t* buffer = builder_.GetBufferPointer();
418   const Model* output_model = GetModel(buffer);
419   ASSERT_TRUE(output_model);
420 }
421 
422 class QuantizeConcatModelTest : public QuantizeModelTest,
423                                 public testing::WithParamInterface<TensorType> {
424  protected:
QuantizeConcatModelTest()425   QuantizeConcatModelTest() {
426     input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
427     readonly_model_ = input_model_->GetModel();
428     readonly_model_->UnPackTo(&model_);
429   }
430 
SetUp()431   void SetUp() override {
432     tensor_type_ = GetParam();
433     bias_type_ = GetBiasTensorType(tensor_type_);
434   }
435 
436   TensorType tensor_type_;
437   TensorType bias_type_;
438 };
439 
440 // There are two inputs for concat, "input0" and "input1". "input0" has [0, 5]
441 // as min/max and "input1" has [0, 10] as min/max. The output "output" for
442 // concat has [0, 10] as min/max.
443 // After applyging QuantizeModel(), "input0" will have a requant op added, along
444 // with a tensor "input0_reqaunt" that has [0, 10] as min/max. So the topology
445 // becomes:
446 // input0 -> requant -> input0_requant \
447 //                                       concat - output
448 //                              input1 /
TEST_P(QuantizeConcatModelTest,AddRequantBeforeConcat)449 TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) {
450   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
451                                           tensor_type_, false, tensor_type_,
452                                           bias_type_, &error_reporter_);
453   EXPECT_EQ(status, kTfLiteOk);
454 
455   // There is only one subgraph.
456   const int32_t subgraph_idx = 0;
457   const auto& subgraph = model_.subgraphs[subgraph_idx];
458   const auto& readonly_subgraph =
459       readonly_model_->subgraphs()->Get(subgraph_idx);
460 
461   // There should be two ops: quant and concat.
462   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
463   EXPECT_EQ(subgraph->operators.size(), 2);
464   const auto& requant = subgraph->operators[0];
465   const auto& concat = subgraph->operators[1];
466   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[requant->opcode_index].get()),
467             BuiltinOperator_QUANTIZE);
468   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[concat->opcode_index].get()),
469             BuiltinOperator_CONCATENATION);
470 
471   auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0;
472   /*
473      input0_scale_control
474         INT8: (5-0) / (2^8 - 1)
475         INT16: (5-0) / (2^16 / 2 - 1)
476      input1_scale
477         INT8: (10-0) / (2^8 - 1)
478         INT16: (10-0) / (2^16 / 2 - 1)
479   */
480   auto input0_scale_control =
481       tensor_type_ == TensorType_INT8 ? 0.019607844 : 0.00015259254;
482   auto input1_scale =
483       tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
484 
485   // There should be 4 tensors: input0, input1, input0_requantized, output.
486   EXPECT_EQ(subgraph->tensors.size(), 4);
487   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
488   EXPECT_EQ(subgraph->tensors[0]->name, "input0");
489   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
490   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
491   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0],
492                   input0_scale_control);
493   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0],
494                   zero_point_control);
495   EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_);
496   EXPECT_EQ(subgraph->tensors[1]->name, "input1");
497   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
498   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
499   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale);
500   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0],
501                   zero_point_control);
502   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
503   EXPECT_EQ(subgraph->tensors[2]->name, "output");
504   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
505   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
506   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale);
507   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0],
508                   zero_point_control);
509   EXPECT_EQ(subgraph->tensors[3]->type, tensor_type_);
510   EXPECT_EQ(subgraph->tensors[3]->name, "input0_requantized");
511   EXPECT_EQ(subgraph->tensors[3]->quantization->scale.size(), 1);
512   EXPECT_EQ(subgraph->tensors[3]->quantization->zero_point.size(), 1);
513   EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->scale[0], input1_scale);
514   EXPECT_FLOAT_EQ(subgraph->tensors[3]->quantization->zero_point[0],
515                   zero_point_control);
516 
517   // The connection should be what is described in the comment.
518   EXPECT_EQ(requant->inputs.size(), 1);
519   EXPECT_EQ(requant->outputs.size(), 1);
520   EXPECT_EQ(requant->inputs[0], 0);
521   EXPECT_EQ(requant->outputs[0], 3);
522   EXPECT_EQ(concat->inputs.size(), 2);
523   EXPECT_EQ(concat->outputs.size(), 1);
524   EXPECT_EQ(concat->inputs[0], 3);
525   EXPECT_EQ(concat->inputs[1], 1);
526   EXPECT_EQ(concat->outputs[0], 2);
527 
528   // check op and versioning.
529   EXPECT_EQ(model_.operator_codes.size(), 2);
530   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
531             BuiltinOperator_CONCATENATION);
532   EXPECT_EQ(model_.operator_codes[0]->version, 2);
533   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
534             BuiltinOperator_QUANTIZE);
535   EXPECT_EQ(model_.operator_codes[1]->version, 2);
536 }
537 INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest,
538                          testing::ValuesIn({TensorType_INT8,
539                                             TensorType_INT16}));
540 class QuantizeSplitModelTest : public QuantizeModelTest {
541  protected:
QuantizeSplitModelTest()542   QuantizeSplitModelTest() {
543     input_model_ = ReadModel(internal::kModelSplit);
544     readonly_model_ = input_model_->GetModel();
545     readonly_model_->UnPackTo(&model_);
546   }
547 };
548 
549 // There are two outputs for split with different scales, the resulting model
550 // should have the scales be hardcodes to the input scale value.
TEST_F(QuantizeSplitModelTest,QuantizeSplit)551 TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
552   auto status = QuantizeModelAllOperators(
553       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
554       TensorType_INT8, TensorType_INT32, &error_reporter_);
555   EXPECT_EQ(status, kTfLiteOk);
556 
557   // There is only one subgraph.
558   const int32_t subgraph_idx = 0;
559   const auto& subgraph = model_.subgraphs[subgraph_idx];
560   const auto& readonly_subgraph =
561       readonly_model_->subgraphs()->Get(subgraph_idx);
562 
563   // There should be two ops: the split and add in the original model.
564   EXPECT_EQ(readonly_subgraph->operators()->size(), 2);
565   EXPECT_EQ(subgraph->operators.size(), 2);
566   const auto& split = subgraph->operators[0];
567   const auto& add = subgraph->operators[1];
568   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[split->opcode_index].get()),
569             BuiltinOperator_SPLIT);
570   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[add->opcode_index].get()),
571             BuiltinOperator_ADD);
572 
573   // There should be 5 tensors: input, output, split, split/split_dim, split:1.
574   EXPECT_EQ(subgraph->tensors.size(), 5);
575 
576   EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8);
577   EXPECT_EQ(subgraph->tensors[0]->name, "input");
578   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
579   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
580   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 1.0);
581   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128);
582   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8);
583   EXPECT_EQ(subgraph->tensors[1]->name, "output");
584   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
585   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
586   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 1.0);
587   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128);
588   EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8);
589   EXPECT_EQ(subgraph->tensors[2]->name, "split");
590   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
591   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
592   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 1.0);
593   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128);
594   EXPECT_EQ(subgraph->tensors[4]->type, TensorType_INT8);
595   EXPECT_EQ(subgraph->tensors[4]->name, "split:1");
596   EXPECT_EQ(subgraph->tensors[4]->quantization->scale.size(), 1);
597   EXPECT_EQ(subgraph->tensors[4]->quantization->zero_point.size(), 1);
598   EXPECT_FLOAT_EQ(subgraph->tensors[4]->quantization->scale[0], 1.0);
599   EXPECT_FLOAT_EQ(subgraph->tensors[4]->quantization->zero_point[0], -128);
600 
601   // check op and versioning.
602   EXPECT_EQ(model_.operator_codes.size(), 2);
603   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
604             BuiltinOperator_SPLIT);
605   EXPECT_EQ(model_.operator_codes[0]->version, 2);
606 }
607 
608 class QuantizeConvModel1Test : public QuantizeModelTest {
609  protected:
QuantizeConvModel1Test()610   QuantizeConvModel1Test() {
611     input_model_ = ReadModel(internal::kConvModelWithMinus128Plus127Weights);
612     readonly_model_ = input_model_->GetModel();
613     readonly_model_->UnPackTo(&model_);
614   }
615 };
616 
TEST_F(QuantizeConvModel1Test,VerifyConvQuantizationWithUnitScale)617 TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) {
618   auto status = QuantizeModelAllOperators(
619       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
620       TensorType_INT8, TensorType_INT32, &error_reporter_);
621   EXPECT_EQ(status, kTfLiteOk);
622   const auto& subgraph = model_.subgraphs[0];
623 
624   auto conv_op = subgraph->operators[0].get();
625   const int input_tensor_idx = 0;
626   const int weights_tensor_idx = 1;
627   const int bias_tensor_index = 2;
628   const int output_tensor_idx = 0;
629   const auto bias_tensor =
630       subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
631   const auto input_tensor =
632       subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
633   const auto weights_tensor =
634       subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
635   const auto output_tensor =
636       subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
637 
638   EXPECT_EQ(bias_tensor->type, TensorType_INT32);
639   EXPECT_EQ(input_tensor->type, TensorType_INT8);
640   EXPECT_EQ(weights_tensor->type, TensorType_INT8);
641 
642   ASSERT_TRUE(weights_tensor->quantization);
643   const int out_channel_size = weights_tensor->shape[0];
644   ASSERT_TRUE(bias_tensor->quantization);
645   ASSERT_TRUE(weights_tensor->quantization);
646   const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
647   const std::vector<float>& weights_scales =
648       weights_tensor->quantization->scale;
649 
650   const std::vector<int64_t>& weights_zero_points =
651       weights_tensor->quantization->zero_point;
652 
653   ASSERT_EQ(bias_scales.size(), out_channel_size);
654   ASSERT_EQ(weights_scales.size(), out_channel_size);
655   ASSERT_EQ(weights_zero_points.size(), out_channel_size);
656   ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
657   ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
658 
659   for (size_t i = 0; i < out_channel_size; i++) {
660     EXPECT_EQ(weights_scales[i], 1);
661     EXPECT_EQ(bias_scales[i], 1);
662     EXPECT_EQ(weights_zero_points[i], 0);
663   }
664 
665   EXPECT_EQ(input_tensor->quantization->scale[0], 1);
666   EXPECT_EQ(output_tensor->quantization->scale[0], 1);
667 
668   const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
669   ASSERT_EQ(bias_buffer->data.size(), sizeof(int32_t) * bias_tensor->shape[0]);
670   const int32_t* bias_values =
671       reinterpret_cast<int32_t*>(bias_buffer->data.data());
672   const auto original_bias_buffer =
673       readonly_model_->buffers()->Get(bias_tensor->buffer);
674   const float* bias_float_buffer =
675       reinterpret_cast<const float*>(original_bias_buffer->data()->data());
676 
677   const float eps = 1e-7;
678   for (size_t i = 0; i < bias_tensor->shape[0]; i++) {
679     const float bias_scale =
680         input_tensor->quantization->scale[0] * weights_scales[i];
681     auto dequantized_value = bias_values[i] * bias_scale;
682     EXPECT_NEAR(dequantized_value, bias_float_buffer[i], eps);
683   }
684 
685   const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
686   const auto original_weights_buffer =
687       readonly_model_->buffers()->Get(weights_tensor->buffer);
688   const int8_t* weight_values =
689       reinterpret_cast<int8_t*>(weights_buffer->data.data());
690   const float* weights_float_buffer =
691       reinterpret_cast<const float*>(original_weights_buffer->data()->data());
692   ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
693             original_weights_buffer->data()->size());
694   int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
695   for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
696     for (size_t j = 0; j < num_values_in_channel; j++) {
697       size_t element_idx = channel_idx * out_channel_size + j;
698       auto dequantized_value =
699           weight_values[element_idx] * weights_scales[channel_idx];
700       EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx], eps);
701     }
702   }
703 
704   // check op and versioning.
705   EXPECT_EQ(model_.operator_codes.size(), 1);
706   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
707             BuiltinOperator_CONV_2D);
708   EXPECT_EQ(model_.operator_codes[0]->version, 3);
709 }
710 
711 class QuantizeConvModel2Test : public QuantizeModelTest,
712                                public testing::WithParamInterface<TensorType> {
713  protected:
QuantizeConvModel2Test()714   QuantizeConvModel2Test() {
715     tensor_type_ = GetParam();
716     bias_type_ = GetBiasTensorType(tensor_type_);
717     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
718     readonly_model_ = input_model_->GetModel();
719     readonly_model_->UnPackTo(&model_);
720   }
721 
722   TensorType tensor_type_;
723   TensorType bias_type_;
724 };
725 INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
726                          testing::ValuesIn({TensorType_INT8,
727                                             TensorType_INT16}));
728 
TEST_P(QuantizeConvModel2Test,VerifyConvQuantization)729 TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
730   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
731                                           tensor_type_, false, tensor_type_,
732                                           bias_type_, &error_reporter_);
733   ASSERT_EQ(kTfLiteOk, status);
734   const auto& subgraph = model_.subgraphs[0];
735   auto conv_op = subgraph->operators[0].get();
736   const int input_tensor_idx = 0;
737   const int weights_tensor_idx = 1;
738   const int bias_tensor_index = 2;
739   const int output_tensor_idx = 0;
740   const auto bias_tensor =
741       subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
742   const auto input_tensor =
743       subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
744   const auto weights_tensor =
745       subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
746   const auto output_tensor =
747       subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
748 
749   EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
750                                    ? TensorType_INT32
751                                    : TensorType_INT64);
752   EXPECT_EQ(input_tensor->type, tensor_type_);
753   EXPECT_EQ(weights_tensor->type, TensorType_INT8);
754 
755   ASSERT_TRUE(weights_tensor->quantization);
756   ASSERT_TRUE(bias_tensor->quantization);
757   ASSERT_TRUE(weights_tensor->quantization);
758   const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
759   const std::vector<float>& weights_scales =
760       weights_tensor->quantization->scale;
761   const std::vector<int64_t>& weights_zero_points =
762       weights_tensor->quantization->zero_point;
763   const int out_channel_size = weights_tensor->shape[0];
764   ASSERT_EQ(bias_scales.size(), out_channel_size);
765   ASSERT_EQ(weights_scales.size(), out_channel_size);
766   ASSERT_EQ(weights_zero_points.size(), out_channel_size);
767   ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
768   ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
769 
770   const float eps = 1e-7;
771 
772   // Bias scale should be input * per_channel_weight_scale.
773   for (size_t i = 0; i < out_channel_size; i++) {
774     EXPECT_NEAR(bias_scales[i],
775                 input_tensor->quantization->scale[0] * weights_scales[i], eps);
776   }
777 
778   const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
779   auto control_size = tensor_type_ == TensorType_INT8
780                           ? sizeof(int32_t) * bias_tensor->shape[0]
781                           : sizeof(int64_t) * bias_tensor->shape[0];
782 
783   ASSERT_EQ(bias_buffer->data.size(), control_size);
784   const auto original_bias_buffer =
785       readonly_model_->buffers()->Get(bias_tensor->buffer);
786   const float* bias_float_buffer =
787       reinterpret_cast<const float*>(original_bias_buffer->data()->data());
788 
789   if (tensor_type_ == TensorType_INT8) {
790     int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
791     for (size_t i = 0; i < out_channel_size; i++) {
792       auto dequantized_value = bias_values[i] * bias_scales[i];
793       EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
794     }
795   } else if (tensor_type_ == TensorType_INT16) {
796     int64_t* bias_values = reinterpret_cast<int64_t*>(bias_buffer->data.data());
797     for (size_t i = 0; i < out_channel_size; i++) {
798       auto dequantized_value = bias_values[i] * bias_scales[i];
799       EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
800     }
801   }
802 
803   const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
804   const auto original_weights_buffer =
805       readonly_model_->buffers()->Get(weights_tensor->buffer);
806   const int8_t* weight_values =
807       reinterpret_cast<int8_t*>(weights_buffer->data.data());
808   const float* weights_float_buffer =
809       reinterpret_cast<const float*>(original_weights_buffer->data()->data());
810   ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
811             original_weights_buffer->data()->size());
812   int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
813   for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
814     for (size_t j = 0; j < num_values_in_channel; j++) {
815       size_t element_idx = channel_idx * out_channel_size + j;
816       auto scale = weights_scales[channel_idx];
817       auto zero_point = weights_zero_points[channel_idx];
818       auto dequantized_value = weight_values[element_idx] * scale;
819       EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx],
820                   scale / 2);
821       EXPECT_EQ(zero_point, 0);
822     }
823   }
824 
825   // check op and versioning.
826   EXPECT_EQ(model_.operator_codes.size(), 1);
827   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
828             BuiltinOperator_CONV_2D);
829   EXPECT_EQ(model_.operator_codes[0]->version, 3);
830 }
831 
TEST_P(QuantizeConvModel2Test,VerifyConvDisablePerChannelQuantization)832 TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) {
833   auto status =
834       QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
835                                 false, tensor_type_, bias_type_,
836                                 /*disable_per_channel=*/true, &error_reporter_);
837   ASSERT_EQ(kTfLiteOk, status);
838   const auto& subgraph = model_.subgraphs[0];
839   auto conv_op = subgraph->operators[0].get();
840   const int input_tensor_idx = 0;
841   const int weights_tensor_idx = 1;
842   const int bias_tensor_index = 2;
843   const int output_tensor_idx = 0;
844   const auto bias_tensor =
845       subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
846   const auto input_tensor =
847       subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
848   const auto weights_tensor =
849       subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
850   const auto output_tensor =
851       subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
852 
853   EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
854                                    ? TensorType_INT32
855                                    : TensorType_INT64);
856   EXPECT_EQ(input_tensor->type, tensor_type_);
857   EXPECT_EQ(weights_tensor->type, TensorType_INT8);
858 
859   ASSERT_TRUE(weights_tensor->quantization);
860   ASSERT_TRUE(bias_tensor->quantization);
861   ASSERT_TRUE(weights_tensor->quantization);
862   const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
863   const std::vector<float>& weights_scales =
864       weights_tensor->quantization->scale;
865   const std::vector<int64_t>& weights_zero_points =
866       weights_tensor->quantization->zero_point;
867 
868   const int out_channel_size = 1;
869   ASSERT_EQ(bias_scales.size(), out_channel_size);
870   ASSERT_EQ(weights_scales.size(), out_channel_size);
871   ASSERT_EQ(weights_zero_points.size(), out_channel_size);
872   ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
873   ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
874 
875   const float eps = 1e-7;
876 
877   // Bias scale should be input * per_channel_weight_scale.
878   for (size_t i = 0; i < out_channel_size; i++) {
879     EXPECT_NEAR(bias_scales[i],
880                 input_tensor->quantization->scale[0] * weights_scales[i], eps);
881   }
882 
883   const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
884   auto control_size = tensor_type_ == TensorType_INT8
885                           ? sizeof(int32_t) * bias_tensor->shape[0]
886                           : sizeof(int64_t) * bias_tensor->shape[0];
887 
888   ASSERT_EQ(bias_buffer->data.size(), control_size);
889   const auto original_bias_buffer =
890       readonly_model_->buffers()->Get(bias_tensor->buffer);
891   const float* bias_float_buffer =
892       reinterpret_cast<const float*>(original_bias_buffer->data()->data());
893 
894   if (tensor_type_ == TensorType_INT8) {
895     int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
896     for (size_t i = 0; i < out_channel_size; i++) {
897       auto dequantized_value = bias_values[i] * bias_scales[i];
898       EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
899     }
900   } else if (tensor_type_ == TensorType_INT16) {
901     int64_t* bias_values = reinterpret_cast<int64_t*>(bias_buffer->data.data());
902     for (size_t i = 0; i < out_channel_size; i++) {
903       auto dequantized_value = bias_values[i] * bias_scales[i];
904       EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
905     }
906   }
907 
908   const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
909   const auto original_weights_buffer =
910       readonly_model_->buffers()->Get(weights_tensor->buffer);
911   const int8_t* weight_values =
912       reinterpret_cast<int8_t*>(weights_buffer->data.data());
913   const float* weights_float_buffer =
914       reinterpret_cast<const float*>(original_weights_buffer->data()->data());
915   ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
916             original_weights_buffer->data()->size());
917   int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
918   for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
919     for (size_t j = 0; j < num_values_in_channel; j++) {
920       size_t element_idx = channel_idx * out_channel_size + j;
921       auto scale = weights_scales[channel_idx];
922       auto zero_point = weights_zero_points[channel_idx];
923       auto dequantized_value = weight_values[element_idx] * scale;
924       EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx],
925                   scale / 2);
926       EXPECT_EQ(zero_point, 0);
927     }
928   }
929 
930   // check op and versioning.
931   EXPECT_EQ(model_.operator_codes.size(), 1);
932   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
933             BuiltinOperator_CONV_2D);
934   EXPECT_EQ(model_.operator_codes[0]->version, 3);
935 }
936 
937 class QuantizeSoftmaxTest : public QuantizeModelTest {
938  protected:
QuantizeSoftmaxTest()939   QuantizeSoftmaxTest() {
940     input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5);
941     readonly_model_ = input_model_->GetModel();
942     readonly_model_->UnPackTo(&model_);
943   }
944 };
945 
TEST_F(QuantizeSoftmaxTest,VerifySoftmaxQuantization)946 TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
947   auto status = QuantizeModelAllOperators(
948       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
949       TensorType_INT8, TensorType_INT32, &error_reporter_);
950   ASSERT_EQ(kTfLiteOk, status);
951 
952   const auto& subgraph = model_.subgraphs[0];
953   auto op = subgraph->operators[0].get();
954   // Model has a single softmax op.
955   ASSERT_EQ(op->opcode_index, 0);
956   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
957             BuiltinOperator_SOFTMAX);
958 
959   ASSERT_EQ(op->inputs.size(), 1);
960   ASSERT_EQ(op->outputs.size(), 1);
961   auto float_graph = readonly_model_->subgraphs()->Get(0);
962 
963   // Verify input.
964   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
965             TensorType_FLOAT32);
966   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
967             TensorType_FLOAT32);
968 
969   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
970   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
971 
972   auto float_input_quant_params =
973       float_graph->tensors()->Get(op->inputs[0])->quantization();
974   auto input_quant_params =
975       subgraph->tensors[op->inputs[0]]->quantization.get();
976   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
977                                     *input_quant_params);
978 
979   // Verify output.
980   auto float_output_quant_params =
981       float_graph->tensors()->Get(op->outputs[0])->quantization();
982   auto output_quant_params =
983       subgraph->tensors[op->outputs[0]]->quantization.get();
984   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
985   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
986 
987   ASSERT_EQ(output_quant_params->scale.size(), 1);
988   ASSERT_EQ(output_quant_params->zero_point.size(), 1);
989   ASSERT_EQ(1.0f / 256.0f, output_quant_params->scale[0]);
990   ASSERT_EQ(-128, output_quant_params->zero_point[0]);
991 
992   // check op and versioning.
993   EXPECT_EQ(model_.operator_codes.size(), 1);
994   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
995             BuiltinOperator_SOFTMAX);
996   EXPECT_EQ(model_.operator_codes[0]->version, 2);
997 }
998 
999 class QuantizeAvgPoolTest : public QuantizeModelTest {
1000  protected:
QuantizeAvgPoolTest()1001   QuantizeAvgPoolTest() {
1002     input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5);
1003     readonly_model_ = input_model_->GetModel();
1004     readonly_model_->UnPackTo(&model_);
1005   }
1006 };
1007 
TEST_F(QuantizeAvgPoolTest,VerifyAvgPoolQuantization)1008 TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
1009   auto status = QuantizeModelAllOperators(
1010       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
1011       TensorType_INT8, TensorType_INT32, &error_reporter_);
1012   ASSERT_EQ(kTfLiteOk, status);
1013 
1014   const auto& subgraph = model_.subgraphs[0];
1015   auto op = subgraph->operators[0].get();
1016   // Model has a single AveragePool op.
1017   ASSERT_EQ(op->opcode_index, 0);
1018   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1019             BuiltinOperator_AVERAGE_POOL_2D);
1020 
1021   ASSERT_EQ(op->inputs.size(), 1);
1022   ASSERT_EQ(op->outputs.size(), 1);
1023 
1024   auto float_graph = readonly_model_->subgraphs()->Get(0);
1025   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1026             TensorType_FLOAT32);
1027   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1028             TensorType_FLOAT32);
1029 
1030   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1031   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1032 
1033   auto float_input_quant_params =
1034       float_graph->tensors()->Get(op->inputs[0])->quantization();
1035   auto input_quant_params =
1036       subgraph->tensors[op->inputs[0]]->quantization.get();
1037   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
1038                                     *input_quant_params);
1039 
1040   auto float_output_quant_params =
1041       float_graph->tensors()->Get(op->outputs[0])->quantization();
1042   auto output_quant_params =
1043       subgraph->tensors[op->outputs[0]]->quantization.get();
1044   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
1045   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
1046   ASSERT_EQ(output_quant_params->min.size(), 1);
1047   ASSERT_EQ(output_quant_params->max.size(), 1);
1048 
1049   // Make sure the input min/maxes are propagated to outputs.
1050   EXPECT_EQ(input_quant_params->min[0], output_quant_params->min[0]);
1051   EXPECT_EQ(input_quant_params->max[0], output_quant_params->max[0]);
1052   EXPECT_EQ(input_quant_params->scale[0], output_quant_params->scale[0]);
1053 
1054   // check op and versioning.
1055   EXPECT_EQ(model_.operator_codes.size(), 1);
1056   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1057             BuiltinOperator_AVERAGE_POOL_2D);
1058   EXPECT_EQ(model_.operator_codes[0]->version, 2);
1059 }
1060 
1061 class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
1062  protected:
QuantizeMultiInputAddWithReshapeTest()1063   QuantizeMultiInputAddWithReshapeTest() {
1064     input_model_ = ReadModel(internal::kMultiInputAddWithReshape);
1065     readonly_model_ = input_model_->GetModel();
1066     readonly_model_->UnPackTo(&model_);
1067   }
1068 };
1069 
TEST_F(QuantizeMultiInputAddWithReshapeTest,VerifyReshapeQuantization)1070 TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
1071   auto status = QuantizeModelAllOperators(
1072       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
1073       TensorType_INT8, TensorType_INT32, &error_reporter_);
1074   ASSERT_EQ(kTfLiteOk, status);
1075 
1076   // Verify Reshape is quantized.
1077   const auto& subgraph = model_.subgraphs[0];
1078   auto op = subgraph->operators[1].get();
1079   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1080             BuiltinOperator_RESHAPE);
1081 
1082   ASSERT_EQ(op->inputs.size(), 2);
1083   ASSERT_EQ(op->outputs.size(), 1);
1084 
1085   auto float_graph = readonly_model_->subgraphs()->Get(0);
1086   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1087             TensorType_FLOAT32);
1088   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1089             TensorType_FLOAT32);
1090 
1091   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1092   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1093   auto float_input_quant_params =
1094       float_graph->tensors()->Get(op->inputs[0])->quantization();
1095   auto input_quant_params =
1096       subgraph->tensors[op->inputs[0]]->quantization.get();
1097   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
1098                                     *input_quant_params);
1099 
1100   auto float_output_quant_params =
1101       float_graph->tensors()->Get(op->outputs[0])->quantization();
1102   auto output_quant_params =
1103       subgraph->tensors[op->outputs[0]]->quantization.get();
1104   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
1105   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
1106   ASSERT_EQ(output_quant_params->min.size(), 1);
1107   ASSERT_EQ(output_quant_params->max.size(), 1);
1108 
1109   // check op and versioning.
1110   EXPECT_EQ(model_.operator_codes.size(), 2);
1111   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1112             BuiltinOperator_ADD);
1113   EXPECT_EQ(model_.operator_codes[0]->version, 2);
1114   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
1115             BuiltinOperator_RESHAPE);
1116   EXPECT_EQ(model_.operator_codes[1]->version, 1);
1117 }
1118 
TEST_F(QuantizeMultiInputAddWithReshapeTest,VerifyAddQuantization)1119 TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
1120   auto status = QuantizeModelAllOperators(
1121       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
1122       TensorType_INT8, TensorType_INT32, &error_reporter_);
1123   ASSERT_EQ(kTfLiteOk, status);
1124 
1125   // Verify ADD is quantized.
1126   const auto& subgraph = model_.subgraphs[0];
1127   auto op = subgraph->operators[0].get();
1128   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1129             BuiltinOperator_ADD);
1130 
1131   ASSERT_EQ(op->inputs.size(), 2);
1132   ASSERT_EQ(op->outputs.size(), 1);
1133 
1134   auto float_graph = readonly_model_->subgraphs()->Get(0);
1135   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1136             TensorType_FLOAT32);
1137   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
1138             TensorType_FLOAT32);
1139   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1140             TensorType_FLOAT32);
1141 
1142   for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
1143     EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
1144               TensorType_INT8);
1145     auto float_input_quant_params =
1146         float_graph->tensors()->Get(op->inputs[input_idx])->quantization();
1147     auto input_quant_params =
1148         subgraph->tensors[op->inputs[input_idx]]->quantization.get();
1149     VerifyAsymmetricQuantizationScale(*float_input_quant_params,
1150                                       *input_quant_params);
1151   }
1152 
1153   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1154   auto float_output_quant_params =
1155       float_graph->tensors()->Get(op->outputs[0])->quantization();
1156   auto output_quant_params =
1157       subgraph->tensors[op->outputs[0]]->quantization.get();
1158   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
1159   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
1160   ASSERT_EQ(output_quant_params->min.size(), 1);
1161   ASSERT_EQ(output_quant_params->max.size(), 1);
1162 
1163   // check op and versioning.
1164   EXPECT_EQ(model_.operator_codes.size(), 2);
1165   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1166             BuiltinOperator_ADD);
1167   EXPECT_EQ(model_.operator_codes[0]->version, 2);
1168   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
1169             BuiltinOperator_RESHAPE);
1170   EXPECT_EQ(model_.operator_codes[1]->version, 1);
1171 }
1172 
1173 class QuantizeConstInputTest : public QuantizeModelTest,
1174                                public testing::WithParamInterface<TensorType> {
1175  protected:
QuantizeConstInputTest()1176   QuantizeConstInputTest() {
1177     tensor_type_ = GetParam();
1178     bias_type_ = GetBiasTensorType(tensor_type_);
1179     input_model_ = ReadModel(internal::kConstInputAddModel);
1180     readonly_model_ = input_model_->GetModel();
1181     readonly_model_->UnPackTo(&model_);
1182   }
1183 
1184   TensorType tensor_type_;
1185   TensorType bias_type_;
1186 };
1187 INSTANTIATE_TEST_SUITE_P(QuantizeConstInputTestInst, QuantizeConstInputTest,
1188                          testing::ValuesIn({TensorType_INT8,
1189                                             TensorType_INT16}));
1190 
TEST_P(QuantizeConstInputTest,VerifyConstOpInput)1191 TEST_P(QuantizeConstInputTest, VerifyConstOpInput) {
1192   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
1193                                           tensor_type_, false, tensor_type_,
1194                                           bias_type_, &error_reporter_);
1195   ASSERT_EQ(kTfLiteOk, status);
1196 
1197   // Verify ConstOp is quantized.
1198   const auto& subgraph = model_.subgraphs[0];
1199   auto op = subgraph->operators[0].get();
1200   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1201             BuiltinOperator_ADD);
1202 
1203   ASSERT_EQ(op->inputs.size(), 2);
1204   ASSERT_EQ(op->outputs.size(), 1);
1205 
1206   auto float_graph = readonly_model_->subgraphs()->Get(0);
1207   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1208             TensorType_FLOAT32);
1209   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1210             TensorType_FLOAT32);
1211 
1212   for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
1213     EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
1214               tensor_type_);
1215   }
1216 
1217   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, tensor_type_);
1218 
1219   // check op and versioning.
1220   EXPECT_EQ(model_.operator_codes.size(), 1);
1221   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1222             BuiltinOperator_ADD);
1223   EXPECT_EQ(model_.operator_codes[0]->version, 2);
1224 
1225   // check that in case of int16 activations, pot_scale_int16 parameter is set
1226   // to false.
1227   if (tensor_type_ == TensorType_INT16) {
1228     EXPECT_EQ(subgraph->operators[0]
1229                   .get()
1230                   ->builtin_options.AsAddOptions()
1231                   ->pot_scale_int16,
1232               false);
1233   }
1234 }
1235 class QuantizeArgMaxTest : public QuantizeModelTest {
1236  protected:
QuantizeArgMaxTest()1237   QuantizeArgMaxTest() {
1238     input_model_ = ReadModel(internal::kModelWithArgMaxOp);
1239     readonly_model_ = input_model_->GetModel();
1240     readonly_model_->UnPackTo(&model_);
1241   }
1242 };
1243 
TEST_F(QuantizeArgMaxTest,VerifyArgMax)1244 TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
1245   auto status = QuantizeModelAllOperators(
1246       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
1247       TensorType_INT8, TensorType_INT32, &error_reporter_);
1248   ASSERT_EQ(kTfLiteOk, status);
1249 
1250   const auto& subgraph = model_.subgraphs[0];
1251   auto op = subgraph->operators[0].get();
1252   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1253             BuiltinOperator_ARG_MAX);
1254 
1255   ASSERT_EQ(op->inputs.size(), 2);
1256   ASSERT_EQ(op->outputs.size(), 1);
1257 
1258   auto float_graph = readonly_model_->subgraphs()->Get(0);
1259   // Verify ArgMax input is quantized.
1260   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1261             TensorType_FLOAT32);
1262   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1263 
1264   // Verify ArgMax input axis should still be the same type.
1265   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
1266             subgraph->tensors[op->inputs[1]].get()->type);
1267 
1268   // The output of ArgMax should still be the same type.
1269   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1270             subgraph->tensors[op->outputs[0]].get()->type);
1271 
1272   // check op and versioning.
1273   EXPECT_EQ(model_.operator_codes.size(), 1);
1274   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1275             BuiltinOperator_ARG_MAX);
1276   EXPECT_EQ(model_.operator_codes[0]->version, 2);
1277 }
1278 
1279 class QuantizeLSTMTest : public QuantizeModelTest {
1280  protected:
QuantizeLSTMTest()1281   QuantizeLSTMTest() {
1282     input_model_ = ReadModel(internal::kLstmCalibrated);
1283     readonly_model_ = input_model_->GetModel();
1284     readonly_model_->UnPackTo(&model_);
1285   }
1286 };
1287 
TEST_F(QuantizeLSTMTest,VerifyLSTM)1288 TEST_F(QuantizeLSTMTest, VerifyLSTM) {
1289   // Quantize model.
1290   auto status = QuantizeModelAllOperators(
1291       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
1292       TensorType_INT8, TensorType_INT32, &error_reporter_);
1293   ASSERT_EQ(kTfLiteOk, status);
1294 
1295   // Read expected model.
1296   auto expected_fb_model = ReadModel(internal::kLstmQuantized);
1297   auto expected_read_only_model = expected_fb_model->GetModel();
1298   ModelT expected_model;
1299   expected_read_only_model->UnPackTo(&expected_model);
1300 
1301   ExpectSameModels(model_, expected_model);
1302 }
1303 
1304 class QuantizeLSTM2Test : public QuantizeModelTest {
1305  protected:
QuantizeLSTM2Test()1306   QuantizeLSTM2Test() {
1307     input_model_ = ReadModel(internal::kLstmCalibrated2);
1308     readonly_model_ = input_model_->GetModel();
1309     readonly_model_->UnPackTo(&model_);
1310   }
1311 };
1312 
TEST_F(QuantizeLSTM2Test,VerifyLSTM)1313 TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
1314   // Quantize model.
1315   auto status = QuantizeModelAllOperators(
1316       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
1317       TensorType_INT8, TensorType_INT32, &error_reporter_);
1318   ASSERT_EQ(kTfLiteOk, status);
1319 
1320   // Read expected model.
1321   auto expected_fb_model = ReadModel(internal::kLstmQuantized2);
1322   auto expected_read_only_model = expected_fb_model->GetModel();
1323   ModelT expected_model;
1324   expected_read_only_model->UnPackTo(&expected_model);
1325 
1326   ExpectSameModels(model_, expected_model);
1327 }
1328 
1329 class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest {
1330  protected:
QuantizeUnidirectionalSequenceLSTMTest()1331   QuantizeUnidirectionalSequenceLSTMTest() {
1332     input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated);
1333     readonly_model_ = input_model_->GetModel();
1334     readonly_model_->UnPackTo(&model_);
1335   }
1336 };
1337 
TEST_F(QuantizeUnidirectionalSequenceLSTMTest,VerifyUnidirectionalSequenceLSTM)1338 TEST_F(QuantizeUnidirectionalSequenceLSTMTest,
1339        VerifyUnidirectionalSequenceLSTM) {
1340   // Quantize model.
1341   auto status = QuantizeModelAllOperators(
1342       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
1343       TensorType_INT8, TensorType_INT32, &error_reporter_);
1344   ASSERT_EQ(kTfLiteOk, status);
1345 
1346   // Read expected model.
1347   auto expected_fb_model =
1348       ReadModel(internal::kUnidirectionalSequenceLstmQuantized);
1349   auto expected_read_only_model = expected_fb_model->GetModel();
1350   ModelT expected_model;
1351   expected_read_only_model->UnPackTo(&expected_model);
1352 
1353   ExpectSameModels(model_, expected_model);
1354 }
1355 
1356 class QuantizeSVDFTest : public QuantizeModelTest {
1357  protected:
QuantizeSVDFTest()1358   QuantizeSVDFTest() {
1359     input_model_ = ReadModel(internal::kSvdfCalibrated);
1360     readonly_model_ = input_model_->GetModel();
1361     readonly_model_->UnPackTo(&model_);
1362   }
1363 };
1364 
TEST_F(QuantizeSVDFTest,VerifySVDF)1365 TEST_F(QuantizeSVDFTest, VerifySVDF) {
1366   // Quantize model.
1367   auto status = QuantizeModelAllOperators(
1368       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
1369       TensorType_INT8, TensorType_INT32, &error_reporter_);
1370   ASSERT_EQ(kTfLiteOk, status);
1371 
1372   // Read expected model.
1373   auto expected_fb_model = ReadModel(internal::kSvdfQuantized);
1374   auto expected_read_only_model = expected_fb_model->GetModel();
1375   ModelT expected_model;
1376   expected_read_only_model->UnPackTo(&expected_model);
1377 
1378   // Comparison.
1379   ASSERT_EQ(model_.subgraphs.size(), expected_model.subgraphs.size());
1380   for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
1381        subgraph_idx++) {
1382     const auto graph = model_.subgraphs[subgraph_idx].get();
1383     const auto expected_graph = expected_model.subgraphs[subgraph_idx].get();
1384     ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size());
1385     for (size_t i = 0; i < graph->tensors.size(); i++) {
1386       const auto tensor = graph->tensors[i].get();
1387       const auto expected_tensor = expected_graph->tensors[i].get();
1388       EXPECT_EQ(tensor->buffer, expected_tensor->buffer);
1389       EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable);
1390       EXPECT_EQ(tensor->shape, expected_tensor->shape);
1391       EXPECT_EQ(tensor->name, expected_tensor->name);
1392       EXPECT_EQ(tensor->type, expected_tensor->type);
1393       const auto quantization_params = tensor->quantization.get();
1394       const auto expected_quantization_params =
1395           expected_tensor->quantization.get();
1396       if (quantization_params != nullptr ||
1397           expected_quantization_params != nullptr) {
1398         EXPECT_NE(quantization_params, nullptr);
1399         EXPECT_NE(expected_quantization_params, nullptr);
1400         EXPECT_EQ(quantization_params->scale,
1401                   expected_quantization_params->scale);
1402         EXPECT_EQ(quantization_params->zero_point,
1403                   expected_quantization_params->zero_point);
1404       }
1405     }
1406   }
1407   ASSERT_EQ(model_.buffers.size(), expected_model.buffers.size());
1408   for (size_t buffer_idx = 0; buffer_idx < model_.buffers.size();
1409        ++buffer_idx) {
1410     const auto buffer = model_.buffers[buffer_idx].get()->data;
1411     const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data;
1412     EXPECT_EQ(buffer, expected_buffer);
1413   }
1414 }
1415 
1416 class QuantizeFCTest : public QuantizeModelTest {
1417  protected:
QuantizeFCTest()1418   QuantizeFCTest() {
1419     input_model_ = ReadModel(internal::kModelWithFCOp);
1420     readonly_model_ = input_model_->GetModel();
1421     readonly_model_->UnPackTo(&model_);
1422   }
1423 };
1424 
TEST_F(QuantizeFCTest,VerifyFC)1425 TEST_F(QuantizeFCTest, VerifyFC) {
1426   auto status = QuantizeModelAllOperators(
1427       &builder_, &model_, TensorType_INT8, TensorType_INT8, false,
1428       TensorType_INT8, TensorType_INT32, &error_reporter_);
1429   ASSERT_EQ(kTfLiteOk, status);
1430 
1431   const auto& subgraph = model_.subgraphs[0];
1432   auto op = subgraph->operators[0].get();
1433   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1434             BuiltinOperator_FULLY_CONNECTED);
1435 
1436   ASSERT_EQ(op->inputs.size(), 3);
1437   ASSERT_EQ(op->outputs.size(), 1);
1438 
1439   auto float_graph = readonly_model_->subgraphs()->Get(0);
1440   // Verify FC input and weight is quantized.
1441   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1442             TensorType_FLOAT32);
1443   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1444   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
1445             TensorType_FLOAT32);
1446   EXPECT_EQ(subgraph->tensors[op->inputs[1]].get()->type, TensorType_INT8);
1447 
1448   // Verify FC bias should be int32 quantized.
1449   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[2])->type(),
1450             TensorType_FLOAT32);
1451   EXPECT_EQ(subgraph->tensors[op->inputs[2]].get()->type, TensorType_INT32);
1452 
1453   // The output of FC should be quantized.
1454   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
1455             TensorType_FLOAT32);
1456   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1457 
1458   // check op and versioning.
1459   EXPECT_EQ(model_.operator_codes.size(), 2);
1460   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1461             BuiltinOperator_FULLY_CONNECTED);
1462   EXPECT_EQ(model_.operator_codes[0]->version, 4);
1463   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
1464             BuiltinOperator_RESHAPE);
1465   EXPECT_EQ(model_.operator_codes[1]->version, 1);
1466 }
1467 
1468 class QuantizeCustomOpTest
1469     : public QuantizeModelTest,
1470       public ::testing::WithParamInterface<tflite::TensorType> {
1471  protected:
QuantizeCustomOpTest()1472   QuantizeCustomOpTest() {
1473     tensor_type_ = GetParam();
1474     bias_type_ = GetBiasTensorType(tensor_type_);
1475     input_model_ = ReadModel(internal::kModelMixed);
1476     readonly_model_ = input_model_->GetModel();
1477     readonly_model_->UnPackTo(&model_);
1478   }
1479 
1480   TensorType tensor_type_;
1481   TensorType bias_type_;
1482 };
1483 
TEST_P(QuantizeCustomOpTest,VerifyMixedQuantization)1484 TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) {
1485   auto status = QuantizeModelAllOperators(
1486       &builder_, &model_, tensor_type_, tensor_type_,
1487       /*allow_float=*/true, tensor_type_, bias_type_, &error_reporter_);
1488   ASSERT_EQ(kTfLiteOk, status);
1489   const auto& subgraph = model_.subgraphs[0];
1490   auto float_graph = readonly_model_->subgraphs()->Get(0);
1491   // The original model reshape->custom->custom->squeeze.
1492   ASSERT_EQ(float_graph->operators()->size(), 4);
1493   // The resulting model should be:
1494   // reshape->dequantize->custom->custom->quantize->squeeze.
1495   ASSERT_EQ(subgraph->operators.size(), 6);
1496   const std::vector<BuiltinOperator> op_codes = {
1497       BuiltinOperator_RESHAPE,  BuiltinOperator_DEQUANTIZE,
1498       BuiltinOperator_CUSTOM,   BuiltinOperator_CUSTOM,
1499       BuiltinOperator_QUANTIZE, BuiltinOperator_SQUEEZE};
1500   const std::vector<TensorType> op_input_types = {
1501       GetParam(),         GetParam(),         TensorType_FLOAT32,
1502       TensorType_FLOAT32, TensorType_FLOAT32, GetParam()};
1503   for (int i = 0; i < subgraph->operators.size(); ++i) {
1504     OperatorT* op = subgraph->operators[i].get();
1505     ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1506               op_codes[i]);
1507     ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]);
1508   }
1509 }
1510 
1511 INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest,
1512                          ::testing::Values(TensorType_INT8, TensorType_INT16));
1513 
1514 class QuantizeOp16x8Test : public QuantizeModelTest {
1515  protected:
QuantizeOp16x8Test()1516   QuantizeOp16x8Test() {
1517     input_model_ = ReadModel(internal::kModelMixed16x8);
1518     readonly_model_ = input_model_->GetModel();
1519     readonly_model_->UnPackTo(&model_);
1520   }
1521 };
1522 
TEST_F(QuantizeOp16x8Test,VerifyMixedQuantization16x8)1523 TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) {
1524   auto status = QuantizeModelAllOperators(
1525       &builder_, &model_, TensorType_INT16, TensorType_FLOAT32,
1526       /*allow_float=*/true, TensorType_INT16, TensorType_INT64,
1527       &error_reporter_);
1528   ASSERT_EQ(kTfLiteOk, status);
1529   const auto& subgraph = model_.subgraphs[0];
1530   auto float_graph = readonly_model_->subgraphs()->Get(0);
1531   // The original model conv_2d->log_softmax
1532   ASSERT_EQ(float_graph->operators()->size(), 2);
1533   // The resulting model should be:
1534   // conv_2d->dequantize->log_softmax
1535   ASSERT_EQ(subgraph->operators.size(), 3);
1536   const std::vector<BuiltinOperator> op_codes = {BuiltinOperator_CONV_2D,
1537                                                  BuiltinOperator_DEQUANTIZE,
1538                                                  BuiltinOperator_LOG_SOFTMAX};
1539   const std::vector<TensorType> op_input_types = {
1540       TensorType_INT16, TensorType_INT16, TensorType_FLOAT32};
1541   for (int i = 0; i < subgraph->operators.size(); ++i) {
1542     OperatorT* op = subgraph->operators[i].get();
1543     ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1544               op_codes[i]);
1545     ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]);
1546   }
1547 }
1548 
1549 class QuantizePackTest : public QuantizeModelTest {
1550  protected:
QuantizePackTest()1551   QuantizePackTest() {
1552     input_model_ = ReadModel(internal::kModelPack);
1553     readonly_model_ = input_model_->GetModel();
1554     readonly_model_->UnPackTo(&model_);
1555   }
1556 };
1557 
TEST_F(QuantizePackTest,VerifyPack)1558 TEST_F(QuantizePackTest, VerifyPack) {
1559   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1560 
1561   ASSERT_EQ(kTfLiteOk, status);
1562 
1563   const auto subgraph = model_.subgraphs[0].get();
1564 
1565   // The model should only have 3 inputs and 1 output.
1566   EXPECT_EQ(subgraph->inputs.size(), 3);
1567   EXPECT_EQ(subgraph->outputs.size(), 1);
1568 
1569   const auto& op1 = subgraph->operators[1].get();
1570   const auto& op2 = subgraph->operators[2].get();
1571   const auto& op3 = subgraph->operators[3].get();
1572   const auto& op4 = subgraph->operators[4].get();
1573 
1574   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op1->opcode_index].get()),
1575             BuiltinOperator_QUANTIZE);
1576   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op2->opcode_index].get()),
1577             BuiltinOperator_QUANTIZE);
1578   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op3->opcode_index].get()),
1579             BuiltinOperator_PACK);
1580   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op4->opcode_index].get()),
1581             BuiltinOperator_DEQUANTIZE);
1582 
1583   const auto& pack_input0 = subgraph->tensors[op3->inputs[0]].get();
1584   const auto& pack_input1 = subgraph->tensors[op3->inputs[1]].get();
1585   const auto& pack_input2 = subgraph->tensors[op3->inputs[2]].get();
1586 
1587   const auto& pack_output = subgraph->tensors[op3->outputs[0]].get();
1588 
1589   // Check quantization parameters for input and output.
1590   EXPECT_FLOAT_EQ(pack_input0->quantization->scale[0],
1591                   pack_input1->quantization->scale[0]);
1592   EXPECT_FLOAT_EQ(pack_input1->quantization->scale[0],
1593                   pack_input2->quantization->scale[0]);
1594   EXPECT_FLOAT_EQ(pack_input0->quantization->zero_point[0],
1595                   pack_input1->quantization->zero_point[0]);
1596   EXPECT_FLOAT_EQ(pack_input1->quantization->zero_point[0],
1597                   pack_input2->quantization->zero_point[0]);
1598 
1599   EXPECT_FLOAT_EQ(pack_input1->quantization->scale[0],
1600                   pack_output->quantization->scale[0]);
1601   EXPECT_FLOAT_EQ(pack_input1->quantization->zero_point[0],
1602                   pack_output->quantization->zero_point[0]);
1603 
1604   // Check type of input and output.
1605   EXPECT_EQ(pack_output->type, TensorType_INT8);
1606   EXPECT_EQ(pack_input0->type, TensorType_INT8);
1607   EXPECT_EQ(pack_input1->type, TensorType_INT8);
1608   EXPECT_EQ(pack_input2->type, TensorType_INT8);
1609 }
1610 
1611 class QuantizeMinimumMaximumTest
1612     : public QuantizeModelTest,
1613       public testing::WithParamInterface<const char*> {
1614  protected:
QuantizeMinimumMaximumTest()1615   QuantizeMinimumMaximumTest() {
1616     input_model_ = ReadModel(GetParam());
1617     readonly_model_ = input_model_->GetModel();
1618     readonly_model_->UnPackTo(&model_);
1619   }
1620 };
1621 
TEST_P(QuantizeMinimumMaximumTest,VerifyMinimumMaximum)1622 TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
1623   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1624   ASSERT_EQ(kTfLiteOk, status);
1625   const auto& subgraph = model_.subgraphs[0];
1626 
1627   // Check that the first op is Quantize and the last is Dequant.
1628   const auto& quant_op = subgraph->operators[0];
1629   const auto& dequant_op = subgraph->operators[subgraph->operators.size() - 1];
1630   const int32_t quant_idx = quant_op->opcode_index;
1631   const int32_t dequant_idx = dequant_op->opcode_index;
1632   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[quant_idx].get()),
1633             BuiltinOperator_QUANTIZE);
1634   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[dequant_idx].get()),
1635             BuiltinOperator_DEQUANTIZE);
1636   const auto& requant1 = subgraph->operators[1].get();
1637   // Check that we have RE operator.
1638   auto requant1_builtin_code =
1639       GetBuiltinCode(model_.operator_codes[requant1->opcode_index].get());
1640   ASSERT_TRUE(requant1_builtin_code == tflite::BuiltinOperator_QUANTIZE);
1641 
1642   // Constant is quantized rather than adding requant.
1643   const auto& op = subgraph->operators[2].get();
1644 
1645   // Check that we have MINIMUM or MAXIMUM operator.
1646   auto op_builtin_code =
1647       GetBuiltinCode(model_.operator_codes[op->opcode_index].get());
1648   ASSERT_TRUE(op_builtin_code == tflite::BuiltinOperator_MINIMUM ||
1649               op_builtin_code == tflite::BuiltinOperator_MAXIMUM);
1650 
1651   // Check that we have two inputs and one output.
1652   ASSERT_EQ(op->inputs.size(), 2);
1653   ASSERT_EQ(op->outputs.size(), 1);
1654 
1655   // Check that all is quantized.
1656   auto output = subgraph->tensors[op->outputs[0]].get();
1657   auto input1 = subgraph->tensors[op->outputs[0]].get();
1658   auto input2 = subgraph->tensors[op->outputs[0]].get();
1659 
1660   EXPECT_EQ(output->type, TensorType_INT8);
1661   EXPECT_EQ(input1->type, TensorType_INT8);
1662   EXPECT_EQ(input2->type, TensorType_INT8);
1663 
1664   // Check if the quantization params of the minimum/maximum inputs match
1665   // after requantization
1666   EXPECT_EQ(input1->quantization->scale, input2->quantization->scale);
1667   EXPECT_EQ(input1->quantization->zero_point, input2->quantization->zero_point);
1668 
1669   // Check the input quantization params match the output ones.
1670   EXPECT_EQ(output->quantization->scale, input1->quantization->scale);
1671   EXPECT_EQ(output->quantization->zero_point, input1->quantization->zero_point);
1672   EXPECT_EQ(output->quantization->scale, input2->quantization->scale);
1673   EXPECT_EQ(output->quantization->zero_point, input2->quantization->zero_point);
1674 
1675   EXPECT_EQ(subgraph->tensors.size(), 6);
1676 
1677   EXPECT_EQ(subgraph->tensors[0]->name, "input_int8");
1678   EXPECT_EQ(subgraph->tensors[1]->name, "output_int8");
1679   EXPECT_EQ(subgraph->tensors[2]->name, "output/y");
1680   EXPECT_EQ(subgraph->tensors[3]->name, "input_requantized");
1681   EXPECT_EQ(subgraph->tensors[4]->name, "input");
1682   EXPECT_EQ(subgraph->tensors[5]->name, "output");
1683 }
1684 
1685 INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest,
1686                          testing::ValuesIn({internal::kModelWithMinimumOp,
1687                                             internal::kModelWithMaximumOp}));
1688 
1689 class QuantizeUnpackTest : public QuantizeModelTest {
1690  protected:
QuantizeUnpackTest()1691   QuantizeUnpackTest() {
1692     input_model_ = ReadModel(internal::kModelWithUnpack);
1693     readonly_model_ = input_model_->GetModel();
1694     readonly_model_->UnPackTo(&model_);
1695   }
1696 };
TEST_F(QuantizeUnpackTest,VerifyUnpack)1697 TEST_F(QuantizeUnpackTest, VerifyUnpack) {
1698   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1699 
1700   ASSERT_EQ(kTfLiteOk, status);
1701 
1702   const auto subgraph = model_.subgraphs[0].get();
1703   auto op = subgraph->operators[1].get();
1704 
1705   auto float_graph = readonly_model_->subgraphs()->Get(0);
1706 
1707   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1708             BuiltinOperator_UNPACK);
1709 
1710   // Get unpack input and output tensors
1711   auto unpack_input = subgraph->tensors[op->inputs[0]].get();
1712   auto unpack_output_0 = subgraph->tensors[op->outputs[0]].get();
1713   auto unpack_output_1 = subgraph->tensors[op->outputs[1]].get();
1714 
1715   // Verify Unpack input is quantized.
1716   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1717             TensorType_FLOAT32);
1718   EXPECT_EQ(unpack_input->type, TensorType_INT8);
1719 
1720   // The model should only have one input and 2 outputs.
1721   EXPECT_EQ(subgraph->inputs.size(), 1);
1722   EXPECT_EQ(subgraph->outputs.size(), 2);
1723 
1724   // Ensure quantization parameters before and after unpack
1725   // are preserved after quantization for all outputs of
1726   // unpack.
1727   EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0],
1728                   unpack_output_0->quantization->scale[0]);
1729   EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0],
1730                   unpack_output_1->quantization->scale[0]);
1731   EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0],
1732                   unpack_output_0->quantization->zero_point[0]);
1733   EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0],
1734                   unpack_output_1->quantization->zero_point[0]);
1735 }
1736 
1737 class QuantizeTransposeTest : public QuantizeModelTest {
1738  protected:
QuantizeTransposeTest()1739   QuantizeTransposeTest() {
1740     input_model_ = ReadModel(internal::kModelWithTranspose);
1741     readonly_model_ = input_model_->GetModel();
1742     readonly_model_->UnPackTo(&model_);
1743   }
1744 };
1745 
TEST_F(QuantizeTransposeTest,VerifyTranspose)1746 TEST_F(QuantizeTransposeTest, VerifyTranspose) {
1747   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1748 
1749   ASSERT_EQ(kTfLiteOk, status);
1750 
1751   const auto subgraph = model_.subgraphs[0].get();
1752   auto op = subgraph->operators[1].get();
1753 
1754   auto float_graph = readonly_model_->subgraphs()->Get(0);
1755 
1756   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1757             BuiltinOperator_TRANSPOSE);
1758 
1759   // The model should only have one input and one outputs.
1760   EXPECT_EQ(subgraph->inputs.size(), 1);
1761   EXPECT_EQ(subgraph->outputs.size(), 1);
1762 
1763   // Get transpose input and output tensors
1764   auto transpose_input = subgraph->tensors[op->inputs[0]].get();
1765   auto transpose_output = subgraph->tensors[op->outputs[0]].get();
1766 
1767   // Verify transpose input is quantized.
1768   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1769             TensorType_FLOAT32);
1770   EXPECT_EQ(transpose_input->type, TensorType_INT8);
1771 
1772   // Ensure quantization parameters before and after transpose
1773   // are preserved after quantization for all outputs of
1774   // transpose.
1775   EXPECT_FLOAT_EQ(transpose_input->quantization->scale[0],
1776                   transpose_output->quantization->scale[0]);
1777   EXPECT_EQ(transpose_input->quantization->zero_point[0],
1778             transpose_output->quantization->zero_point[0]);
1779 }
1780 
1781 class QuantizeQatTest : public QuantizeModelTest {
1782  protected:
QuantizeQatTest()1783   QuantizeQatTest() {
1784     input_model_ = ReadModel(internal::kQatModelWithFc);
1785     readonly_model_ = input_model_->GetModel();
1786     readonly_model_->UnPackTo(&model_);
1787   }
1788 };
1789 
TEST_F(QuantizeQatTest,VerifySingleQuantize)1790 TEST_F(QuantizeQatTest, VerifySingleQuantize) {
1791   auto status = QuantizeModelAllOperators(
1792       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false,
1793       TensorType_INT8, TensorType_INT32, &error_reporter_);
1794   ASSERT_EQ(kTfLiteOk, status);
1795 
1796   const auto& subgraph = model_.subgraphs[0];
1797   auto op = subgraph->operators[0].get();
1798   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1799             BuiltinOperator_QUANTIZE);
1800   op = subgraph->operators[1].get();
1801   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1802             BuiltinOperator_RESHAPE);
1803   op = subgraph->operators[2].get();
1804   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1805             BuiltinOperator_FULLY_CONNECTED);
1806 
1807   ASSERT_EQ(op->inputs.size(), 3);
1808   ASSERT_EQ(op->outputs.size(), 1);
1809 
1810   auto qat_graph = readonly_model_->subgraphs()->Get(0);
1811   // Verify FC input and weight is quantized.
1812   ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[0])->type(), TensorType_INT8);
1813   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1814   ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[1])->type(), TensorType_INT8);
1815   EXPECT_EQ(subgraph->tensors[op->inputs[1]].get()->type, TensorType_INT8);
1816 
1817   // Verify FC bias should be int32 quantized.
1818   ASSERT_EQ(qat_graph->tensors()->Get(op->inputs[2])->type(), TensorType_INT32);
1819   EXPECT_EQ(subgraph->tensors[op->inputs[2]].get()->type, TensorType_INT32);
1820 
1821   // The output of FC should be quantized.
1822   ASSERT_EQ(qat_graph->tensors()->Get(op->outputs[0])->type(), TensorType_INT8);
1823   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1824 
1825   // check op and versioning.
1826   EXPECT_EQ(model_.operator_codes.size(), 4);
1827   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1828             BuiltinOperator_QUANTIZE);
1829   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
1830             BuiltinOperator_RESHAPE);
1831   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[2].get()),
1832             BuiltinOperator_FULLY_CONNECTED);
1833   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[3].get()),
1834             BuiltinOperator_DEQUANTIZE);
1835   EXPECT_EQ(model_.operator_codes[1]->version, 1);
1836   EXPECT_EQ(model_.operator_codes[2]->version, 4);
1837 }
1838 
1839 class QuantizeBroadcastToModelTest
1840     : public QuantizeModelTest,
1841       public testing::WithParamInterface<TensorType> {
1842  protected:
QuantizeBroadcastToModelTest()1843   QuantizeBroadcastToModelTest() {
1844     tensor_type_ = GetParam();
1845     bias_type_ = GetBiasTensorType(tensor_type_);
1846     input_model_ = ReadModel(internal::kModelWithBroadcastToOp);
1847     readonly_model_ = input_model_->GetModel();
1848     readonly_model_->UnPackTo(&model_);
1849   }
1850   TensorType tensor_type_;
1851   TensorType bias_type_;
1852 };
1853 
1854 INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst,
1855                          QuantizeBroadcastToModelTest,
1856                          testing::ValuesIn({TensorType_INT8,
1857                                             TensorType_INT16}));
1858 
TEST_P(QuantizeBroadcastToModelTest,VerifyBroadcastToQuantization)1859 TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
1860   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
1861                                           tensor_type_, false, tensor_type_,
1862                                           bias_type_, &error_reporter_);
1863   EXPECT_EQ(status, kTfLiteOk);
1864 
1865   // There is only one subgraph.
1866   const int32_t subgraph_idx = 0;
1867   const auto& subgraph = model_.subgraphs[subgraph_idx];
1868   const auto& readonly_subgraph =
1869       readonly_model_->subgraphs()->Get(subgraph_idx);
1870 
1871   // There should be a single broadcast_to op.
1872   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1873   EXPECT_EQ(subgraph->operators.size(), 1);
1874   const auto& broadcast_to = subgraph->operators[0];
1875   EXPECT_EQ(model_.operator_codes[broadcast_to->opcode_index]->builtin_code,
1876             BuiltinOperator_BROADCAST_TO);
1877 
1878   // There should be 3 tensors: input, output, and BroadcastTo/shape.
1879   EXPECT_EQ(subgraph->tensors.size(), 3);
1880 
1881   // Input Tensor
1882   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
1883   EXPECT_EQ(subgraph->tensors[0]->name, "input_1");
1884   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
1885   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
1886 
1887   // Output Tensor. The name given in the generated
1888   // .bin test file is 'Identity' and should be preserved
1889   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
1890   EXPECT_EQ(subgraph->tensors[2]->name, "Identity");
1891   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
1892   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
1893 
1894   // The BroadCastTo shape is of type INT32 and should not be quantized
1895   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
1896   EXPECT_EQ(subgraph->tensors[1]->name,
1897             "model/tf.broadcast_to/BroadcastTo/shape");
1898   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1899   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1900 
1901   // check op and versioning.
1902   EXPECT_EQ(model_.operator_codes.size(), 1);
1903   EXPECT_EQ(model_.operator_codes[0]->builtin_code,
1904             BuiltinOperator_BROADCAST_TO);
1905   EXPECT_EQ(model_.operator_codes[0]->version, 3);
1906 }
1907 
1908 class QuantizeGatherNDModelTest
1909     : public QuantizeModelTest,
1910       public testing::WithParamInterface<TensorType> {
1911  protected:
QuantizeGatherNDModelTest()1912   QuantizeGatherNDModelTest() {
1913     tensor_type_ = GetParam();
1914     bias_type_ = GetBiasTensorType(tensor_type_);
1915     input_model_ = ReadModel(internal::kModelWithGatherNDOp);
1916     readonly_model_ = input_model_->GetModel();
1917     readonly_model_->UnPackTo(&model_);
1918   }
1919 
1920   TensorType tensor_type_;
1921   TensorType bias_type_;
1922 };
1923 
1924 INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst,
1925                          QuantizeGatherNDModelTest,
1926                          testing::ValuesIn({TensorType_INT8,
1927                                             TensorType_INT16}));
1928 
TEST_P(QuantizeGatherNDModelTest,QuantizeGatherND)1929 TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) {
1930   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
1931                                           tensor_type_, false, tensor_type_,
1932                                           bias_type_, &error_reporter_);
1933   EXPECT_EQ(status, kTfLiteOk);
1934 
1935   // There is only one subgraph.
1936   const int32_t subgraph_idx = 0;
1937   const auto& subgraph = model_.subgraphs[subgraph_idx];
1938   const auto& readonly_subgraph =
1939       readonly_model_->subgraphs()->Get(subgraph_idx);
1940 
1941   // There should be a single gather_nd op.
1942   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1943   EXPECT_EQ(subgraph->operators.size(), 1);
1944   const auto& gather_nd = subgraph->operators[0];
1945   EXPECT_EQ(model_.operator_codes[gather_nd->opcode_index]->builtin_code,
1946             BuiltinOperator_GATHER_ND);
1947 
1948   // There should be 3 tensors: input, output, and indices.
1949   EXPECT_EQ(subgraph->tensors.size(), 3);
1950 
1951   // Input Tensor
1952   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
1953   EXPECT_EQ(subgraph->tensors[0]->name, "input");
1954   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
1955   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
1956 
1957   // Output Tensor
1958   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
1959   EXPECT_EQ(subgraph->tensors[2]->name, "output");
1960   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
1961   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
1962 
1963   // The gather indices are of type INT32 and should not be quantized
1964   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
1965   EXPECT_EQ(subgraph->tensors[1]->name, "indices");
1966   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1967   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1968 
1969   // Check op and versioning.
1970   EXPECT_EQ(model_.operator_codes.size(), 1);
1971   EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_GATHER_ND);
1972   EXPECT_EQ(model_.operator_codes[0]->version, 3);
1973 }
1974 
1975 class QuantizeWhereModelTest : public QuantizeModelTest {
1976  protected:
QuantizeWhereModelTest()1977   QuantizeWhereModelTest() {
1978     input_model_ = ReadModel(internal::kModelWithWhereOp);
1979     readonly_model_ = input_model_->GetModel();
1980     readonly_model_->UnPackTo(&model_);
1981   }
1982 };
1983 
TEST_F(QuantizeWhereModelTest,QuantizeWhere)1984 TEST_F(QuantizeWhereModelTest, QuantizeWhere) {
1985   // Where operator takes a BOOL tensor as input
1986   // and outputs INT64 indices, both of which
1987   // should not be quantized
1988   auto status = QuantizeModel(&builder_, &model_, TensorType_BOOL,
1989                               TensorType_INT64, &error_reporter_);
1990   EXPECT_EQ(status, kTfLiteOk);
1991 
1992   // There is only one subgraph.
1993   const int32_t subgraph_idx = 0;
1994   const auto& subgraph = model_.subgraphs[subgraph_idx];
1995   const auto& readonly_subgraph =
1996       readonly_model_->subgraphs()->Get(subgraph_idx);
1997 
1998   // There should be a single where op.
1999   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
2000   EXPECT_EQ(subgraph->operators.size(), 1);
2001   const auto& where = subgraph->operators[0];
2002   EXPECT_EQ(model_.operator_codes[where->opcode_index]->builtin_code,
2003             BuiltinOperator_WHERE);
2004 
2005   // There should be 2 tensors: input and output.
2006   EXPECT_EQ(subgraph->tensors.size(), 2);
2007 
2008   // Testing input tensor type and ensuring it
2009   // was not quantized
2010   EXPECT_EQ(subgraph->tensors[0]->type, TensorType_BOOL);
2011   EXPECT_EQ(subgraph->tensors[0]->name, "input");
2012   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 0);
2013   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 0);
2014 
2015   // Testing output (indices) tensor type and ensuring it
2016   // was not quantized
2017   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT64);
2018   EXPECT_EQ(subgraph->tensors[1]->name, "indices");
2019   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
2020   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
2021 
2022   // check op and versioning.
2023   EXPECT_EQ(model_.operator_codes.size(), 1);
2024   EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_WHERE);
2025   EXPECT_EQ(model_.operator_codes[0]->version, 1);
2026 }
2027 
2028 enum struct ModifyRangeType {
2029   kNone = 0,
2030   kAll = 1,
2031   kReadOnly = 2,
2032   kAssignOnly = 3,
2033 };
2034 
2035 struct TestType {
2036   TensorType tensor_type;
2037   ModifyRangeType modify_range;
2038 };
2039 
2040 struct BiasTestType {
2041   TensorType tensor_type;
2042   TensorType bias_type;
2043   bool is_valid_bias_type;
2044 };
2045 
2046 class QuantizeResourcesModelTest
2047     : public QuantizeModelTest,
2048       public testing::WithParamInterface<TestType> {
2049  protected:
QuantizeResourcesModelTest()2050   QuantizeResourcesModelTest() {
2051     TestType obj = GetParam();
2052     tensor_type_ = obj.tensor_type;
2053     modify_range_ = obj.modify_range;
2054     bias_type_ = GetBiasTensorType(tensor_type_);
2055     input_model_ = ReadModel(internal::kModelWithResourceVarsCalibrated);
2056     readonly_model_ = input_model_->GetModel();
2057     readonly_model_->UnPackTo(&model_, nullptr);
2058     if (modify_range_ != ModifyRangeType::kNone) {
2059       ModifyRange(&model_);
2060     }
2061   }
ModifyRange(ModelT * model)2062   void ModifyRange(ModelT* model) {
2063     // Modify ranges to test when min/max of the primary subgraph variable
2064     // is smaller than the initializer subgraph.
2065     const bool do_read = (modify_range_ == ModifyRangeType::kAll ||
2066                           modify_range_ == ModifyRangeType::kReadOnly);
2067     const bool do_assign = (modify_range_ == ModifyRangeType::kAll ||
2068                             modify_range_ == ModifyRangeType::kAssignOnly);
2069     SubGraphT* subgraph = model->subgraphs[0].get();
2070     for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
2071       OperatorT* op = subgraph->operators[op_idx].get();
2072       const BuiltinOperator op_code =
2073           GetBuiltinCode(model_.operator_codes[op->opcode_index].get());
2074       TensorT* var_tensor;
2075       if (op_code == BuiltinOperator_ASSIGN_VARIABLE && do_assign) {
2076         var_tensor = subgraph->tensors[op->inputs[1]].get();
2077       } else if (op_code == BuiltinOperator_READ_VARIABLE && do_read) {
2078         var_tensor = subgraph->tensors[op->outputs[0]].get();
2079       } else {
2080         continue;
2081       }
2082       // This value is lower than the initial values, so should be replaced
2083       var_tensor->quantization->max[0] = 12.5;
2084     }
2085   }
2086   TensorType tensor_type_;
2087   TensorType bias_type_;
2088   ModifyRangeType modify_range_ = ModifyRangeType::kAll;
2089 };
2090 
2091 INSTANTIATE_TEST_SUITE_P(QuantizeResourcesModelTest, QuantizeResourcesModelTest,
2092                          testing::ValuesIn<TestType>(
2093                              {{TensorType_INT8, ModifyRangeType::kNone},
2094                               {TensorType_INT8, ModifyRangeType::kAll},
2095                               {TensorType_INT8, ModifyRangeType::kReadOnly},
2096                               {TensorType_INT8, ModifyRangeType::kAssignOnly},
2097                               {TensorType_INT16, ModifyRangeType::kNone},
2098                               {TensorType_INT16, ModifyRangeType::kAll},
2099                               {TensorType_INT16, ModifyRangeType::kReadOnly},
2100                               {TensorType_INT16,
2101                                ModifyRangeType::kAssignOnly}}));
2102 
TEST_P(QuantizeResourcesModelTest,GraphIsFullyQuantized)2103 TEST_P(QuantizeResourcesModelTest, GraphIsFullyQuantized) {
2104   auto status = QuantizeModelAllOperators(
2105       &builder_, &model_, tensor_type_, tensor_type_,
2106       /*allow_float*/ false, tensor_type_, bias_type_, &error_reporter_);
2107   EXPECT_EQ(status, kTfLiteOk);
2108   std::vector<QuantizationParametersT*> quant_params;
2109   const float quant_eps = tensor_type_ == TensorType_INT8 ? 1e-1 : 1e-2;
2110   for (const auto& subgraph : model_.subgraphs) {
2111     for (const auto& tensor : subgraph->tensors) {
2112       if (tensor_type_ == TensorType_INT8) {
2113         EXPECT_TRUE(
2114             tensor->type == TensorType_RESOURCE ||  // resource
2115             tensor->type == TensorType_INT32 ||     // bias and gather indices
2116             tensor->type == TensorType_INT8);       // weights and activations
2117       } else if (tensor_type_ == TensorType_INT16) {
2118         EXPECT_TRUE(tensor->type == TensorType_RESOURCE ||  // resource
2119                     tensor->type == TensorType_INT64 ||     // bias
2120                     tensor->type == TensorType_INT32 ||     // gather indices
2121                     tensor->type == TensorType_INT16 ||     // activations
2122                     tensor->type == TensorType_INT8);       // weights
2123       }
2124     }
2125     for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
2126       OperatorT* op = subgraph->operators[op_idx].get();
2127       const BuiltinOperator op_code =
2128           GetBuiltinCode(model_.operator_codes[op->opcode_index].get());
2129       if (op_code == BuiltinOperator_ASSIGN_VARIABLE) {
2130         TensorT* var_tensor = subgraph->tensors[op->inputs[1]].get();
2131         quant_params.push_back(var_tensor->quantization.get());
2132         if (model_.buffers[var_tensor->buffer] &&
2133             !model_.buffers[var_tensor->buffer]->data.empty()) {
2134           const BufferT* buffer = model_.buffers[var_tensor->buffer].get();
2135           const int num_elements = 25;
2136           const int expected_buffer_size = tensor_type_ == TensorType_INT8
2137                                                ? num_elements * sizeof(int8_t)
2138                                                : num_elements * sizeof(int16_t);
2139           EXPECT_EQ(buffer->data.size(), expected_buffer_size);
2140           for (int i = 0; i < num_elements; ++i) {
2141             float dequantized = 0;
2142             if (tensor_type_ == TensorType_INT8) {
2143               auto data = reinterpret_cast<const int8_t*>(buffer->data.data());
2144               const int zero_point = var_tensor->quantization->zero_point[0];
2145               dequantized =
2146                   (data[i] - zero_point) * var_tensor->quantization->scale[0];
2147             } else if (tensor_type_ == TensorType_INT16) {
2148               auto data = reinterpret_cast<const int16_t*>(buffer->data.data());
2149               dequantized = data[i] * var_tensor->quantization->scale[0];
2150             }
2151             EXPECT_NEAR(dequantized, 25.0 - i, quant_eps);
2152           }
2153         }
2154       } else if (op_code == BuiltinOperator_READ_VARIABLE) {
2155         TensorT* var_tensor = subgraph->tensors[op->outputs[0]].get();
2156         quant_params.push_back(var_tensor->quantization.get());
2157       }
2158 
2159       // Test that the bias was duplicated.
2160       if (op_code == BuiltinOperator_FULLY_CONNECTED) {
2161         TensorT* bias = subgraph->tensors[op->inputs[2]].get();
2162         EXPECT_EQ(bias->name, "Const_duplicate_1");
2163         if (tensor_type_ == TensorType_INT8) {
2164           EXPECT_EQ(bias->type, TensorType_INT32);
2165         } else if (tensor_type_ == TensorType_INT8) {
2166           EXPECT_EQ(bias->type, TensorType_INT64);
2167         }
2168       }
2169     }
2170   }
2171   EXPECT_EQ(quant_params.size(), 4);
2172   QuantizationParametersT* expected_quant_param = quant_params[0];
2173   EXPECT_EQ(expected_quant_param->scale.size(), 1);
2174   float expected_scale =
2175       tensor_type_ == TensorType_INT8 ? 0.1960605f : 0.0015258f;
2176   if (modify_range_ == ModifyRangeType::kAll) {
2177     expected_scale = tensor_type_ == TensorType_INT8 ? 0.0980392f : 0.0007629f;
2178   }
2179   const float eps = 1e-7;
2180   EXPECT_NEAR(expected_quant_param->scale[0], expected_scale, eps);
2181   for (int i = 1; i < quant_params.size(); ++i) {
2182     QuantizationParametersT* test_param = quant_params[i];
2183     EXPECT_EQ(test_param->scale, expected_quant_param->scale);
2184     EXPECT_EQ(test_param->zero_point, expected_quant_param->zero_point);
2185     EXPECT_EQ(test_param->min, expected_quant_param->min);
2186     EXPECT_EQ(test_param->max, expected_quant_param->max);
2187   }
2188 }
2189 
2190 class QuantizeConcatConstModelTest
2191     : public QuantizeModelTest,
2192       public testing::WithParamInterface<TensorType> {
2193  protected:
QuantizeConcatConstModelTest()2194   QuantizeConcatConstModelTest() {
2195     input_model_ = ReadModel(internal::kFloatConcatMax5Max10Max10);
2196     readonly_model_ = input_model_->GetModel();
2197     readonly_model_->UnPackTo(&model_);
2198     // Make one of the values constant.
2199     MakeInputConstant(&model_);
2200   }
2201 
SetUp()2202   void SetUp() override {
2203     tensor_type_ = GetParam();
2204     bias_type_ = GetBiasTensorType(tensor_type_);
2205   }
2206 
MakeInputConstant(tflite::ModelT * model)2207   void MakeInputConstant(tflite::ModelT* model) {
2208     auto& subgraph = model->subgraphs[0];
2209     const int tensor_id = subgraph->inputs.back();
2210     int replace_tensor_id = subgraph->inputs[0];
2211     subgraph->inputs[0] = tensor_id;
2212     subgraph->inputs.pop_back();
2213     auto& tensor = subgraph->tensors[replace_tensor_id];
2214     tensor->name = "const_input0";
2215     model->buffers.emplace_back(new tflite::BufferT());
2216     tensor->buffer = model->buffers.size() - 1;
2217     auto& buffer = model->buffers[tensor->buffer];
2218     std::vector<float> tensor_buffer = {0.0, 5.0};
2219     uint8_t* uint8_data = reinterpret_cast<uint8_t*>(tensor_buffer.data());
2220     buffer->data = std::vector<uint8_t>(
2221         uint8_data, uint8_data + (sizeof(float) * tensor_buffer.size()));
2222   }
2223 
2224   TensorType tensor_type_;
2225   TensorType bias_type_;
2226 };
2227 
TEST_P(QuantizeConcatConstModelTest,AddRequantBeforeConcat)2228 TEST_P(QuantizeConcatConstModelTest, AddRequantBeforeConcat) {
2229   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
2230                                           tensor_type_, false, tensor_type_,
2231                                           bias_type_, &error_reporter_);
2232   EXPECT_EQ(status, kTfLiteOk);
2233 
2234   // There is only one subgraph.
2235   const int32_t subgraph_idx = 0;
2236   const auto& subgraph = model_.subgraphs[subgraph_idx];
2237   const auto& readonly_subgraph =
2238       readonly_model_->subgraphs()->Get(subgraph_idx);
2239 
2240   // There should be 1 op: concat.
2241   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
2242   EXPECT_EQ(subgraph->operators.size(), 1);
2243   const auto& concat = subgraph->operators[0];
2244   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[concat->opcode_index].get()),
2245             BuiltinOperator_CONCATENATION);
2246 
2247   auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0;
2248 
2249   auto input0_scale_control =
2250       tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
2251   auto input1_scale =
2252       tensor_type_ == TensorType_INT8 ? 0.039215688 : 0.00030518509;
2253 
2254   // There should be 3 tensors: const_input0, input1, output.
2255   EXPECT_EQ(subgraph->tensors.size(), 3);
2256   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
2257   EXPECT_EQ(subgraph->tensors[0]->name, "const_input0");
2258   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
2259   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
2260   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0],
2261                   input0_scale_control);
2262   EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0],
2263                   zero_point_control);
2264 
2265   EXPECT_EQ(subgraph->tensors[1]->type, tensor_type_);
2266   EXPECT_EQ(subgraph->tensors[1]->name, "input1");
2267   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
2268   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
2269   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], input1_scale);
2270   EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0],
2271                   zero_point_control);
2272   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
2273   EXPECT_EQ(subgraph->tensors[2]->name, "output");
2274   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
2275   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
2276   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], input1_scale);
2277   EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0],
2278                   zero_point_control);
2279 
2280   EXPECT_EQ(concat->inputs.size(), 2);
2281   EXPECT_EQ(concat->outputs.size(), 1);
2282   EXPECT_EQ(concat->inputs[0], 0);
2283   EXPECT_EQ(concat->inputs[1], 1);
2284   EXPECT_EQ(concat->outputs[0], 2);
2285 
2286   // check op and versioning.
2287   EXPECT_EQ(model_.operator_codes.size(), 1);
2288   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
2289             BuiltinOperator_CONCATENATION);
2290   EXPECT_EQ(model_.operator_codes[0]->version, 2);
2291 }
2292 
2293 INSTANTIATE_TEST_SUITE_P(QuantizeConcatConstModelTest,
2294                          QuantizeConcatConstModelTest,
2295                          testing::ValuesIn({TensorType_INT8,
2296                                             TensorType_INT16}));
2297 
2298 class BiasInputTest : public QuantizeModelTest,
2299                       public testing::WithParamInterface<BiasTestType> {
2300  protected:
BiasInputTest()2301   BiasInputTest() {
2302     BiasTestType obj = GetParam();
2303     tensor_type_ = obj.tensor_type;
2304     bias_type_ = obj.bias_type;
2305     is_valid_bias_type_ = obj.is_valid_bias_type;
2306     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
2307     readonly_model_ = input_model_->GetModel();
2308     readonly_model_->UnPackTo(&model_);
2309   }
2310   TensorType tensor_type_;
2311   TensorType bias_type_;
2312   bool is_valid_bias_type_;
2313   tflite::TestErrorReporter test_error_reporter_;
2314 };
2315 
2316 INSTANTIATE_TEST_SUITE_P(BiasInputTestInst, BiasInputTest,
2317                          testing::ValuesIn<BiasTestType>(
2318                              {{TensorType_INT8, TensorType_INT32, true},
2319                               {TensorType_INT8, TensorType_FLOAT32, false},
2320                               {TensorType_INT16, TensorType_INT32, true},
2321                               {TensorType_INT16, TensorType_INT64, true},
2322                               {TensorType_INT16, TensorType_FLOAT32, false}}));
2323 
TEST_P(BiasInputTest,QuantizationSucceeds)2324 TEST_P(BiasInputTest, QuantizationSucceeds) {
2325   auto status = QuantizeModelAllOperators(&builder_, &model_, tensor_type_,
2326                                           tensor_type_, false, tensor_type_,
2327                                           bias_type_, &test_error_reporter_);
2328   if (is_valid_bias_type_) {
2329     EXPECT_EQ(status, kTfLiteOk);
2330     const uint8_t* buffer = builder_.GetBufferPointer();
2331     const Model* output_model = GetModel(buffer);
2332     ASSERT_TRUE(output_model);
2333   } else {
2334     EXPECT_EQ(status, kTfLiteError);
2335   }
2336 }
2337 
2338 }  // namespace
2339 }  // namespace optimize
2340 }  // namespace tflite
2341 
main(int argc,char ** argv)2342 int main(int argc, char** argv) {
2343   tensorflow::string model_file;
2344   const std::vector<tensorflow::Flag> flag_list = {
2345       tensorflow::Flag("test_model_file", &model_file,
2346                        "Path to test tflite model file."),
2347   };
2348 
2349   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
2350   if (!parse_result) {
2351     std::cerr << "Required test_model_file\n";
2352     std::abort();
2353   }
2354   g_test_model_dir =
2355       new tensorflow::string(tensorflow::io::Dirname(model_file));
2356   ::tensorflow::port::InitMain(argv[0], &argc, &argv);
2357   return RUN_ALL_TESTS();
2358 }
2359