xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstdint>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
27 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/platform/init_main.h"
30 #include "tensorflow/core/util/command_line_flags.h"
31 #include "tensorflow/lite/model.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 #include "tensorflow/lite/schema/schema_utils.h"
34 #include "tensorflow/lite/tools/optimize/test_util.h"
35 
36 // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc
37 
38 namespace {
39 tensorflow::string* g_test_model_dir = nullptr;
40 }  // namespace
41 
42 namespace tflite {
43 namespace optimize {
44 namespace {
45 
QuantizeModel(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,bool allow_float,const std::unordered_set<string> & operator_names,const TensorType & activations_type,ErrorReporter * error_reporter,bool disable_per_channel=false,const absl::flat_hash_set<std::string> & blocked_ops={},const absl::flat_hash_set<std::string> & blocked_nodes={})46 TfLiteStatus QuantizeModel(
47     flatbuffers::FlatBufferBuilder* builder, ModelT* model,
48     const TensorType& input_type, const TensorType& output_type,
49     bool allow_float, const std::unordered_set<string>& operator_names,
50     const TensorType& activations_type, ErrorReporter* error_reporter,
51     bool disable_per_channel = false,
52     const absl::flat_hash_set<std::string>& blocked_ops = {},
53     const absl::flat_hash_set<std::string>& blocked_nodes = {}) {
54   TensorType inference_tensor_type = activations_type;
55   bool fully_quantize = !allow_float;
56   auto status = mlir::lite::QuantizeModel(
57       *model, input_type, output_type, inference_tensor_type,
58       /*operator_names=*/{}, disable_per_channel, fully_quantize, builder,
59       error_reporter, /*verify_numeric=*/false, /*whole_model_verify=*/false,
60       /*legacy_float_scale=*/true, blocked_ops, blocked_nodes);
61   if (status != kTfLiteOk) {
62     return status;
63   }
64   std::string buffer(
65       reinterpret_cast<const char*>(builder->GetCurrentBufferPointer()),
66       builder->GetSize());
67 
68   auto flatbuffer_model =
69       FlatBufferModel::BuildFromBuffer(buffer.c_str(), buffer.size());
70   flatbuffer_model->GetModel()->UnPackTo(model);
71   return kTfLiteOk;
72 }
73 
QuantizeModel(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,bool allow_float,ErrorReporter * error_reporter)74 TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
75                            ModelT* model, const TensorType& input_type,
76                            const TensorType& output_type, bool allow_float,
77                            ErrorReporter* error_reporter) {
78   return QuantizeModel(builder, model, input_type, output_type, allow_float,
79                        /*operator_names=*/{}, TensorType_INT8, error_reporter);
80 }
81 
QuantizeModel(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,ErrorReporter * error_reporter)82 TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
83                            ModelT* model, const TensorType& input_type,
84                            const TensorType& output_type,
85                            ErrorReporter* error_reporter) {
86   return QuantizeModel(builder, model, input_type, output_type,
87                        /*allow_float=*/false, error_reporter);
88 }
89 
QuantizeModel(flatbuffers::FlatBufferBuilder * builder,ModelT * model,ErrorReporter * error_reporter)90 TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
91                            ModelT* model, ErrorReporter* error_reporter) {
92   return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32,
93                        /*allow_float=*/true, error_reporter);
94 }
95 
QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,bool allow_float,const TensorType & activations_type,bool disable_per_channel,ErrorReporter * error_reporter)96 TfLiteStatus QuantizeModelAllOperators(
97     flatbuffers::FlatBufferBuilder* builder, ModelT* model,
98     const TensorType& input_type, const TensorType& output_type,
99     bool allow_float, const TensorType& activations_type,
100     bool disable_per_channel, ErrorReporter* error_reporter) {
101   return QuantizeModel(builder, model, input_type, output_type, allow_float,
102                        /*operator_names=*/{}, activations_type, error_reporter,
103                        disable_per_channel);
104 }
105 
QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,bool allow_float,const TensorType & activations_type,ErrorReporter * error_reporter)106 TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder,
107                                        ModelT* model,
108                                        const TensorType& input_type,
109                                        const TensorType& output_type,
110                                        bool allow_float,
111                                        const TensorType& activations_type,
112                                        ErrorReporter* error_reporter) {
113   return QuantizeModel(builder, model, input_type, output_type, allow_float,
114                        /*operator_names=*/{}, activations_type, error_reporter);
115 }
116 
ReadModel(const string & model_name)117 std::unique_ptr<FlatBufferModel> ReadModel(const string& model_name) {
118   auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name);
119   return FlatBufferModel::BuildFromFile(model_path.c_str());
120 }
121 
122 template <typename T>
GetAsVector(const flatbuffers::Vector<T> * vec)123 std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
124   return std::vector<T>(vec->begin(), vec->end());
125 }
126 
VerifyAsymmetricQuantizationScale(const QuantizationParameters & float_quant_params,const QuantizationParametersT & quantized_quant_params)127 void VerifyAsymmetricQuantizationScale(
128     const QuantizationParameters& float_quant_params,
129     const QuantizationParametersT& quantized_quant_params) {
130   const float eps = 1e-7;
131   ASSERT_EQ(float_quant_params.min()->size(), 1);
132   ASSERT_EQ(float_quant_params.max()->size(), 1);
133   float float_min = std::min(0.f, float_quant_params.min()->Get(0));
134   float float_max = std::max(0.f, float_quant_params.max()->Get(0));
135 
136   ASSERT_EQ(quantized_quant_params.scale.size(), 1);
137   ASSERT_EQ(quantized_quant_params.zero_point.size(), 1);
138   float scale = (float_max - float_min) / 255;
139   EXPECT_NEAR(scale, quantized_quant_params.scale[0], eps);
140 }
141 
142 class QuantizeModelTest : public testing::Test {
143  protected:
QuantizeModelTest()144   QuantizeModelTest() {
145     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
146     readonly_model_ = input_model_->GetModel();
147     readonly_model_->UnPackTo(&model_);
148   }
149 
150   std::unique_ptr<FlatBufferModel> input_model_;
151   const Model* readonly_model_;
152   tflite::ModelT model_;
153   flatbuffers::FlatBufferBuilder builder_;
154   internal::FailOnErrorReporter error_reporter_;
155 };
156 
ExpectEqualTensor(TensorT * tensor,TensorT * expected_tensor)157 void ExpectEqualTensor(TensorT* tensor, TensorT* expected_tensor) {
158   const float eps = 1e-7;
159   EXPECT_NE(expected_tensor, nullptr);
160   EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable);
161   EXPECT_EQ(tensor->shape, expected_tensor->shape);
162   EXPECT_EQ(tensor->type, expected_tensor->type);
163   const auto quantization_params = tensor->quantization.get();
164   const auto expected_quantization_params = expected_tensor->quantization.get();
165   if (quantization_params != nullptr &&
166       expected_quantization_params != nullptr) {
167     for (int i = 0; i < quantization_params->scale.size(); ++i) {
168       if (quantization_params->scale[i] > 3e-5) {
169         EXPECT_NEAR(quantization_params->scale[i],
170                     expected_quantization_params->scale[i], eps);
171       }
172     }
173     EXPECT_EQ(quantization_params->zero_point,
174               expected_quantization_params->zero_point);
175   }
176 }
177 
FindMatchingExpectedTensor(const SubGraphT & expected_graph,const ModelT & expected_model,const ModelT & quant_model,const OperatorT & quant_op,int idx)178 TensorT* FindMatchingExpectedTensor(const SubGraphT& expected_graph,
179                                     const ModelT& expected_model,
180                                     const ModelT& quant_model,
181                                     const OperatorT& quant_op, int idx) {
182   const auto& builtin_code =
183       GetBuiltinCode(quant_model.operator_codes[quant_op.opcode_index].get());
184   for (const auto& expected_op : expected_graph.operators) {
185     const auto& op_code =
186         expected_model.operator_codes[expected_op->opcode_index].get();
187     const auto& expected_code = GetBuiltinCode(op_code);
188     if (expected_code == builtin_code) {
189       return expected_graph.tensors[expected_op->inputs[idx]].get();
190     }
191   }
192   return nullptr;
193 }
194 
ExpectSameModels(const ModelT & model,const ModelT & expected_model)195 void ExpectSameModels(const ModelT& model, const ModelT& expected_model) {
196   ASSERT_EQ(model.subgraphs.size(), expected_model.subgraphs.size());
197   for (size_t subgraph_idx = 0; subgraph_idx < model.subgraphs.size();
198        subgraph_idx++) {
199     const auto graph = model.subgraphs[subgraph_idx].get();
200     const auto expected_graph = expected_model.subgraphs[subgraph_idx].get();
201     for (auto& op : graph->operators) {
202       for (int idx = 0; idx < op->inputs.size(); idx++) {
203         if (op->inputs[idx] < 0) {
204           continue;
205         }
206         const auto& tensor = graph->tensors[op->inputs[idx]];
207         auto* expected_tensor = FindMatchingExpectedTensor(
208             *expected_graph, expected_model, model, *op, idx);
209         if (!expected_tensor) {
210           continue;
211         }
212         ExpectEqualTensor(tensor.get(), expected_tensor);
213         if (expected_tensor->buffer > 0) {
214           const int buffer_idx = tensor->buffer;
215           const int expected_buffer_idx = expected_tensor->buffer;
216           const auto buffer = model.buffers[buffer_idx].get()->data;
217           const auto expected_buffer =
218               expected_model.buffers[expected_buffer_idx].get()->data;
219           EXPECT_EQ(buffer, expected_buffer);
220         }
221       }
222     }
223   }
224 }
225 
226 class QuantizeConvModelTest : public QuantizeModelTest,
227                               public testing::WithParamInterface<TensorType> {
228  protected:
QuantizeConvModelTest()229   QuantizeConvModelTest() {
230     tensor_type_ = GetParam();
231     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
232     readonly_model_ = input_model_->GetModel();
233     readonly_model_->UnPackTo(&model_);
234     // Flatbuffer is missing calibration data -- add dummy params.
235     auto& subgraph = model_.subgraphs[0];
236     auto* input = subgraph->tensors[subgraph->inputs[0]].get();
237     auto* output = subgraph->tensors[subgraph->outputs[0]].get();
238     input->quantization = std::make_unique<QuantizationParametersT>();
239     output->quantization = std::make_unique<QuantizationParametersT>();
240     input->quantization->min.push_back(0.0);
241     output->quantization->min.push_back(0.0);
242     input->quantization->max.push_back(6.0);
243     output->quantization->max.push_back(6.0);
244   }
245   TensorType tensor_type_;
246 };
247 
248 INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
249                          testing::ValuesIn({TensorType_INT8}));
250 
TEST_P(QuantizeConvModelTest,QuantizationSucceeds)251 TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
252   auto status = QuantizeModelAllOperators(
253       &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
254       tensor_type_, &error_reporter_);
255   EXPECT_EQ(status, kTfLiteOk);
256   const uint8_t* buffer = builder_.GetBufferPointer();
257   const Model* output_model = GetModel(buffer);
258   ASSERT_TRUE(output_model);
259 }
260 
TEST_P(QuantizeConvModelTest,SkipUnspecifiedLayer)261 TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
262   auto status = QuantizeModel(
263       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
264       /*allow_float=*/true, /*operator_names=*/{}, TensorType_FLOAT32,
265       &error_reporter_, /*disable_per_channel=*/false, {"CONV_2D"});
266   EXPECT_EQ(status, kTfLiteOk);
267 
268   ModelT expected_model;
269   readonly_model_->UnPackTo(&expected_model);
270   // The resulting model should be the same.
271   ExpectSameModels(model_, expected_model);
272 }
273 
TEST_P(QuantizeConvModelTest,SkipUnspecifiedLayerByName)274 TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayerByName) {
275   auto status = QuantizeModel(
276       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
277       /*allow_float=*/true, /*operator_names=*/{}, TensorType_FLOAT32,
278       &error_reporter_, /*disable_per_channel=*/false, /*blocked_ops=*/{},
279       {"output"});
280   EXPECT_EQ(status, kTfLiteOk);
281 
282   ModelT expected_model;
283   readonly_model_->UnPackTo(&expected_model);
284   // The resulting model should be the same.
285   ExpectSameModels(model_, expected_model);
286 }
287 
TEST_P(QuantizeConvModelTest,GraphIsFullyQuantized)288 TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
289   auto status = QuantizeModelAllOperators(
290       &builder_, &model_, tensor_type_, tensor_type_,
291       /*allow_float=*/false, tensor_type_, &error_reporter_);
292   EXPECT_EQ(status, kTfLiteOk);
293 
294   for (const auto& subgraph : model_.subgraphs) {
295     for (const auto& tensor : subgraph->tensors) {
296       EXPECT_TRUE(tensor->type == TensorType_INT32 ||
297                   tensor->type == TensorType_INT8);
298     }
299   }
300 }
301 
302 class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
303  protected:
QuantizeConvNoBiasModelTest()304   QuantizeConvNoBiasModelTest() {
305     input_model_ = ReadModel(internal::kConvModelWithNoBias);
306     readonly_model_ = input_model_->GetModel();
307     readonly_model_->UnPackTo(&model_);
308   }
309 };
310 
TEST_F(QuantizeConvNoBiasModelTest,QuantizationSucceeds)311 TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
312   auto status = QuantizeModelAllOperators(
313       &builder_, &model_, TensorType_INT8, TensorType_INT8,
314       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
315   EXPECT_EQ(status, kTfLiteOk);
316   const uint8_t* buffer = builder_.GetBufferPointer();
317   const Model* output_model = GetModel(buffer);
318   ASSERT_TRUE(output_model);
319 }
320 
321 class QuantizeSplitModelTest : public QuantizeModelTest {
322  protected:
QuantizeSplitModelTest()323   QuantizeSplitModelTest() {
324     input_model_ = ReadModel(internal::kModelSplit);
325     readonly_model_ = input_model_->GetModel();
326     readonly_model_->UnPackTo(&model_);
327   }
328 };
329 
330 // There are two outputs for split with different scales, the resulting model
331 // should have the scales be hardcodes to the input scale value.
TEST_F(QuantizeSplitModelTest,QuantizeSplit)332 TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
333   auto status = QuantizeModelAllOperators(
334       &builder_, &model_, TensorType_INT8, TensorType_INT8,
335       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
336   EXPECT_EQ(status, kTfLiteOk);
337 
338   // There is only one subgraph.
339   const int32_t subgraph_idx = 0;
340   const auto& subgraph = model_.subgraphs[subgraph_idx];
341   const auto& readonly_subgraph =
342       readonly_model_->subgraphs()->Get(subgraph_idx);
343 
344   // There should be two ops: the split and add in the original model.
345   EXPECT_EQ(readonly_subgraph->operators()->size(), 2);
346   EXPECT_EQ(subgraph->operators.size(), 2);
347   const auto& split = subgraph->operators[0];
348   const auto& add = subgraph->operators[1];
349   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[split->opcode_index].get()),
350             BuiltinOperator_SPLIT);
351   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[add->opcode_index].get()),
352             BuiltinOperator_ADD);
353 
354   // There should be 5 tensors: input, output, split, split/split_dim, split:1.
355   // Tensor indices could be different between original and quantized.
356   EXPECT_EQ(subgraph->tensors.size(), 5);
357   const int input_idx = 0;
358   EXPECT_EQ(subgraph->tensors[input_idx]->type, TensorType_INT8);
359   EXPECT_EQ(subgraph->tensors[input_idx]->name, "input");
360   EXPECT_EQ(subgraph->tensors[input_idx]->quantization->scale.size(), 1);
361   EXPECT_EQ(subgraph->tensors[input_idx]->quantization->zero_point.size(), 1);
362   EXPECT_FLOAT_EQ(subgraph->tensors[input_idx]->quantization->scale[0], 1.0);
363   EXPECT_FLOAT_EQ(subgraph->tensors[input_idx]->quantization->zero_point[0],
364                   -128);
365   const int output_idx = 4;
366   EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_INT8);
367   EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
368   EXPECT_EQ(subgraph->tensors[output_idx]->quantization->scale.size(), 1);
369   EXPECT_EQ(subgraph->tensors[output_idx]->quantization->zero_point.size(), 1);
370   EXPECT_FLOAT_EQ(subgraph->tensors[output_idx]->quantization->scale[0], 1.0);
371   EXPECT_FLOAT_EQ(subgraph->tensors[output_idx]->quantization->zero_point[0],
372                   -128);
373   const int split0_idx = 2;
374   EXPECT_EQ(subgraph->tensors[split0_idx]->type, TensorType_INT8);
375   EXPECT_EQ(subgraph->tensors[split0_idx]->name, "split;split:1");
376   EXPECT_EQ(subgraph->tensors[split0_idx]->quantization->scale.size(), 1);
377   EXPECT_EQ(subgraph->tensors[split0_idx]->quantization->zero_point.size(), 1);
378   EXPECT_FLOAT_EQ(subgraph->tensors[split0_idx]->quantization->scale[0], 1.0);
379   EXPECT_FLOAT_EQ(subgraph->tensors[split0_idx]->quantization->zero_point[0],
380                   -128);
381   const int split1_idx = 3;
382   EXPECT_EQ(subgraph->tensors[split1_idx]->type, TensorType_INT8);
383   EXPECT_EQ(subgraph->tensors[split1_idx]->name, "split;split:11");
384   EXPECT_EQ(subgraph->tensors[split1_idx]->quantization->scale.size(), 1);
385   EXPECT_EQ(subgraph->tensors[split1_idx]->quantization->zero_point.size(), 1);
386   EXPECT_FLOAT_EQ(subgraph->tensors[split1_idx]->quantization->scale[0], 1.0);
387   EXPECT_FLOAT_EQ(subgraph->tensors[split1_idx]->quantization->zero_point[0],
388                   -128);
389 
390   // check op and versioning.
391   EXPECT_EQ(model_.operator_codes.size(), 2);
392   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
393             BuiltinOperator_SPLIT);
394   EXPECT_EQ(model_.operator_codes[0]->version, 2);
395 }
396 
397 class QuantizeConvModel2Test : public QuantizeModelTest,
398                                public testing::WithParamInterface<TensorType> {
399  protected:
QuantizeConvModel2Test()400   QuantizeConvModel2Test() {
401     tensor_type_ = GetParam();
402     input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
403     readonly_model_ = input_model_->GetModel();
404     readonly_model_->UnPackTo(&model_);
405     auto& subgraph = model_.subgraphs[0];
406     auto* input = subgraph->tensors[subgraph->inputs[0]].get();
407     auto* output = subgraph->tensors[subgraph->outputs[0]].get();
408     input->quantization = std::make_unique<QuantizationParametersT>();
409     output->quantization = std::make_unique<QuantizationParametersT>();
410     input->quantization->min.push_back(0.0);
411     output->quantization->min.push_back(0.0);
412     input->quantization->max.push_back(6.0);
413     output->quantization->max.push_back(6.0);
414   }
415 
416   TensorType tensor_type_;
417 };
418 
419 INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
420                          testing::ValuesIn({TensorType_INT8}));
421 
TEST_P(QuantizeConvModel2Test,VerifyConvQuantization)422 TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
423   auto status = QuantizeModelAllOperators(
424       &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
425       tensor_type_, &error_reporter_);
426   ASSERT_EQ(kTfLiteOk, status);
427   const auto& subgraph = model_.subgraphs[0];
428   auto conv_op = subgraph->operators[0].get();
429   const int input_tensor_idx = 0;
430   const int weights_tensor_idx = 1;
431   const int bias_tensor_index = 2;
432   const int output_tensor_idx = 0;
433   const auto bias_tensor =
434       subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
435   const auto input_tensor =
436       subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
437   const auto weights_tensor =
438       subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
439   const auto output_tensor =
440       subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
441 
442   EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
443                                    ? TensorType_INT32
444                                    : TensorType_INT64);
445   EXPECT_EQ(input_tensor->type, tensor_type_);
446   EXPECT_EQ(weights_tensor->type, TensorType_INT8);
447 
448   ASSERT_TRUE(weights_tensor->quantization);
449   ASSERT_TRUE(bias_tensor->quantization);
450   ASSERT_TRUE(weights_tensor->quantization);
451   const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
452   const std::vector<float>& weights_scales =
453       weights_tensor->quantization->scale;
454   const std::vector<int64_t>& weights_zero_points =
455       weights_tensor->quantization->zero_point;
456   const int out_channel_size = weights_tensor->shape[0];
457   ASSERT_EQ(bias_scales.size(), out_channel_size);
458   ASSERT_EQ(weights_scales.size(), out_channel_size);
459   ASSERT_EQ(weights_zero_points.size(), out_channel_size);
460   ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
461   ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
462 
463   const float eps = 1e-7;
464 
465   // Bias scale should be input * per_channel_weight_scale.
466   for (size_t i = 0; i < out_channel_size; i++) {
467     EXPECT_NEAR(bias_scales[i],
468                 input_tensor->quantization->scale[0] * weights_scales[i], eps);
469   }
470 
471   const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
472   auto control_size = tensor_type_ == TensorType_INT8
473                           ? sizeof(int32_t) * bias_tensor->shape[0]
474                           : sizeof(int64_t) * bias_tensor->shape[0];
475 
476   const auto float_op =
477       readonly_model_->subgraphs()->Get(0)->operators()->Get(0);
478   const auto original_bias_tensor =
479       readonly_model_->subgraphs()->Get(0)->tensors()->Get(
480           float_op->inputs()->Get(2));
481   ASSERT_EQ(bias_buffer->data.size(), control_size);
482   const auto original_bias_buffer =
483       readonly_model_->buffers()->Get(original_bias_tensor->buffer());
484   const float* bias_float_buffer =
485       reinterpret_cast<const float*>(original_bias_buffer->data()->data());
486 
487   if (tensor_type_ == TensorType_INT8) {
488     int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
489     for (size_t i = 0; i < out_channel_size; i++) {
490       auto dequantized_value = bias_values[i] * bias_scales[i];
491       EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
492     }
493   }
494 
495   const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
496   const auto original_weights_tensor =
497       readonly_model_->subgraphs()->Get(0)->tensors()->Get(
498           float_op->inputs()->Get(1));
499   const auto original_weights_buffer =
500       readonly_model_->buffers()->Get(original_weights_tensor->buffer());
501   const int8_t* weight_values =
502       reinterpret_cast<int8_t*>(weights_buffer->data.data());
503   const float* weights_float_buffer =
504       reinterpret_cast<const float*>(original_weights_buffer->data()->data());
505   ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
506             original_weights_buffer->data()->size());
507   int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
508   for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
509     for (size_t j = 0; j < num_values_in_channel; j++) {
510       size_t element_idx = channel_idx * out_channel_size + j;
511       auto scale = weights_scales[channel_idx];
512       auto zero_point = weights_zero_points[channel_idx];
513       auto dequantized_value = weight_values[element_idx] * scale;
514       EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx],
515                   scale / 2);
516       EXPECT_EQ(zero_point, 0);
517     }
518   }
519 
520   // check op and versioning.
521   EXPECT_EQ(model_.operator_codes.size(), 1);
522   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
523             BuiltinOperator_CONV_2D);
524   EXPECT_EQ(model_.operator_codes[0]->version, 3);
525 }
526 
TEST_P(QuantizeConvModel2Test,VerifyConvDisablePerChannelQuantization)527 TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) {
528   auto status = QuantizeModelAllOperators(
529       &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
530       tensor_type_, /*disable_per_channel=*/true, &error_reporter_);
531   ASSERT_EQ(kTfLiteOk, status);
532   const auto& subgraph = model_.subgraphs[0];
533   auto conv_op = subgraph->operators[0].get();
534   const int input_tensor_idx = 0;
535   const int weights_tensor_idx = 1;
536   const int bias_tensor_index = 2;
537   const int output_tensor_idx = 0;
538   const auto bias_tensor =
539       subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
540   const auto input_tensor =
541       subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
542   const auto weights_tensor =
543       subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
544   const auto output_tensor =
545       subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
546 
547   EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
548                                    ? TensorType_INT32
549                                    : TensorType_INT64);
550   EXPECT_EQ(input_tensor->type, tensor_type_);
551   EXPECT_EQ(weights_tensor->type, TensorType_INT8);
552 
553   ASSERT_TRUE(weights_tensor->quantization);
554   ASSERT_TRUE(bias_tensor->quantization);
555   ASSERT_TRUE(weights_tensor->quantization);
556   const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
557   const std::vector<float>& weights_scales =
558       weights_tensor->quantization->scale;
559   const std::vector<int64_t>& weights_zero_points =
560       weights_tensor->quantization->zero_point;
561 
562   const int out_channel_size = 1;
563   ASSERT_EQ(bias_scales.size(), out_channel_size);
564   ASSERT_EQ(weights_scales.size(), out_channel_size);
565   ASSERT_EQ(weights_zero_points.size(), out_channel_size);
566   ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
567   ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
568 
569   const float eps = 1e-7;
570 
571   // Bias scale should be input * per_channel_weight_scale.
572   for (size_t i = 0; i < out_channel_size; i++) {
573     EXPECT_NEAR(bias_scales[i],
574                 input_tensor->quantization->scale[0] * weights_scales[i], eps);
575   }
576 
577   const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
578   auto control_size = tensor_type_ == TensorType_INT8
579                           ? sizeof(int32_t) * bias_tensor->shape[0]
580                           : sizeof(int64_t) * bias_tensor->shape[0];
581 
582   ASSERT_EQ(bias_buffer->data.size(), control_size);
583   const auto float_op =
584       readonly_model_->subgraphs()->Get(0)->operators()->Get(0);
585   const auto original_bias_tensor =
586       readonly_model_->subgraphs()->Get(0)->tensors()->Get(
587           float_op->inputs()->Get(2));
588   ASSERT_EQ(bias_buffer->data.size(), control_size);
589   const auto original_bias_buffer =
590       readonly_model_->buffers()->Get(original_bias_tensor->buffer());
591   const float* bias_float_buffer =
592       reinterpret_cast<const float*>(original_bias_buffer->data()->data());
593 
594   if (tensor_type_ == TensorType_INT8) {
595     int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
596     for (size_t i = 0; i < out_channel_size; i++) {
597       auto dequantized_value = bias_values[i] * bias_scales[i];
598       EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
599     }
600   }
601 
602   const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
603   const auto original_weights_tensor =
604       readonly_model_->subgraphs()->Get(0)->tensors()->Get(
605           float_op->inputs()->Get(1));
606   const auto original_weights_buffer =
607       readonly_model_->buffers()->Get(original_weights_tensor->buffer());
608   const int8_t* weight_values =
609       reinterpret_cast<int8_t*>(weights_buffer->data.data());
610   const float* weights_float_buffer =
611       reinterpret_cast<const float*>(original_weights_buffer->data()->data());
612   ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
613             original_weights_buffer->data()->size());
614   int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
615   for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
616     for (size_t j = 0; j < num_values_in_channel; j++) {
617       size_t element_idx = channel_idx * out_channel_size + j;
618       auto scale = weights_scales[channel_idx];
619       auto zero_point = weights_zero_points[channel_idx];
620       auto dequantized_value = weight_values[element_idx] * scale;
621       EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx],
622                   scale / 2);
623       EXPECT_EQ(zero_point, 0);
624     }
625   }
626 
627   // check op and versioning.
628   EXPECT_EQ(model_.operator_codes.size(), 1);
629   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
630             BuiltinOperator_CONV_2D);
631   EXPECT_EQ(model_.operator_codes[0]->version, 3);
632 }
633 
634 class QuantizeSoftmaxTest : public QuantizeModelTest {
635  protected:
QuantizeSoftmaxTest()636   QuantizeSoftmaxTest() {
637     input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5);
638     readonly_model_ = input_model_->GetModel();
639     readonly_model_->UnPackTo(&model_);
640   }
641 };
642 
TEST_F(QuantizeSoftmaxTest,VerifySoftmaxQuantization)643 TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
644   auto status = QuantizeModelAllOperators(
645       &builder_, &model_, TensorType_INT8, TensorType_INT8,
646       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
647   ASSERT_EQ(kTfLiteOk, status);
648 
649   const auto& subgraph = model_.subgraphs[0];
650   auto op = subgraph->operators[0].get();
651   // Model has a single softmax op.
652   ASSERT_EQ(op->opcode_index, 0);
653   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
654             BuiltinOperator_SOFTMAX);
655 
656   ASSERT_EQ(op->inputs.size(), 1);
657   ASSERT_EQ(op->outputs.size(), 1);
658   auto float_graph = readonly_model_->subgraphs()->Get(0);
659 
660   // Verify input.
661   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
662             TensorType_FLOAT32);
663   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
664             TensorType_FLOAT32);
665 
666   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
667   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
668 
669   auto float_input_quant_params =
670       float_graph->tensors()->Get(op->inputs[0])->quantization();
671   auto input_quant_params =
672       subgraph->tensors[op->inputs[0]]->quantization.get();
673   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
674                                     *input_quant_params);
675 
676   // Verify output.
677   auto float_output_quant_params =
678       float_graph->tensors()->Get(op->outputs[0])->quantization();
679   auto output_quant_params =
680       subgraph->tensors[op->outputs[0]]->quantization.get();
681   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
682   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
683 
684   ASSERT_EQ(output_quant_params->scale.size(), 1);
685   ASSERT_EQ(output_quant_params->zero_point.size(), 1);
686   ASSERT_EQ(1.0f / 256.0f, output_quant_params->scale[0]);
687   ASSERT_EQ(-128, output_quant_params->zero_point[0]);
688 
689   // check op and versioning.
690   EXPECT_EQ(model_.operator_codes.size(), 1);
691   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
692             BuiltinOperator_SOFTMAX);
693   EXPECT_EQ(model_.operator_codes[0]->version, 2);
694 }
695 
696 class QuantizeAvgPoolTest : public QuantizeModelTest {
697  protected:
QuantizeAvgPoolTest()698   QuantizeAvgPoolTest() {
699     input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5);
700     readonly_model_ = input_model_->GetModel();
701     readonly_model_->UnPackTo(&model_);
702   }
703 };
704 
TEST_F(QuantizeAvgPoolTest,VerifyAvgPoolQuantization)705 TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
706   auto status = QuantizeModelAllOperators(
707       &builder_, &model_, TensorType_INT8, TensorType_INT8,
708       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
709   ASSERT_EQ(kTfLiteOk, status);
710 
711   const auto& subgraph = model_.subgraphs[0];
712   auto op = subgraph->operators[0].get();
713   // Model has a single AveragePool op.
714   ASSERT_EQ(op->opcode_index, 0);
715   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
716             BuiltinOperator_AVERAGE_POOL_2D);
717 
718   ASSERT_EQ(op->inputs.size(), 1);
719   ASSERT_EQ(op->outputs.size(), 1);
720 
721   auto float_graph = readonly_model_->subgraphs()->Get(0);
722   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
723             TensorType_FLOAT32);
724   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
725             TensorType_FLOAT32);
726 
727   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
728   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
729 
730   auto float_input_quant_params =
731       float_graph->tensors()->Get(op->inputs[0])->quantization();
732   auto input_quant_params =
733       subgraph->tensors[op->inputs[0]]->quantization.get();
734   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
735                                     *input_quant_params);
736 
737   auto float_output_quant_params =
738       float_graph->tensors()->Get(op->outputs[0])->quantization();
739   auto output_quant_params =
740       subgraph->tensors[op->outputs[0]]->quantization.get();
741   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
742   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
743   ASSERT_EQ(output_quant_params->scale.size(), 1);
744 
745   // Make sure the input min/maxes are propagated to outputs.
746   EXPECT_EQ(input_quant_params->scale[0], output_quant_params->scale[0]);
747 
748   // check op and versioning.
749   EXPECT_EQ(model_.operator_codes.size(), 1);
750   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
751             BuiltinOperator_AVERAGE_POOL_2D);
752   EXPECT_EQ(model_.operator_codes[0]->version, 2);
753 }
754 
755 class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
756  protected:
QuantizeMultiInputAddWithReshapeTest()757   QuantizeMultiInputAddWithReshapeTest() {
758     input_model_ = ReadModel(internal::kMultiInputAddWithReshape);
759     readonly_model_ = input_model_->GetModel();
760     readonly_model_->UnPackTo(&model_);
761   }
762 };
763 
TEST_F(QuantizeMultiInputAddWithReshapeTest,VerifyReshapeQuantization)764 TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
765   auto status = QuantizeModelAllOperators(
766       &builder_, &model_, TensorType_INT8, TensorType_INT8,
767       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
768 
769   ASSERT_EQ(kTfLiteOk, status);
770 
771   // Verify Reshape is quantized.
772   const auto& subgraph = model_.subgraphs[0];
773   auto op = subgraph->operators[1].get();
774   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
775             BuiltinOperator_RESHAPE);
776 
777   ASSERT_EQ(op->inputs.size(), 2);
778   ASSERT_EQ(op->outputs.size(), 1);
779 
780   auto float_graph = readonly_model_->subgraphs()->Get(0);
781   auto float_op = float_graph->operators()->Get(1);
782   ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(0))->type(),
783             TensorType_FLOAT32);
784   ASSERT_EQ(float_graph->tensors()->Get(float_op->outputs()->Get(0))->type(),
785             TensorType_FLOAT32);
786 
787   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
788   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
789   auto float_input_quant_params =
790       float_graph->tensors()->Get(op->inputs[0])->quantization();
791   auto input_quant_params =
792       subgraph->tensors[op->inputs[0]]->quantization.get();
793   VerifyAsymmetricQuantizationScale(*float_input_quant_params,
794                                     *input_quant_params);
795 
796   auto float_output_quant_params =
797       float_graph->tensors()->Get(float_op->outputs()->Get(0))->quantization();
798   auto output_quant_params =
799       subgraph->tensors[op->outputs[0]]->quantization.get();
800   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
801   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
802   ASSERT_EQ(output_quant_params->scale.size(), 1);
803 
804   // check op and versioning.
805   EXPECT_EQ(model_.operator_codes.size(), 2);
806   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
807             BuiltinOperator_ADD);
808   EXPECT_EQ(model_.operator_codes[0]->version, 2);
809   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
810             BuiltinOperator_RESHAPE);
811   EXPECT_EQ(model_.operator_codes[1]->version, 1);
812 }
813 
TEST_F(QuantizeMultiInputAddWithReshapeTest,VerifyAddQuantization)814 TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
815   auto status = QuantizeModelAllOperators(
816       &builder_, &model_, TensorType_INT8, TensorType_INT8,
817       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
818   ASSERT_EQ(kTfLiteOk, status);
819 
820   // Verify ADD is quantized.
821   const auto& subgraph = model_.subgraphs[0];
822   auto op = subgraph->operators[0].get();
823   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
824             BuiltinOperator_ADD);
825 
826   ASSERT_EQ(op->inputs.size(), 2);
827   ASSERT_EQ(op->outputs.size(), 1);
828 
829   auto float_graph = readonly_model_->subgraphs()->Get(0);
830   auto float_op = float_graph->operators()->Get(0);
831   const int float_input0_idx = float_op->inputs()->Get(0);
832   const int float_input1_idx = float_op->inputs()->Get(1);
833   const int float_output_idx = float_op->outputs()->Get(0);
834   ASSERT_EQ(float_graph->tensors()->Get(float_input0_idx)->type(),
835             TensorType_FLOAT32);
836   ASSERT_EQ(float_graph->tensors()->Get(float_input1_idx)->type(),
837             TensorType_FLOAT32);
838   ASSERT_EQ(float_graph->tensors()->Get(float_output_idx)->type(),
839             TensorType_FLOAT32);
840 
841   for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
842     EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
843               TensorType_INT8);
844     auto float_input_quant_params =
845         float_graph->tensors()
846             ->Get(float_op->inputs()->Get(input_idx))
847             ->quantization();
848     auto input_quant_params =
849         subgraph->tensors[op->inputs[input_idx]]->quantization.get();
850     VerifyAsymmetricQuantizationScale(*float_input_quant_params,
851                                       *input_quant_params);
852   }
853 
854   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
855   auto float_output_quant_params =
856       float_graph->tensors()->Get(op->outputs[0])->quantization();
857   auto output_quant_params =
858       subgraph->tensors[op->outputs[0]]->quantization.get();
859   ASSERT_EQ(float_output_quant_params->min()->size(), 1);
860   ASSERT_EQ(float_output_quant_params->max()->size(), 1);
861   ASSERT_EQ(output_quant_params->scale.size(), 1);
862 
863   // check op and versioning.
864   EXPECT_EQ(model_.operator_codes.size(), 2);
865   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
866             BuiltinOperator_ADD);
867   EXPECT_EQ(model_.operator_codes[0]->version, 2);
868   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
869             BuiltinOperator_RESHAPE);
870   EXPECT_EQ(model_.operator_codes[1]->version, 1);
871 }
872 
873 class QuantizeConstInputTest : public QuantizeModelTest,
874                                public testing::WithParamInterface<TensorType> {
875  protected:
QuantizeConstInputTest()876   QuantizeConstInputTest() {
877     tensor_type_ = GetParam();
878     input_model_ = ReadModel(internal::kConstInputAddModel);
879     readonly_model_ = input_model_->GetModel();
880     readonly_model_->UnPackTo(&model_);
881   }
882 
883   TensorType tensor_type_;
884 };
885 INSTANTIATE_TEST_SUITE_P(QuantizeConstInputTestInst, QuantizeConstInputTest,
886                          testing::ValuesIn({TensorType_INT8}));
887 
TEST_P(QuantizeConstInputTest,VerifyConstOpInput)888 TEST_P(QuantizeConstInputTest, VerifyConstOpInput) {
889   auto status =
890       QuantizeModelAllOperators(
891           &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
892           tensor_type_, &error_reporter_);
893   ASSERT_EQ(kTfLiteOk, status);
894 
895   // Verify ConstOp is quantized.
896   const auto& subgraph = model_.subgraphs[0];
897   auto op = subgraph->operators[0].get();
898   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
899             BuiltinOperator_ADD);
900 
901   ASSERT_EQ(op->inputs.size(), 2);
902   ASSERT_EQ(op->outputs.size(), 1);
903 
904   auto float_graph = readonly_model_->subgraphs()->Get(0);
905   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
906             TensorType_FLOAT32);
907   ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
908             TensorType_FLOAT32);
909 
910   for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
911     EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
912               tensor_type_);
913   }
914 
915   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, tensor_type_);
916 
917   // check op and versioning.
918   EXPECT_EQ(model_.operator_codes.size(), 1);
919   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
920             BuiltinOperator_ADD);
921   EXPECT_EQ(model_.operator_codes[0]->version, 2);
922 }
923 
924 class QuantizeArgMaxTest : public QuantizeModelTest {
925  protected:
QuantizeArgMaxTest()926   QuantizeArgMaxTest() {
927     input_model_ = ReadModel(internal::kModelWithArgMaxOp);
928     readonly_model_ = input_model_->GetModel();
929     readonly_model_->UnPackTo(&model_);
930   }
931 };
932 
TEST_F(QuantizeArgMaxTest,VerifyArgMax)933 TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
934   auto status = QuantizeModelAllOperators(
935       &builder_, &model_, TensorType_INT8, TensorType_INT8,
936       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
937   ASSERT_EQ(kTfLiteOk, status);
938 
939   const auto& subgraph = model_.subgraphs[0];
940   auto op = subgraph->operators[0].get();
941   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
942             BuiltinOperator_ARG_MAX);
943 
944   ASSERT_EQ(op->inputs.size(), 2);
945   ASSERT_EQ(op->outputs.size(), 1);
946 
947   auto float_graph = readonly_model_->subgraphs()->Get(0);
948   auto float_op = float_graph->operators()->Get(0);
949   // Verify ArgMax input is quantized.
950   ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(0))->type(),
951             TensorType_FLOAT32);
952   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
953 
954   // Verify ArgMax input axis should still be the same type.
955   ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(1))->type(),
956             subgraph->tensors[op->inputs[1]].get()->type);
957 
958   // The output of ArgMax should still be the same type.
959   ASSERT_EQ(float_graph->tensors()->Get(float_op->outputs()->Get(0))->type(),
960             subgraph->tensors[op->outputs[0]].get()->type);
961 
962   // check op and versioning.
963   EXPECT_EQ(model_.operator_codes.size(), 1);
964   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
965             BuiltinOperator_ARG_MAX);
966   EXPECT_EQ(model_.operator_codes[0]->version, 2);
967 }
968 
969 class QuantizeLSTMTest : public QuantizeModelTest {
970  protected:
QuantizeLSTMTest()971   QuantizeLSTMTest() {
972     input_model_ = ReadModel(internal::kLstmCalibrated);
973     readonly_model_ = input_model_->GetModel();
974     readonly_model_->UnPackTo(&model_);
975   }
976 };
977 
TEST_F(QuantizeLSTMTest,VerifyLSTM)978 TEST_F(QuantizeLSTMTest, VerifyLSTM) {
979   // Quantize model.
980   auto status = QuantizeModelAllOperators(
981       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, true,
982       TensorType_INT8, &error_reporter_);
983   ASSERT_EQ(kTfLiteOk, status);
984 
985   // Read expected model.
986   auto expected_fb_model = ReadModel(internal::kLstmQuantized);
987   auto expected_read_only_model = expected_fb_model->GetModel();
988   ModelT expected_model;
989   expected_read_only_model->UnPackTo(&expected_model);
990 
991   ExpectSameModels(model_, expected_model);
992 }
993 
994 class QuantizeLSTM2Test : public QuantizeModelTest {
995  protected:
QuantizeLSTM2Test()996   QuantizeLSTM2Test() {
997     input_model_ = ReadModel(internal::kLstmCalibrated2);
998     readonly_model_ = input_model_->GetModel();
999     readonly_model_->UnPackTo(&model_);
1000   }
1001 };
1002 
TEST_F(QuantizeLSTM2Test,VerifyLSTM)1003 TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
1004   // Quantize model.
1005   auto status = QuantizeModelAllOperators(
1006       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
1007       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
1008   ASSERT_EQ(kTfLiteOk, status);
1009 
1010   // Read expected model.
1011   auto expected_fb_model = ReadModel(internal::kLstmQuantized2);
1012   auto expected_read_only_model = expected_fb_model->GetModel();
1013   ModelT expected_model;
1014   expected_read_only_model->UnPackTo(&expected_model);
1015 
1016   ExpectSameModels(model_, expected_model);
1017 }
1018 
1019 class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest {
1020  protected:
QuantizeUnidirectionalSequenceLSTMTest()1021   QuantizeUnidirectionalSequenceLSTMTest() {
1022     input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated);
1023     readonly_model_ = input_model_->GetModel();
1024     readonly_model_->UnPackTo(&model_);
1025   }
1026 };
1027 
TEST_F(QuantizeUnidirectionalSequenceLSTMTest,VerifyUnidirectionalSequenceLSTM)1028 TEST_F(QuantizeUnidirectionalSequenceLSTMTest,
1029        VerifyUnidirectionalSequenceLSTM) {
1030   // Quantize model.
1031   auto status = QuantizeModelAllOperators(
1032       &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
1033       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
1034   ASSERT_EQ(kTfLiteOk, status);
1035 
1036   // Read expected model.
1037   auto expected_fb_model =
1038       ReadModel(internal::kUnidirectionalSequenceLstmQuantized);
1039   auto expected_read_only_model = expected_fb_model->GetModel();
1040   ModelT expected_model;
1041   expected_read_only_model->UnPackTo(&expected_model);
1042 
1043   ExpectSameModels(model_, expected_model);
1044 }
1045 
1046 class QuantizeSVDFTest : public QuantizeModelTest {
1047  protected:
QuantizeSVDFTest()1048   QuantizeSVDFTest() {
1049     input_model_ = ReadModel(internal::kSvdfCalibrated);
1050     readonly_model_ = input_model_->GetModel();
1051     readonly_model_->UnPackTo(&model_);
1052   }
1053 };
1054 
TEST_F(QuantizeSVDFTest,VerifySVDF)1055 TEST_F(QuantizeSVDFTest, VerifySVDF) {
1056   // Quantize model.
1057   auto status = QuantizeModelAllOperators(
1058       &builder_, &model_, TensorType_INT8, TensorType_INT8,
1059       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
1060   ASSERT_EQ(kTfLiteOk, status);
1061 
1062   // Read expected model.
1063   auto expected_fb_model = ReadModel(internal::kSvdfQuantized);
1064   auto expected_read_only_model = expected_fb_model->GetModel();
1065   ModelT expected_model;
1066   expected_read_only_model->UnPackTo(&expected_model);
1067 
1068   ExpectSameModels(model_, expected_model);
1069 }
1070 
1071 class QuantizeFCTest : public QuantizeModelTest {
1072  protected:
QuantizeFCTest()1073   QuantizeFCTest() {
1074     input_model_ = ReadModel(internal::kModelWithFCOp);
1075     readonly_model_ = input_model_->GetModel();
1076     readonly_model_->UnPackTo(&model_);
1077   }
1078 };
1079 
TEST_F(QuantizeFCTest,VerifyFC)1080 TEST_F(QuantizeFCTest, VerifyFC) {
1081   auto status = QuantizeModelAllOperators(
1082       &builder_, &model_, TensorType_INT8, TensorType_INT8,
1083       /*allow_float=*/false, TensorType_INT8, &error_reporter_);
1084   ASSERT_EQ(kTfLiteOk, status);
1085 
1086   const auto& subgraph = model_.subgraphs[0];
1087   auto op = subgraph->operators[0].get();
1088   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1089             BuiltinOperator_FULLY_CONNECTED);
1090 
1091   ASSERT_EQ(op->inputs.size(), 3);
1092   ASSERT_EQ(op->outputs.size(), 1);
1093 
1094   auto float_graph = readonly_model_->subgraphs()->Get(0);
1095   // Verify FC input and weight is quantized.
1096   auto float_op = float_graph->operators()->Get(0);
1097   ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(0))->type(),
1098             TensorType_FLOAT32);
1099   EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1100   ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(1))->type(),
1101             TensorType_FLOAT32);
1102   EXPECT_EQ(subgraph->tensors[op->inputs[1]].get()->type, TensorType_INT8);
1103 
1104   // Verify FC bias should be int32 quantized.
1105   ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(2))->type(),
1106             TensorType_FLOAT32);
1107   EXPECT_EQ(subgraph->tensors[op->inputs[2]].get()->type, TensorType_INT32);
1108 
1109   // The output of FC should be quantized.
1110   ASSERT_EQ(float_graph->tensors()->Get(float_op->outputs()->Get(0))->type(),
1111             TensorType_FLOAT32);
1112   EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1113 
1114   // check op and versioning.
1115   EXPECT_EQ(model_.operator_codes.size(), 1);
1116   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1117             BuiltinOperator_FULLY_CONNECTED);
1118   EXPECT_EQ(model_.operator_codes[0]->version, 5);
1119 }
1120 
1121 class QuantizeCustomOpTest
1122     : public QuantizeModelTest,
1123       public ::testing::WithParamInterface<tflite::TensorType> {
1124  protected:
QuantizeCustomOpTest()1125   QuantizeCustomOpTest() {
1126     input_model_ = ReadModel(internal::kModelMixed);
1127     readonly_model_ = input_model_->GetModel();
1128     readonly_model_->UnPackTo(&model_);
1129   }
1130 };
1131 
TEST_P(QuantizeCustomOpTest,VerifyMixedQuantization)1132 TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) {
1133   auto status = QuantizeModelAllOperators(
1134       &builder_, &model_, GetParam(), GetParam(),
1135       /*allow_float=*/true, GetParam(), &error_reporter_);
1136   ASSERT_EQ(kTfLiteOk, status);
1137   const auto& subgraph = model_.subgraphs[0];
1138   auto float_graph = readonly_model_->subgraphs()->Get(0);
1139   // The original model reshape->custom->custom->squeeze.
1140   ASSERT_EQ(float_graph->operators()->size(), 4);
1141   // The resulting model should be:
1142   // reshape->dequantize->custom->custom->quantize->squeeze.
1143   ASSERT_EQ(subgraph->operators.size(), 6);
1144   const std::vector<BuiltinOperator> op_codes = {
1145       BuiltinOperator_RESHAPE,  BuiltinOperator_DEQUANTIZE,
1146       BuiltinOperator_CUSTOM,   BuiltinOperator_CUSTOM,
1147       BuiltinOperator_QUANTIZE, BuiltinOperator_SQUEEZE};
1148   const std::vector<TensorType> op_input_types = {
1149       GetParam(),         GetParam(),         TensorType_FLOAT32,
1150       TensorType_FLOAT32, TensorType_FLOAT32, GetParam()};
1151   for (int i = 0; i < subgraph->operators.size(); ++i) {
1152     OperatorT* op = subgraph->operators[i].get();
1153     ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1154               op_codes[i]);
1155     ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]);
1156   }
1157 }
1158 
1159 INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest,
1160                          ::testing::Values(TensorType_INT8));
1161 
1162 class QuantizePackTest : public QuantizeModelTest {
1163  protected:
QuantizePackTest()1164   QuantizePackTest() {
1165     input_model_ = ReadModel(internal::kModelPack);
1166     readonly_model_ = input_model_->GetModel();
1167     readonly_model_->UnPackTo(&model_);
1168   }
1169 };
1170 
TEST_F(QuantizePackTest,VerifyPack)1171 TEST_F(QuantizePackTest, VerifyPack) {
1172   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1173 
1174   ASSERT_EQ(kTfLiteOk, status);
1175 
1176   const auto subgraph = model_.subgraphs[0].get();
1177 
1178   // The model should only have 3 inputs and 1 output.
1179   EXPECT_EQ(subgraph->inputs.size(), 3);
1180   EXPECT_EQ(subgraph->outputs.size(), 1);
1181 
1182   const auto& op1 = subgraph->operators[1].get();
1183   const auto& op2 = subgraph->operators[2].get();
1184   const auto& op3 = subgraph->operators[3].get();
1185   const auto& op4 = subgraph->operators[4].get();
1186 
1187   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op1->opcode_index].get()),
1188             BuiltinOperator_QUANTIZE);
1189   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op2->opcode_index].get()),
1190             BuiltinOperator_QUANTIZE);
1191   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op3->opcode_index].get()),
1192             BuiltinOperator_PACK);
1193   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op4->opcode_index].get()),
1194             BuiltinOperator_DEQUANTIZE);
1195 
1196   const auto& pack_input0 = subgraph->tensors[op3->inputs[0]].get();
1197   const auto& pack_input1 = subgraph->tensors[op3->inputs[1]].get();
1198   const auto& pack_input2 = subgraph->tensors[op3->inputs[2]].get();
1199 
1200   const auto& pack_output = subgraph->tensors[op3->outputs[0]].get();
1201 
1202   // Check quantization parameters for input and output.
1203   EXPECT_FLOAT_EQ(pack_input0->quantization->scale[0],
1204                   pack_input1->quantization->scale[0]);
1205   EXPECT_FLOAT_EQ(pack_input1->quantization->scale[0],
1206                   pack_input2->quantization->scale[0]);
1207   EXPECT_FLOAT_EQ(pack_input0->quantization->zero_point[0],
1208                   pack_input1->quantization->zero_point[0]);
1209   EXPECT_FLOAT_EQ(pack_input1->quantization->zero_point[0],
1210                   pack_input2->quantization->zero_point[0]);
1211 
1212   EXPECT_FLOAT_EQ(pack_input1->quantization->scale[0],
1213                   pack_output->quantization->scale[0]);
1214   EXPECT_FLOAT_EQ(pack_input1->quantization->zero_point[0],
1215                   pack_output->quantization->zero_point[0]);
1216 
1217   // Check type of input and output.
1218   EXPECT_EQ(pack_output->type, TensorType_INT8);
1219   EXPECT_EQ(pack_input0->type, TensorType_INT8);
1220   EXPECT_EQ(pack_input1->type, TensorType_INT8);
1221   EXPECT_EQ(pack_input2->type, TensorType_INT8);
1222 }
1223 
1224 class QuantizeMinimumMaximumTest
1225     : public QuantizeModelTest,
1226       public testing::WithParamInterface<const char*> {
1227  protected:
QuantizeMinimumMaximumTest()1228   QuantizeMinimumMaximumTest() {
1229     input_model_ = ReadModel(GetParam());
1230     readonly_model_ = input_model_->GetModel();
1231     readonly_model_->UnPackTo(&model_);
1232   }
1233 };
1234 
TEST_P(QuantizeMinimumMaximumTest,VerifyMinimumMaximum)1235 TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
1236   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1237   ASSERT_EQ(kTfLiteOk, status);
1238   const auto& subgraph = model_.subgraphs[0];
1239   // Check that the first op is Quantize and the last is Dequant.
1240   const auto& quant_op = subgraph->operators[0];
1241   const auto& dequant_op = subgraph->operators[subgraph->operators.size() - 1];
1242   const int32_t quant_idx = quant_op->opcode_index;
1243   const int32_t dequant_idx = dequant_op->opcode_index;
1244   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[quant_idx].get()),
1245             BuiltinOperator_QUANTIZE);
1246   EXPECT_EQ(GetBuiltinCode(model_.operator_codes[dequant_idx].get()),
1247             BuiltinOperator_DEQUANTIZE);
1248 
1249   const auto& op = subgraph->operators[1].get();
1250 
1251   // Check that we have MINIMUM or MAXIMUM operator.
1252   auto op_builtin_code =
1253       GetBuiltinCode(model_.operator_codes[op->opcode_index].get());
1254   ASSERT_TRUE(op_builtin_code == tflite::BuiltinOperator_MINIMUM ||
1255               op_builtin_code == tflite::BuiltinOperator_MAXIMUM);
1256 
1257   // Check that we have two inputs and one output.
1258   ASSERT_EQ(op->inputs.size(), 2);
1259   ASSERT_EQ(op->outputs.size(), 1);
1260 
1261   // Check that all is quantized.
1262   auto output = subgraph->tensors[op->outputs[0]].get();
1263   auto input1 = subgraph->tensors[op->inputs[0]].get();
1264   auto input2 = subgraph->tensors[op->inputs[1]].get();
1265 
1266   EXPECT_EQ(output->type, TensorType_INT8);
1267   EXPECT_EQ(input1->type, TensorType_INT8);
1268   EXPECT_EQ(input2->type, TensorType_INT8);
1269 
1270   // Check if the quantization params of the minimum/maximum inputs match
1271   // after requantization
1272   EXPECT_EQ(input1->quantization->scale, input2->quantization->scale);
1273   EXPECT_EQ(input1->quantization->zero_point, input2->quantization->zero_point);
1274 
1275   // Check the input quantization params match the output ones.
1276   EXPECT_EQ(output->quantization->scale, input1->quantization->scale);
1277   EXPECT_EQ(output->quantization->zero_point, input1->quantization->zero_point);
1278   EXPECT_EQ(output->quantization->scale, input2->quantization->scale);
1279   EXPECT_EQ(output->quantization->zero_point, input2->quantization->zero_point);
1280 }
1281 
1282 INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest,
1283                          testing::ValuesIn({internal::kModelWithMinimumOp,
1284                                             internal::kModelWithMaximumOp}));
1285 
1286 class QuantizeUnpackTest : public QuantizeModelTest {
1287  protected:
QuantizeUnpackTest()1288   QuantizeUnpackTest() {
1289     input_model_ = ReadModel(internal::kModelWithUnpack);
1290     readonly_model_ = input_model_->GetModel();
1291     readonly_model_->UnPackTo(&model_);
1292   }
1293 };
1294 
TEST_F(QuantizeUnpackTest,VerifyUnpack)1295 TEST_F(QuantizeUnpackTest, VerifyUnpack) {
1296   auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1297 
1298   ASSERT_EQ(kTfLiteOk, status);
1299 
1300   const auto subgraph = model_.subgraphs[0].get();
1301   auto op = subgraph->operators[1].get();
1302 
1303   auto float_graph = readonly_model_->subgraphs()->Get(0);
1304 
1305   ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1306             BuiltinOperator_UNPACK);
1307 
1308   // Get unpack input and output tensors
1309   auto unpack_input = subgraph->tensors[op->inputs[0]].get();
1310   auto unpack_output_0 = subgraph->tensors[op->outputs[0]].get();
1311   auto unpack_output_1 = subgraph->tensors[op->outputs[1]].get();
1312 
1313   // Verify Unpack input is quantized.
1314   ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1315             TensorType_FLOAT32);
1316   EXPECT_EQ(unpack_input->type, TensorType_INT8);
1317 
1318   // The model should only have one input and 2 outputs.
1319   EXPECT_EQ(subgraph->inputs.size(), 1);
1320   EXPECT_EQ(subgraph->outputs.size(), 2);
1321 
1322   // Ensure quantization parameters before and after unpack
1323   // are preserved after quantization for all outputs of
1324   // unpack.
1325   EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0],
1326                   unpack_output_0->quantization->scale[0]);
1327   EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0],
1328                   unpack_output_1->quantization->scale[0]);
1329   EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0],
1330                   unpack_output_0->quantization->zero_point[0]);
1331   EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0],
1332                   unpack_output_1->quantization->zero_point[0]);
1333 }
1334 
1335 class QuantizeBroadcastToModelTest
1336     : public QuantizeModelTest,
1337       public testing::WithParamInterface<TensorType> {
1338  protected:
QuantizeBroadcastToModelTest()1339   QuantizeBroadcastToModelTest() {
1340     tensor_type_ = GetParam();
1341     input_model_ = ReadModel(internal::kModelWithBroadcastToOp);
1342     readonly_model_ = input_model_->GetModel();
1343     readonly_model_->UnPackTo(&model_);
1344   }
1345   TensorType tensor_type_;
1346 };
1347 
1348 INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst,
1349                          QuantizeBroadcastToModelTest,
1350                          testing::ValuesIn({TensorType_INT8}));
1351 
TEST_P(QuantizeBroadcastToModelTest,VerifyBroadcastToQuantization)1352 TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
1353   auto status = QuantizeModelAllOperators(
1354       &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
1355       tensor_type_, &error_reporter_);
1356   EXPECT_EQ(status, kTfLiteOk);
1357 
1358   // There is only one subgraph.
1359   const int32_t subgraph_idx = 0;
1360   const auto& subgraph = model_.subgraphs[subgraph_idx];
1361   const auto& readonly_subgraph =
1362       readonly_model_->subgraphs()->Get(subgraph_idx);
1363 
1364   // There should be a single broadcast_to op.
1365   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1366   EXPECT_EQ(subgraph->operators.size(), 1);
1367   const auto& broadcast_to = subgraph->operators[0];
1368   EXPECT_EQ(model_.operator_codes[broadcast_to->opcode_index]->builtin_code,
1369             BuiltinOperator_BROADCAST_TO);
1370 
1371   // There should be 3 tensors: input, output, and BroadcastTo/shape.
1372   EXPECT_EQ(subgraph->tensors.size(), 3);
1373 
1374   // Input Tensor
1375   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
1376   EXPECT_EQ(subgraph->tensors[0]->name, "input_1");
1377   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
1378   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
1379 
1380   // Output Tensor. The name given in the generated
1381   // .bin test file is 'Identity' and should be preserved
1382   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
1383   EXPECT_EQ(subgraph->tensors[2]->name, "Identity");
1384   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
1385   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
1386 
1387   // The BroadCastTo shape is of type INT32 and should not be quantized
1388   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
1389   EXPECT_EQ(subgraph->tensors[1]->name,
1390             "model/tf.broadcast_to/BroadcastTo/shape");
1391   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1392   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1393 
1394   // check op and versioning.
1395   EXPECT_EQ(model_.operator_codes.size(), 1);
1396   EXPECT_EQ(model_.operator_codes[0]->builtin_code,
1397             BuiltinOperator_BROADCAST_TO);
1398   EXPECT_EQ(model_.operator_codes[0]->version, 3);
1399 }
1400 
1401 class QuantizeGatherNDModelTest
1402     : public QuantizeModelTest,
1403       public testing::WithParamInterface<TensorType> {
1404  protected:
QuantizeGatherNDModelTest()1405   QuantizeGatherNDModelTest() {
1406     tensor_type_ = GetParam();
1407     input_model_ = ReadModel(internal::kModelWithGatherNDOp);
1408     readonly_model_ = input_model_->GetModel();
1409     readonly_model_->UnPackTo(&model_);
1410   }
1411 
1412   TensorType tensor_type_;
1413 };
1414 
1415 INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst,
1416                          QuantizeGatherNDModelTest,
1417                          testing::ValuesIn({TensorType_INT8}));
1418 
TEST_P(QuantizeGatherNDModelTest,QuantizeGatherND)1419 TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) {
1420   auto status = QuantizeModelAllOperators(
1421       &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
1422       tensor_type_, &error_reporter_);
1423   EXPECT_EQ(status, kTfLiteOk);
1424 
1425   // There is only one subgraph.
1426   const int32_t subgraph_idx = 0;
1427   const auto& subgraph = model_.subgraphs[subgraph_idx];
1428   const auto& readonly_subgraph =
1429       readonly_model_->subgraphs()->Get(subgraph_idx);
1430 
1431   // There should be a single gather_nd op.
1432   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1433   EXPECT_EQ(subgraph->operators.size(), 1);
1434   const auto& gather_nd = subgraph->operators[0];
1435   EXPECT_EQ(model_.operator_codes[gather_nd->opcode_index]->builtin_code,
1436             BuiltinOperator_GATHER_ND);
1437 
1438   // There should be 3 tensors: input, output, and indices.
1439   EXPECT_EQ(subgraph->tensors.size(), 3);
1440 
1441   // Input Tensor
1442   EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
1443   EXPECT_EQ(subgraph->tensors[0]->name, "input");
1444   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
1445   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
1446 
1447   // Output Tensor
1448   EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
1449   EXPECT_EQ(subgraph->tensors[2]->name, "output");
1450   EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
1451   EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
1452 
1453   // The gather indices are of type INT32 and should not be quantized
1454   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
1455   EXPECT_EQ(subgraph->tensors[1]->name, "indices");
1456   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1457   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1458 
1459   // Check op and versioning.
1460   EXPECT_EQ(model_.operator_codes.size(), 1);
1461   EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_GATHER_ND);
1462   EXPECT_EQ(model_.operator_codes[0]->version, 1);
1463 }
1464 
1465 class QuantizeWhereModelTest : public QuantizeModelTest {
1466  protected:
QuantizeWhereModelTest()1467   QuantizeWhereModelTest() {
1468     input_model_ = ReadModel(internal::kModelWithWhereOp);
1469     readonly_model_ = input_model_->GetModel();
1470     readonly_model_->UnPackTo(&model_);
1471   }
1472 };
1473 
TEST_F(QuantizeWhereModelTest,QuantizeWhere)1474 TEST_F(QuantizeWhereModelTest, QuantizeWhere) {
1475   // Where operator takes a BOOL tensor as input
1476   // and outputs INT64 indices, both of which
1477   // should not be quantized
1478   auto status = QuantizeModel(&builder_, &model_, TensorType_BOOL,
1479                               TensorType_INT64, &error_reporter_);
1480   EXPECT_EQ(status, kTfLiteOk);
1481 
1482   // There is only one subgraph.
1483   const int32_t subgraph_idx = 0;
1484   const auto& subgraph = model_.subgraphs[subgraph_idx];
1485   const auto& readonly_subgraph =
1486       readonly_model_->subgraphs()->Get(subgraph_idx);
1487 
1488   // There should be a single where op.
1489   EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1490   EXPECT_EQ(subgraph->operators.size(), 1);
1491   const auto& where = subgraph->operators[0];
1492   EXPECT_EQ(model_.operator_codes[where->opcode_index]->builtin_code,
1493             BuiltinOperator_WHERE);
1494 
1495   // There should be 2 tensors: input and output.
1496   EXPECT_EQ(subgraph->tensors.size(), 2);
1497 
1498   // Testing input tensor type and ensuring it
1499   // was not quantized
1500   EXPECT_EQ(subgraph->tensors[0]->type, TensorType_BOOL);
1501   EXPECT_EQ(subgraph->tensors[0]->name, "input");
1502   EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 0);
1503   EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 0);
1504 
1505   // Testing output (indices) tensor type and ensuring it
1506   // was not quantized
1507   EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT64);
1508   EXPECT_EQ(subgraph->tensors[1]->name, "indices");
1509   EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1510   EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1511 
1512   // check op and versioning.
1513   EXPECT_EQ(model_.operator_codes.size(), 1);
1514   EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_WHERE);
1515   EXPECT_EQ(model_.operator_codes[0]->version, 1);
1516 }
1517 
1518 }  // namespace
1519 }  // namespace optimize
1520 }  // namespace tflite
1521 
main(int argc,char ** argv)1522 int main(int argc, char** argv) {
1523   tensorflow::string model_file;
1524   const std::vector<tensorflow::Flag> flag_list = {
1525       tensorflow::Flag("test_model_file", &model_file,
1526                        "Path to test tflite model file."),
1527   };
1528 
1529   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
1530   if (!parse_result) {
1531     std::cerr << "Required test_model_file\n";
1532     std::abort();
1533   }
1534   g_test_model_dir =
1535       new tensorflow::string(tensorflow::io::Dirname(model_file));
1536   ::tensorflow::port::InitMain(argv[0], &argc, &argv);
1537   return RUN_ALL_TESTS();
1538 }
1539