1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstdint>
20 #include <memory>
21 #include <optional>
22 #include <string>
23
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
27 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/platform/init_main.h"
30 #include "tensorflow/core/util/command_line_flags.h"
31 #include "tensorflow/lite/model.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 #include "tensorflow/lite/schema/schema_utils.h"
34 #include "tensorflow/lite/tools/optimize/test_util.h"
35
36 // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc
37
38 namespace {
39 tensorflow::string* g_test_model_dir = nullptr;
40 } // namespace
41
42 namespace tflite {
43 namespace optimize {
44 namespace {
45
QuantizeModel(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,bool allow_float,const std::unordered_set<string> & operator_names,const TensorType & activations_type,ErrorReporter * error_reporter,bool disable_per_channel=false,const absl::flat_hash_set<std::string> & blocked_ops={},const absl::flat_hash_set<std::string> & blocked_nodes={})46 TfLiteStatus QuantizeModel(
47 flatbuffers::FlatBufferBuilder* builder, ModelT* model,
48 const TensorType& input_type, const TensorType& output_type,
49 bool allow_float, const std::unordered_set<string>& operator_names,
50 const TensorType& activations_type, ErrorReporter* error_reporter,
51 bool disable_per_channel = false,
52 const absl::flat_hash_set<std::string>& blocked_ops = {},
53 const absl::flat_hash_set<std::string>& blocked_nodes = {}) {
54 TensorType inference_tensor_type = activations_type;
55 bool fully_quantize = !allow_float;
56 auto status = mlir::lite::QuantizeModel(
57 *model, input_type, output_type, inference_tensor_type,
58 /*operator_names=*/{}, disable_per_channel, fully_quantize, builder,
59 error_reporter, /*verify_numeric=*/false, /*whole_model_verify=*/false,
60 /*legacy_float_scale=*/true, blocked_ops, blocked_nodes);
61 if (status != kTfLiteOk) {
62 return status;
63 }
64 std::string buffer(
65 reinterpret_cast<const char*>(builder->GetCurrentBufferPointer()),
66 builder->GetSize());
67
68 auto flatbuffer_model =
69 FlatBufferModel::BuildFromBuffer(buffer.c_str(), buffer.size());
70 flatbuffer_model->GetModel()->UnPackTo(model);
71 return kTfLiteOk;
72 }
73
QuantizeModel(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,bool allow_float,ErrorReporter * error_reporter)74 TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
75 ModelT* model, const TensorType& input_type,
76 const TensorType& output_type, bool allow_float,
77 ErrorReporter* error_reporter) {
78 return QuantizeModel(builder, model, input_type, output_type, allow_float,
79 /*operator_names=*/{}, TensorType_INT8, error_reporter);
80 }
81
QuantizeModel(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,ErrorReporter * error_reporter)82 TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
83 ModelT* model, const TensorType& input_type,
84 const TensorType& output_type,
85 ErrorReporter* error_reporter) {
86 return QuantizeModel(builder, model, input_type, output_type,
87 /*allow_float=*/false, error_reporter);
88 }
89
QuantizeModel(flatbuffers::FlatBufferBuilder * builder,ModelT * model,ErrorReporter * error_reporter)90 TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
91 ModelT* model, ErrorReporter* error_reporter) {
92 return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32,
93 /*allow_float=*/true, error_reporter);
94 }
95
QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,bool allow_float,const TensorType & activations_type,bool disable_per_channel,ErrorReporter * error_reporter)96 TfLiteStatus QuantizeModelAllOperators(
97 flatbuffers::FlatBufferBuilder* builder, ModelT* model,
98 const TensorType& input_type, const TensorType& output_type,
99 bool allow_float, const TensorType& activations_type,
100 bool disable_per_channel, ErrorReporter* error_reporter) {
101 return QuantizeModel(builder, model, input_type, output_type, allow_float,
102 /*operator_names=*/{}, activations_type, error_reporter,
103 disable_per_channel);
104 }
105
QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder * builder,ModelT * model,const TensorType & input_type,const TensorType & output_type,bool allow_float,const TensorType & activations_type,ErrorReporter * error_reporter)106 TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder,
107 ModelT* model,
108 const TensorType& input_type,
109 const TensorType& output_type,
110 bool allow_float,
111 const TensorType& activations_type,
112 ErrorReporter* error_reporter) {
113 return QuantizeModel(builder, model, input_type, output_type, allow_float,
114 /*operator_names=*/{}, activations_type, error_reporter);
115 }
116
ReadModel(const string & model_name)117 std::unique_ptr<FlatBufferModel> ReadModel(const string& model_name) {
118 auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name);
119 return FlatBufferModel::BuildFromFile(model_path.c_str());
120 }
121
122 template <typename T>
GetAsVector(const flatbuffers::Vector<T> * vec)123 std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
124 return std::vector<T>(vec->begin(), vec->end());
125 }
126
VerifyAsymmetricQuantizationScale(const QuantizationParameters & float_quant_params,const QuantizationParametersT & quantized_quant_params)127 void VerifyAsymmetricQuantizationScale(
128 const QuantizationParameters& float_quant_params,
129 const QuantizationParametersT& quantized_quant_params) {
130 const float eps = 1e-7;
131 ASSERT_EQ(float_quant_params.min()->size(), 1);
132 ASSERT_EQ(float_quant_params.max()->size(), 1);
133 float float_min = std::min(0.f, float_quant_params.min()->Get(0));
134 float float_max = std::max(0.f, float_quant_params.max()->Get(0));
135
136 ASSERT_EQ(quantized_quant_params.scale.size(), 1);
137 ASSERT_EQ(quantized_quant_params.zero_point.size(), 1);
138 float scale = (float_max - float_min) / 255;
139 EXPECT_NEAR(scale, quantized_quant_params.scale[0], eps);
140 }
141
142 class QuantizeModelTest : public testing::Test {
143 protected:
QuantizeModelTest()144 QuantizeModelTest() {
145 input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
146 readonly_model_ = input_model_->GetModel();
147 readonly_model_->UnPackTo(&model_);
148 }
149
150 std::unique_ptr<FlatBufferModel> input_model_;
151 const Model* readonly_model_;
152 tflite::ModelT model_;
153 flatbuffers::FlatBufferBuilder builder_;
154 internal::FailOnErrorReporter error_reporter_;
155 };
156
ExpectEqualTensor(TensorT * tensor,TensorT * expected_tensor)157 void ExpectEqualTensor(TensorT* tensor, TensorT* expected_tensor) {
158 const float eps = 1e-7;
159 EXPECT_NE(expected_tensor, nullptr);
160 EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable);
161 EXPECT_EQ(tensor->shape, expected_tensor->shape);
162 EXPECT_EQ(tensor->type, expected_tensor->type);
163 const auto quantization_params = tensor->quantization.get();
164 const auto expected_quantization_params = expected_tensor->quantization.get();
165 if (quantization_params != nullptr &&
166 expected_quantization_params != nullptr) {
167 for (int i = 0; i < quantization_params->scale.size(); ++i) {
168 if (quantization_params->scale[i] > 3e-5) {
169 EXPECT_NEAR(quantization_params->scale[i],
170 expected_quantization_params->scale[i], eps);
171 }
172 }
173 EXPECT_EQ(quantization_params->zero_point,
174 expected_quantization_params->zero_point);
175 }
176 }
177
FindMatchingExpectedTensor(const SubGraphT & expected_graph,const ModelT & expected_model,const ModelT & quant_model,const OperatorT & quant_op,int idx)178 TensorT* FindMatchingExpectedTensor(const SubGraphT& expected_graph,
179 const ModelT& expected_model,
180 const ModelT& quant_model,
181 const OperatorT& quant_op, int idx) {
182 const auto& builtin_code =
183 GetBuiltinCode(quant_model.operator_codes[quant_op.opcode_index].get());
184 for (const auto& expected_op : expected_graph.operators) {
185 const auto& op_code =
186 expected_model.operator_codes[expected_op->opcode_index].get();
187 const auto& expected_code = GetBuiltinCode(op_code);
188 if (expected_code == builtin_code) {
189 return expected_graph.tensors[expected_op->inputs[idx]].get();
190 }
191 }
192 return nullptr;
193 }
194
ExpectSameModels(const ModelT & model,const ModelT & expected_model)195 void ExpectSameModels(const ModelT& model, const ModelT& expected_model) {
196 ASSERT_EQ(model.subgraphs.size(), expected_model.subgraphs.size());
197 for (size_t subgraph_idx = 0; subgraph_idx < model.subgraphs.size();
198 subgraph_idx++) {
199 const auto graph = model.subgraphs[subgraph_idx].get();
200 const auto expected_graph = expected_model.subgraphs[subgraph_idx].get();
201 for (auto& op : graph->operators) {
202 for (int idx = 0; idx < op->inputs.size(); idx++) {
203 if (op->inputs[idx] < 0) {
204 continue;
205 }
206 const auto& tensor = graph->tensors[op->inputs[idx]];
207 auto* expected_tensor = FindMatchingExpectedTensor(
208 *expected_graph, expected_model, model, *op, idx);
209 if (!expected_tensor) {
210 continue;
211 }
212 ExpectEqualTensor(tensor.get(), expected_tensor);
213 if (expected_tensor->buffer > 0) {
214 const int buffer_idx = tensor->buffer;
215 const int expected_buffer_idx = expected_tensor->buffer;
216 const auto buffer = model.buffers[buffer_idx].get()->data;
217 const auto expected_buffer =
218 expected_model.buffers[expected_buffer_idx].get()->data;
219 EXPECT_EQ(buffer, expected_buffer);
220 }
221 }
222 }
223 }
224 }
225
226 class QuantizeConvModelTest : public QuantizeModelTest,
227 public testing::WithParamInterface<TensorType> {
228 protected:
QuantizeConvModelTest()229 QuantizeConvModelTest() {
230 tensor_type_ = GetParam();
231 input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
232 readonly_model_ = input_model_->GetModel();
233 readonly_model_->UnPackTo(&model_);
234 // Flatbuffer is missing calibration data -- add dummy params.
235 auto& subgraph = model_.subgraphs[0];
236 auto* input = subgraph->tensors[subgraph->inputs[0]].get();
237 auto* output = subgraph->tensors[subgraph->outputs[0]].get();
238 input->quantization = std::make_unique<QuantizationParametersT>();
239 output->quantization = std::make_unique<QuantizationParametersT>();
240 input->quantization->min.push_back(0.0);
241 output->quantization->min.push_back(0.0);
242 input->quantization->max.push_back(6.0);
243 output->quantization->max.push_back(6.0);
244 }
245 TensorType tensor_type_;
246 };
247
248 INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest,
249 testing::ValuesIn({TensorType_INT8}));
250
TEST_P(QuantizeConvModelTest,QuantizationSucceeds)251 TEST_P(QuantizeConvModelTest, QuantizationSucceeds) {
252 auto status = QuantizeModelAllOperators(
253 &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
254 tensor_type_, &error_reporter_);
255 EXPECT_EQ(status, kTfLiteOk);
256 const uint8_t* buffer = builder_.GetBufferPointer();
257 const Model* output_model = GetModel(buffer);
258 ASSERT_TRUE(output_model);
259 }
260
TEST_P(QuantizeConvModelTest,SkipUnspecifiedLayer)261 TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) {
262 auto status = QuantizeModel(
263 &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
264 /*allow_float=*/true, /*operator_names=*/{}, TensorType_FLOAT32,
265 &error_reporter_, /*disable_per_channel=*/false, {"CONV_2D"});
266 EXPECT_EQ(status, kTfLiteOk);
267
268 ModelT expected_model;
269 readonly_model_->UnPackTo(&expected_model);
270 // The resulting model should be the same.
271 ExpectSameModels(model_, expected_model);
272 }
273
TEST_P(QuantizeConvModelTest,SkipUnspecifiedLayerByName)274 TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayerByName) {
275 auto status = QuantizeModel(
276 &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
277 /*allow_float=*/true, /*operator_names=*/{}, TensorType_FLOAT32,
278 &error_reporter_, /*disable_per_channel=*/false, /*blocked_ops=*/{},
279 {"output"});
280 EXPECT_EQ(status, kTfLiteOk);
281
282 ModelT expected_model;
283 readonly_model_->UnPackTo(&expected_model);
284 // The resulting model should be the same.
285 ExpectSameModels(model_, expected_model);
286 }
287
TEST_P(QuantizeConvModelTest,GraphIsFullyQuantized)288 TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) {
289 auto status = QuantizeModelAllOperators(
290 &builder_, &model_, tensor_type_, tensor_type_,
291 /*allow_float=*/false, tensor_type_, &error_reporter_);
292 EXPECT_EQ(status, kTfLiteOk);
293
294 for (const auto& subgraph : model_.subgraphs) {
295 for (const auto& tensor : subgraph->tensors) {
296 EXPECT_TRUE(tensor->type == TensorType_INT32 ||
297 tensor->type == TensorType_INT8);
298 }
299 }
300 }
301
302 class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
303 protected:
QuantizeConvNoBiasModelTest()304 QuantizeConvNoBiasModelTest() {
305 input_model_ = ReadModel(internal::kConvModelWithNoBias);
306 readonly_model_ = input_model_->GetModel();
307 readonly_model_->UnPackTo(&model_);
308 }
309 };
310
TEST_F(QuantizeConvNoBiasModelTest,QuantizationSucceeds)311 TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) {
312 auto status = QuantizeModelAllOperators(
313 &builder_, &model_, TensorType_INT8, TensorType_INT8,
314 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
315 EXPECT_EQ(status, kTfLiteOk);
316 const uint8_t* buffer = builder_.GetBufferPointer();
317 const Model* output_model = GetModel(buffer);
318 ASSERT_TRUE(output_model);
319 }
320
321 class QuantizeSplitModelTest : public QuantizeModelTest {
322 protected:
QuantizeSplitModelTest()323 QuantizeSplitModelTest() {
324 input_model_ = ReadModel(internal::kModelSplit);
325 readonly_model_ = input_model_->GetModel();
326 readonly_model_->UnPackTo(&model_);
327 }
328 };
329
330 // There are two outputs for split with different scales, the resulting model
331 // should have the scales be hardcodes to the input scale value.
TEST_F(QuantizeSplitModelTest,QuantizeSplit)332 TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
333 auto status = QuantizeModelAllOperators(
334 &builder_, &model_, TensorType_INT8, TensorType_INT8,
335 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
336 EXPECT_EQ(status, kTfLiteOk);
337
338 // There is only one subgraph.
339 const int32_t subgraph_idx = 0;
340 const auto& subgraph = model_.subgraphs[subgraph_idx];
341 const auto& readonly_subgraph =
342 readonly_model_->subgraphs()->Get(subgraph_idx);
343
344 // There should be two ops: the split and add in the original model.
345 EXPECT_EQ(readonly_subgraph->operators()->size(), 2);
346 EXPECT_EQ(subgraph->operators.size(), 2);
347 const auto& split = subgraph->operators[0];
348 const auto& add = subgraph->operators[1];
349 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[split->opcode_index].get()),
350 BuiltinOperator_SPLIT);
351 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[add->opcode_index].get()),
352 BuiltinOperator_ADD);
353
354 // There should be 5 tensors: input, output, split, split/split_dim, split:1.
355 // Tensor indices could be different between original and quantized.
356 EXPECT_EQ(subgraph->tensors.size(), 5);
357 const int input_idx = 0;
358 EXPECT_EQ(subgraph->tensors[input_idx]->type, TensorType_INT8);
359 EXPECT_EQ(subgraph->tensors[input_idx]->name, "input");
360 EXPECT_EQ(subgraph->tensors[input_idx]->quantization->scale.size(), 1);
361 EXPECT_EQ(subgraph->tensors[input_idx]->quantization->zero_point.size(), 1);
362 EXPECT_FLOAT_EQ(subgraph->tensors[input_idx]->quantization->scale[0], 1.0);
363 EXPECT_FLOAT_EQ(subgraph->tensors[input_idx]->quantization->zero_point[0],
364 -128);
365 const int output_idx = 4;
366 EXPECT_EQ(subgraph->tensors[output_idx]->type, TensorType_INT8);
367 EXPECT_EQ(subgraph->tensors[output_idx]->name, "output");
368 EXPECT_EQ(subgraph->tensors[output_idx]->quantization->scale.size(), 1);
369 EXPECT_EQ(subgraph->tensors[output_idx]->quantization->zero_point.size(), 1);
370 EXPECT_FLOAT_EQ(subgraph->tensors[output_idx]->quantization->scale[0], 1.0);
371 EXPECT_FLOAT_EQ(subgraph->tensors[output_idx]->quantization->zero_point[0],
372 -128);
373 const int split0_idx = 2;
374 EXPECT_EQ(subgraph->tensors[split0_idx]->type, TensorType_INT8);
375 EXPECT_EQ(subgraph->tensors[split0_idx]->name, "split;split:1");
376 EXPECT_EQ(subgraph->tensors[split0_idx]->quantization->scale.size(), 1);
377 EXPECT_EQ(subgraph->tensors[split0_idx]->quantization->zero_point.size(), 1);
378 EXPECT_FLOAT_EQ(subgraph->tensors[split0_idx]->quantization->scale[0], 1.0);
379 EXPECT_FLOAT_EQ(subgraph->tensors[split0_idx]->quantization->zero_point[0],
380 -128);
381 const int split1_idx = 3;
382 EXPECT_EQ(subgraph->tensors[split1_idx]->type, TensorType_INT8);
383 EXPECT_EQ(subgraph->tensors[split1_idx]->name, "split;split:11");
384 EXPECT_EQ(subgraph->tensors[split1_idx]->quantization->scale.size(), 1);
385 EXPECT_EQ(subgraph->tensors[split1_idx]->quantization->zero_point.size(), 1);
386 EXPECT_FLOAT_EQ(subgraph->tensors[split1_idx]->quantization->scale[0], 1.0);
387 EXPECT_FLOAT_EQ(subgraph->tensors[split1_idx]->quantization->zero_point[0],
388 -128);
389
390 // check op and versioning.
391 EXPECT_EQ(model_.operator_codes.size(), 2);
392 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
393 BuiltinOperator_SPLIT);
394 EXPECT_EQ(model_.operator_codes[0]->version, 2);
395 }
396
397 class QuantizeConvModel2Test : public QuantizeModelTest,
398 public testing::WithParamInterface<TensorType> {
399 protected:
QuantizeConvModel2Test()400 QuantizeConvModel2Test() {
401 tensor_type_ = GetParam();
402 input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
403 readonly_model_ = input_model_->GetModel();
404 readonly_model_->UnPackTo(&model_);
405 auto& subgraph = model_.subgraphs[0];
406 auto* input = subgraph->tensors[subgraph->inputs[0]].get();
407 auto* output = subgraph->tensors[subgraph->outputs[0]].get();
408 input->quantization = std::make_unique<QuantizationParametersT>();
409 output->quantization = std::make_unique<QuantizationParametersT>();
410 input->quantization->min.push_back(0.0);
411 output->quantization->min.push_back(0.0);
412 input->quantization->max.push_back(6.0);
413 output->quantization->max.push_back(6.0);
414 }
415
416 TensorType tensor_type_;
417 };
418
419 INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test,
420 testing::ValuesIn({TensorType_INT8}));
421
TEST_P(QuantizeConvModel2Test,VerifyConvQuantization)422 TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) {
423 auto status = QuantizeModelAllOperators(
424 &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
425 tensor_type_, &error_reporter_);
426 ASSERT_EQ(kTfLiteOk, status);
427 const auto& subgraph = model_.subgraphs[0];
428 auto conv_op = subgraph->operators[0].get();
429 const int input_tensor_idx = 0;
430 const int weights_tensor_idx = 1;
431 const int bias_tensor_index = 2;
432 const int output_tensor_idx = 0;
433 const auto bias_tensor =
434 subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
435 const auto input_tensor =
436 subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
437 const auto weights_tensor =
438 subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
439 const auto output_tensor =
440 subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
441
442 EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
443 ? TensorType_INT32
444 : TensorType_INT64);
445 EXPECT_EQ(input_tensor->type, tensor_type_);
446 EXPECT_EQ(weights_tensor->type, TensorType_INT8);
447
448 ASSERT_TRUE(weights_tensor->quantization);
449 ASSERT_TRUE(bias_tensor->quantization);
450 ASSERT_TRUE(weights_tensor->quantization);
451 const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
452 const std::vector<float>& weights_scales =
453 weights_tensor->quantization->scale;
454 const std::vector<int64_t>& weights_zero_points =
455 weights_tensor->quantization->zero_point;
456 const int out_channel_size = weights_tensor->shape[0];
457 ASSERT_EQ(bias_scales.size(), out_channel_size);
458 ASSERT_EQ(weights_scales.size(), out_channel_size);
459 ASSERT_EQ(weights_zero_points.size(), out_channel_size);
460 ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
461 ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
462
463 const float eps = 1e-7;
464
465 // Bias scale should be input * per_channel_weight_scale.
466 for (size_t i = 0; i < out_channel_size; i++) {
467 EXPECT_NEAR(bias_scales[i],
468 input_tensor->quantization->scale[0] * weights_scales[i], eps);
469 }
470
471 const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
472 auto control_size = tensor_type_ == TensorType_INT8
473 ? sizeof(int32_t) * bias_tensor->shape[0]
474 : sizeof(int64_t) * bias_tensor->shape[0];
475
476 const auto float_op =
477 readonly_model_->subgraphs()->Get(0)->operators()->Get(0);
478 const auto original_bias_tensor =
479 readonly_model_->subgraphs()->Get(0)->tensors()->Get(
480 float_op->inputs()->Get(2));
481 ASSERT_EQ(bias_buffer->data.size(), control_size);
482 const auto original_bias_buffer =
483 readonly_model_->buffers()->Get(original_bias_tensor->buffer());
484 const float* bias_float_buffer =
485 reinterpret_cast<const float*>(original_bias_buffer->data()->data());
486
487 if (tensor_type_ == TensorType_INT8) {
488 int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
489 for (size_t i = 0; i < out_channel_size; i++) {
490 auto dequantized_value = bias_values[i] * bias_scales[i];
491 EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
492 }
493 }
494
495 const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
496 const auto original_weights_tensor =
497 readonly_model_->subgraphs()->Get(0)->tensors()->Get(
498 float_op->inputs()->Get(1));
499 const auto original_weights_buffer =
500 readonly_model_->buffers()->Get(original_weights_tensor->buffer());
501 const int8_t* weight_values =
502 reinterpret_cast<int8_t*>(weights_buffer->data.data());
503 const float* weights_float_buffer =
504 reinterpret_cast<const float*>(original_weights_buffer->data()->data());
505 ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
506 original_weights_buffer->data()->size());
507 int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
508 for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
509 for (size_t j = 0; j < num_values_in_channel; j++) {
510 size_t element_idx = channel_idx * out_channel_size + j;
511 auto scale = weights_scales[channel_idx];
512 auto zero_point = weights_zero_points[channel_idx];
513 auto dequantized_value = weight_values[element_idx] * scale;
514 EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx],
515 scale / 2);
516 EXPECT_EQ(zero_point, 0);
517 }
518 }
519
520 // check op and versioning.
521 EXPECT_EQ(model_.operator_codes.size(), 1);
522 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
523 BuiltinOperator_CONV_2D);
524 EXPECT_EQ(model_.operator_codes[0]->version, 3);
525 }
526
TEST_P(QuantizeConvModel2Test,VerifyConvDisablePerChannelQuantization)527 TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) {
528 auto status = QuantizeModelAllOperators(
529 &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
530 tensor_type_, /*disable_per_channel=*/true, &error_reporter_);
531 ASSERT_EQ(kTfLiteOk, status);
532 const auto& subgraph = model_.subgraphs[0];
533 auto conv_op = subgraph->operators[0].get();
534 const int input_tensor_idx = 0;
535 const int weights_tensor_idx = 1;
536 const int bias_tensor_index = 2;
537 const int output_tensor_idx = 0;
538 const auto bias_tensor =
539 subgraph->tensors[conv_op->inputs[bias_tensor_index]].get();
540 const auto input_tensor =
541 subgraph->tensors[conv_op->inputs[input_tensor_idx]].get();
542 const auto weights_tensor =
543 subgraph->tensors[conv_op->inputs[weights_tensor_idx]].get();
544 const auto output_tensor =
545 subgraph->tensors[conv_op->outputs[output_tensor_idx]].get();
546
547 EXPECT_EQ(bias_tensor->type, tensor_type_ == TensorType_INT8
548 ? TensorType_INT32
549 : TensorType_INT64);
550 EXPECT_EQ(input_tensor->type, tensor_type_);
551 EXPECT_EQ(weights_tensor->type, TensorType_INT8);
552
553 ASSERT_TRUE(weights_tensor->quantization);
554 ASSERT_TRUE(bias_tensor->quantization);
555 ASSERT_TRUE(weights_tensor->quantization);
556 const std::vector<float>& bias_scales = bias_tensor->quantization->scale;
557 const std::vector<float>& weights_scales =
558 weights_tensor->quantization->scale;
559 const std::vector<int64_t>& weights_zero_points =
560 weights_tensor->quantization->zero_point;
561
562 const int out_channel_size = 1;
563 ASSERT_EQ(bias_scales.size(), out_channel_size);
564 ASSERT_EQ(weights_scales.size(), out_channel_size);
565 ASSERT_EQ(weights_zero_points.size(), out_channel_size);
566 ASSERT_EQ(input_tensor->quantization->scale.size(), 1);
567 ASSERT_EQ(output_tensor->quantization->scale.size(), 1);
568
569 const float eps = 1e-7;
570
571 // Bias scale should be input * per_channel_weight_scale.
572 for (size_t i = 0; i < out_channel_size; i++) {
573 EXPECT_NEAR(bias_scales[i],
574 input_tensor->quantization->scale[0] * weights_scales[i], eps);
575 }
576
577 const auto bias_buffer = model_.buffers[bias_tensor->buffer].get();
578 auto control_size = tensor_type_ == TensorType_INT8
579 ? sizeof(int32_t) * bias_tensor->shape[0]
580 : sizeof(int64_t) * bias_tensor->shape[0];
581
582 ASSERT_EQ(bias_buffer->data.size(), control_size);
583 const auto float_op =
584 readonly_model_->subgraphs()->Get(0)->operators()->Get(0);
585 const auto original_bias_tensor =
586 readonly_model_->subgraphs()->Get(0)->tensors()->Get(
587 float_op->inputs()->Get(2));
588 ASSERT_EQ(bias_buffer->data.size(), control_size);
589 const auto original_bias_buffer =
590 readonly_model_->buffers()->Get(original_bias_tensor->buffer());
591 const float* bias_float_buffer =
592 reinterpret_cast<const float*>(original_bias_buffer->data()->data());
593
594 if (tensor_type_ == TensorType_INT8) {
595 int32_t* bias_values = reinterpret_cast<int32_t*>(bias_buffer->data.data());
596 for (size_t i = 0; i < out_channel_size; i++) {
597 auto dequantized_value = bias_values[i] * bias_scales[i];
598 EXPECT_NEAR(dequantized_value, bias_float_buffer[i], bias_scales[i] / 2);
599 }
600 }
601
602 const auto weights_buffer = model_.buffers[weights_tensor->buffer].get();
603 const auto original_weights_tensor =
604 readonly_model_->subgraphs()->Get(0)->tensors()->Get(
605 float_op->inputs()->Get(1));
606 const auto original_weights_buffer =
607 readonly_model_->buffers()->Get(original_weights_tensor->buffer());
608 const int8_t* weight_values =
609 reinterpret_cast<int8_t*>(weights_buffer->data.data());
610 const float* weights_float_buffer =
611 reinterpret_cast<const float*>(original_weights_buffer->data()->data());
612 ASSERT_EQ(sizeof(float) * weights_buffer->data.size(),
613 original_weights_buffer->data()->size());
614 int num_values_in_channel = weights_buffer->data.size() / out_channel_size;
615 for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) {
616 for (size_t j = 0; j < num_values_in_channel; j++) {
617 size_t element_idx = channel_idx * out_channel_size + j;
618 auto scale = weights_scales[channel_idx];
619 auto zero_point = weights_zero_points[channel_idx];
620 auto dequantized_value = weight_values[element_idx] * scale;
621 EXPECT_NEAR(dequantized_value, weights_float_buffer[element_idx],
622 scale / 2);
623 EXPECT_EQ(zero_point, 0);
624 }
625 }
626
627 // check op and versioning.
628 EXPECT_EQ(model_.operator_codes.size(), 1);
629 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
630 BuiltinOperator_CONV_2D);
631 EXPECT_EQ(model_.operator_codes[0]->version, 3);
632 }
633
634 class QuantizeSoftmaxTest : public QuantizeModelTest {
635 protected:
QuantizeSoftmaxTest()636 QuantizeSoftmaxTest() {
637 input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5);
638 readonly_model_ = input_model_->GetModel();
639 readonly_model_->UnPackTo(&model_);
640 }
641 };
642
TEST_F(QuantizeSoftmaxTest,VerifySoftmaxQuantization)643 TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) {
644 auto status = QuantizeModelAllOperators(
645 &builder_, &model_, TensorType_INT8, TensorType_INT8,
646 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
647 ASSERT_EQ(kTfLiteOk, status);
648
649 const auto& subgraph = model_.subgraphs[0];
650 auto op = subgraph->operators[0].get();
651 // Model has a single softmax op.
652 ASSERT_EQ(op->opcode_index, 0);
653 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
654 BuiltinOperator_SOFTMAX);
655
656 ASSERT_EQ(op->inputs.size(), 1);
657 ASSERT_EQ(op->outputs.size(), 1);
658 auto float_graph = readonly_model_->subgraphs()->Get(0);
659
660 // Verify input.
661 ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
662 TensorType_FLOAT32);
663 ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
664 TensorType_FLOAT32);
665
666 EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
667 EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
668
669 auto float_input_quant_params =
670 float_graph->tensors()->Get(op->inputs[0])->quantization();
671 auto input_quant_params =
672 subgraph->tensors[op->inputs[0]]->quantization.get();
673 VerifyAsymmetricQuantizationScale(*float_input_quant_params,
674 *input_quant_params);
675
676 // Verify output.
677 auto float_output_quant_params =
678 float_graph->tensors()->Get(op->outputs[0])->quantization();
679 auto output_quant_params =
680 subgraph->tensors[op->outputs[0]]->quantization.get();
681 ASSERT_EQ(float_output_quant_params->min()->size(), 1);
682 ASSERT_EQ(float_output_quant_params->max()->size(), 1);
683
684 ASSERT_EQ(output_quant_params->scale.size(), 1);
685 ASSERT_EQ(output_quant_params->zero_point.size(), 1);
686 ASSERT_EQ(1.0f / 256.0f, output_quant_params->scale[0]);
687 ASSERT_EQ(-128, output_quant_params->zero_point[0]);
688
689 // check op and versioning.
690 EXPECT_EQ(model_.operator_codes.size(), 1);
691 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
692 BuiltinOperator_SOFTMAX);
693 EXPECT_EQ(model_.operator_codes[0]->version, 2);
694 }
695
696 class QuantizeAvgPoolTest : public QuantizeModelTest {
697 protected:
QuantizeAvgPoolTest()698 QuantizeAvgPoolTest() {
699 input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5);
700 readonly_model_ = input_model_->GetModel();
701 readonly_model_->UnPackTo(&model_);
702 }
703 };
704
TEST_F(QuantizeAvgPoolTest,VerifyAvgPoolQuantization)705 TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) {
706 auto status = QuantizeModelAllOperators(
707 &builder_, &model_, TensorType_INT8, TensorType_INT8,
708 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
709 ASSERT_EQ(kTfLiteOk, status);
710
711 const auto& subgraph = model_.subgraphs[0];
712 auto op = subgraph->operators[0].get();
713 // Model has a single AveragePool op.
714 ASSERT_EQ(op->opcode_index, 0);
715 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
716 BuiltinOperator_AVERAGE_POOL_2D);
717
718 ASSERT_EQ(op->inputs.size(), 1);
719 ASSERT_EQ(op->outputs.size(), 1);
720
721 auto float_graph = readonly_model_->subgraphs()->Get(0);
722 ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
723 TensorType_FLOAT32);
724 ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
725 TensorType_FLOAT32);
726
727 EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
728 EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
729
730 auto float_input_quant_params =
731 float_graph->tensors()->Get(op->inputs[0])->quantization();
732 auto input_quant_params =
733 subgraph->tensors[op->inputs[0]]->quantization.get();
734 VerifyAsymmetricQuantizationScale(*float_input_quant_params,
735 *input_quant_params);
736
737 auto float_output_quant_params =
738 float_graph->tensors()->Get(op->outputs[0])->quantization();
739 auto output_quant_params =
740 subgraph->tensors[op->outputs[0]]->quantization.get();
741 ASSERT_EQ(float_output_quant_params->min()->size(), 1);
742 ASSERT_EQ(float_output_quant_params->max()->size(), 1);
743 ASSERT_EQ(output_quant_params->scale.size(), 1);
744
745 // Make sure the input min/maxes are propagated to outputs.
746 EXPECT_EQ(input_quant_params->scale[0], output_quant_params->scale[0]);
747
748 // check op and versioning.
749 EXPECT_EQ(model_.operator_codes.size(), 1);
750 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
751 BuiltinOperator_AVERAGE_POOL_2D);
752 EXPECT_EQ(model_.operator_codes[0]->version, 2);
753 }
754
755 class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
756 protected:
QuantizeMultiInputAddWithReshapeTest()757 QuantizeMultiInputAddWithReshapeTest() {
758 input_model_ = ReadModel(internal::kMultiInputAddWithReshape);
759 readonly_model_ = input_model_->GetModel();
760 readonly_model_->UnPackTo(&model_);
761 }
762 };
763
TEST_F(QuantizeMultiInputAddWithReshapeTest,VerifyReshapeQuantization)764 TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) {
765 auto status = QuantizeModelAllOperators(
766 &builder_, &model_, TensorType_INT8, TensorType_INT8,
767 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
768
769 ASSERT_EQ(kTfLiteOk, status);
770
771 // Verify Reshape is quantized.
772 const auto& subgraph = model_.subgraphs[0];
773 auto op = subgraph->operators[1].get();
774 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
775 BuiltinOperator_RESHAPE);
776
777 ASSERT_EQ(op->inputs.size(), 2);
778 ASSERT_EQ(op->outputs.size(), 1);
779
780 auto float_graph = readonly_model_->subgraphs()->Get(0);
781 auto float_op = float_graph->operators()->Get(1);
782 ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(0))->type(),
783 TensorType_FLOAT32);
784 ASSERT_EQ(float_graph->tensors()->Get(float_op->outputs()->Get(0))->type(),
785 TensorType_FLOAT32);
786
787 EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
788 EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
789 auto float_input_quant_params =
790 float_graph->tensors()->Get(op->inputs[0])->quantization();
791 auto input_quant_params =
792 subgraph->tensors[op->inputs[0]]->quantization.get();
793 VerifyAsymmetricQuantizationScale(*float_input_quant_params,
794 *input_quant_params);
795
796 auto float_output_quant_params =
797 float_graph->tensors()->Get(float_op->outputs()->Get(0))->quantization();
798 auto output_quant_params =
799 subgraph->tensors[op->outputs[0]]->quantization.get();
800 ASSERT_EQ(float_output_quant_params->min()->size(), 1);
801 ASSERT_EQ(float_output_quant_params->max()->size(), 1);
802 ASSERT_EQ(output_quant_params->scale.size(), 1);
803
804 // check op and versioning.
805 EXPECT_EQ(model_.operator_codes.size(), 2);
806 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
807 BuiltinOperator_ADD);
808 EXPECT_EQ(model_.operator_codes[0]->version, 2);
809 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
810 BuiltinOperator_RESHAPE);
811 EXPECT_EQ(model_.operator_codes[1]->version, 1);
812 }
813
TEST_F(QuantizeMultiInputAddWithReshapeTest,VerifyAddQuantization)814 TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) {
815 auto status = QuantizeModelAllOperators(
816 &builder_, &model_, TensorType_INT8, TensorType_INT8,
817 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
818 ASSERT_EQ(kTfLiteOk, status);
819
820 // Verify ADD is quantized.
821 const auto& subgraph = model_.subgraphs[0];
822 auto op = subgraph->operators[0].get();
823 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
824 BuiltinOperator_ADD);
825
826 ASSERT_EQ(op->inputs.size(), 2);
827 ASSERT_EQ(op->outputs.size(), 1);
828
829 auto float_graph = readonly_model_->subgraphs()->Get(0);
830 auto float_op = float_graph->operators()->Get(0);
831 const int float_input0_idx = float_op->inputs()->Get(0);
832 const int float_input1_idx = float_op->inputs()->Get(1);
833 const int float_output_idx = float_op->outputs()->Get(0);
834 ASSERT_EQ(float_graph->tensors()->Get(float_input0_idx)->type(),
835 TensorType_FLOAT32);
836 ASSERT_EQ(float_graph->tensors()->Get(float_input1_idx)->type(),
837 TensorType_FLOAT32);
838 ASSERT_EQ(float_graph->tensors()->Get(float_output_idx)->type(),
839 TensorType_FLOAT32);
840
841 for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
842 EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
843 TensorType_INT8);
844 auto float_input_quant_params =
845 float_graph->tensors()
846 ->Get(float_op->inputs()->Get(input_idx))
847 ->quantization();
848 auto input_quant_params =
849 subgraph->tensors[op->inputs[input_idx]]->quantization.get();
850 VerifyAsymmetricQuantizationScale(*float_input_quant_params,
851 *input_quant_params);
852 }
853
854 EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
855 auto float_output_quant_params =
856 float_graph->tensors()->Get(op->outputs[0])->quantization();
857 auto output_quant_params =
858 subgraph->tensors[op->outputs[0]]->quantization.get();
859 ASSERT_EQ(float_output_quant_params->min()->size(), 1);
860 ASSERT_EQ(float_output_quant_params->max()->size(), 1);
861 ASSERT_EQ(output_quant_params->scale.size(), 1);
862
863 // check op and versioning.
864 EXPECT_EQ(model_.operator_codes.size(), 2);
865 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
866 BuiltinOperator_ADD);
867 EXPECT_EQ(model_.operator_codes[0]->version, 2);
868 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()),
869 BuiltinOperator_RESHAPE);
870 EXPECT_EQ(model_.operator_codes[1]->version, 1);
871 }
872
873 class QuantizeConstInputTest : public QuantizeModelTest,
874 public testing::WithParamInterface<TensorType> {
875 protected:
QuantizeConstInputTest()876 QuantizeConstInputTest() {
877 tensor_type_ = GetParam();
878 input_model_ = ReadModel(internal::kConstInputAddModel);
879 readonly_model_ = input_model_->GetModel();
880 readonly_model_->UnPackTo(&model_);
881 }
882
883 TensorType tensor_type_;
884 };
885 INSTANTIATE_TEST_SUITE_P(QuantizeConstInputTestInst, QuantizeConstInputTest,
886 testing::ValuesIn({TensorType_INT8}));
887
TEST_P(QuantizeConstInputTest,VerifyConstOpInput)888 TEST_P(QuantizeConstInputTest, VerifyConstOpInput) {
889 auto status =
890 QuantizeModelAllOperators(
891 &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
892 tensor_type_, &error_reporter_);
893 ASSERT_EQ(kTfLiteOk, status);
894
895 // Verify ConstOp is quantized.
896 const auto& subgraph = model_.subgraphs[0];
897 auto op = subgraph->operators[0].get();
898 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
899 BuiltinOperator_ADD);
900
901 ASSERT_EQ(op->inputs.size(), 2);
902 ASSERT_EQ(op->outputs.size(), 1);
903
904 auto float_graph = readonly_model_->subgraphs()->Get(0);
905 ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
906 TensorType_FLOAT32);
907 ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
908 TensorType_FLOAT32);
909
910 for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
911 EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
912 tensor_type_);
913 }
914
915 EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, tensor_type_);
916
917 // check op and versioning.
918 EXPECT_EQ(model_.operator_codes.size(), 1);
919 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
920 BuiltinOperator_ADD);
921 EXPECT_EQ(model_.operator_codes[0]->version, 2);
922 }
923
924 class QuantizeArgMaxTest : public QuantizeModelTest {
925 protected:
QuantizeArgMaxTest()926 QuantizeArgMaxTest() {
927 input_model_ = ReadModel(internal::kModelWithArgMaxOp);
928 readonly_model_ = input_model_->GetModel();
929 readonly_model_->UnPackTo(&model_);
930 }
931 };
932
TEST_F(QuantizeArgMaxTest,VerifyArgMax)933 TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
934 auto status = QuantizeModelAllOperators(
935 &builder_, &model_, TensorType_INT8, TensorType_INT8,
936 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
937 ASSERT_EQ(kTfLiteOk, status);
938
939 const auto& subgraph = model_.subgraphs[0];
940 auto op = subgraph->operators[0].get();
941 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
942 BuiltinOperator_ARG_MAX);
943
944 ASSERT_EQ(op->inputs.size(), 2);
945 ASSERT_EQ(op->outputs.size(), 1);
946
947 auto float_graph = readonly_model_->subgraphs()->Get(0);
948 auto float_op = float_graph->operators()->Get(0);
949 // Verify ArgMax input is quantized.
950 ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(0))->type(),
951 TensorType_FLOAT32);
952 EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
953
954 // Verify ArgMax input axis should still be the same type.
955 ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(1))->type(),
956 subgraph->tensors[op->inputs[1]].get()->type);
957
958 // The output of ArgMax should still be the same type.
959 ASSERT_EQ(float_graph->tensors()->Get(float_op->outputs()->Get(0))->type(),
960 subgraph->tensors[op->outputs[0]].get()->type);
961
962 // check op and versioning.
963 EXPECT_EQ(model_.operator_codes.size(), 1);
964 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
965 BuiltinOperator_ARG_MAX);
966 EXPECT_EQ(model_.operator_codes[0]->version, 2);
967 }
968
969 class QuantizeLSTMTest : public QuantizeModelTest {
970 protected:
QuantizeLSTMTest()971 QuantizeLSTMTest() {
972 input_model_ = ReadModel(internal::kLstmCalibrated);
973 readonly_model_ = input_model_->GetModel();
974 readonly_model_->UnPackTo(&model_);
975 }
976 };
977
TEST_F(QuantizeLSTMTest,VerifyLSTM)978 TEST_F(QuantizeLSTMTest, VerifyLSTM) {
979 // Quantize model.
980 auto status = QuantizeModelAllOperators(
981 &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, true,
982 TensorType_INT8, &error_reporter_);
983 ASSERT_EQ(kTfLiteOk, status);
984
985 // Read expected model.
986 auto expected_fb_model = ReadModel(internal::kLstmQuantized);
987 auto expected_read_only_model = expected_fb_model->GetModel();
988 ModelT expected_model;
989 expected_read_only_model->UnPackTo(&expected_model);
990
991 ExpectSameModels(model_, expected_model);
992 }
993
994 class QuantizeLSTM2Test : public QuantizeModelTest {
995 protected:
QuantizeLSTM2Test()996 QuantizeLSTM2Test() {
997 input_model_ = ReadModel(internal::kLstmCalibrated2);
998 readonly_model_ = input_model_->GetModel();
999 readonly_model_->UnPackTo(&model_);
1000 }
1001 };
1002
TEST_F(QuantizeLSTM2Test,VerifyLSTM)1003 TEST_F(QuantizeLSTM2Test, VerifyLSTM) {
1004 // Quantize model.
1005 auto status = QuantizeModelAllOperators(
1006 &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
1007 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
1008 ASSERT_EQ(kTfLiteOk, status);
1009
1010 // Read expected model.
1011 auto expected_fb_model = ReadModel(internal::kLstmQuantized2);
1012 auto expected_read_only_model = expected_fb_model->GetModel();
1013 ModelT expected_model;
1014 expected_read_only_model->UnPackTo(&expected_model);
1015
1016 ExpectSameModels(model_, expected_model);
1017 }
1018
1019 class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest {
1020 protected:
QuantizeUnidirectionalSequenceLSTMTest()1021 QuantizeUnidirectionalSequenceLSTMTest() {
1022 input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated);
1023 readonly_model_ = input_model_->GetModel();
1024 readonly_model_->UnPackTo(&model_);
1025 }
1026 };
1027
TEST_F(QuantizeUnidirectionalSequenceLSTMTest,VerifyUnidirectionalSequenceLSTM)1028 TEST_F(QuantizeUnidirectionalSequenceLSTMTest,
1029 VerifyUnidirectionalSequenceLSTM) {
1030 // Quantize model.
1031 auto status = QuantizeModelAllOperators(
1032 &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
1033 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
1034 ASSERT_EQ(kTfLiteOk, status);
1035
1036 // Read expected model.
1037 auto expected_fb_model =
1038 ReadModel(internal::kUnidirectionalSequenceLstmQuantized);
1039 auto expected_read_only_model = expected_fb_model->GetModel();
1040 ModelT expected_model;
1041 expected_read_only_model->UnPackTo(&expected_model);
1042
1043 ExpectSameModels(model_, expected_model);
1044 }
1045
1046 class QuantizeSVDFTest : public QuantizeModelTest {
1047 protected:
QuantizeSVDFTest()1048 QuantizeSVDFTest() {
1049 input_model_ = ReadModel(internal::kSvdfCalibrated);
1050 readonly_model_ = input_model_->GetModel();
1051 readonly_model_->UnPackTo(&model_);
1052 }
1053 };
1054
TEST_F(QuantizeSVDFTest,VerifySVDF)1055 TEST_F(QuantizeSVDFTest, VerifySVDF) {
1056 // Quantize model.
1057 auto status = QuantizeModelAllOperators(
1058 &builder_, &model_, TensorType_INT8, TensorType_INT8,
1059 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
1060 ASSERT_EQ(kTfLiteOk, status);
1061
1062 // Read expected model.
1063 auto expected_fb_model = ReadModel(internal::kSvdfQuantized);
1064 auto expected_read_only_model = expected_fb_model->GetModel();
1065 ModelT expected_model;
1066 expected_read_only_model->UnPackTo(&expected_model);
1067
1068 ExpectSameModels(model_, expected_model);
1069 }
1070
1071 class QuantizeFCTest : public QuantizeModelTest {
1072 protected:
QuantizeFCTest()1073 QuantizeFCTest() {
1074 input_model_ = ReadModel(internal::kModelWithFCOp);
1075 readonly_model_ = input_model_->GetModel();
1076 readonly_model_->UnPackTo(&model_);
1077 }
1078 };
1079
TEST_F(QuantizeFCTest,VerifyFC)1080 TEST_F(QuantizeFCTest, VerifyFC) {
1081 auto status = QuantizeModelAllOperators(
1082 &builder_, &model_, TensorType_INT8, TensorType_INT8,
1083 /*allow_float=*/false, TensorType_INT8, &error_reporter_);
1084 ASSERT_EQ(kTfLiteOk, status);
1085
1086 const auto& subgraph = model_.subgraphs[0];
1087 auto op = subgraph->operators[0].get();
1088 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1089 BuiltinOperator_FULLY_CONNECTED);
1090
1091 ASSERT_EQ(op->inputs.size(), 3);
1092 ASSERT_EQ(op->outputs.size(), 1);
1093
1094 auto float_graph = readonly_model_->subgraphs()->Get(0);
1095 // Verify FC input and weight is quantized.
1096 auto float_op = float_graph->operators()->Get(0);
1097 ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(0))->type(),
1098 TensorType_FLOAT32);
1099 EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
1100 ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(1))->type(),
1101 TensorType_FLOAT32);
1102 EXPECT_EQ(subgraph->tensors[op->inputs[1]].get()->type, TensorType_INT8);
1103
1104 // Verify FC bias should be int32 quantized.
1105 ASSERT_EQ(float_graph->tensors()->Get(float_op->inputs()->Get(2))->type(),
1106 TensorType_FLOAT32);
1107 EXPECT_EQ(subgraph->tensors[op->inputs[2]].get()->type, TensorType_INT32);
1108
1109 // The output of FC should be quantized.
1110 ASSERT_EQ(float_graph->tensors()->Get(float_op->outputs()->Get(0))->type(),
1111 TensorType_FLOAT32);
1112 EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
1113
1114 // check op and versioning.
1115 EXPECT_EQ(model_.operator_codes.size(), 1);
1116 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()),
1117 BuiltinOperator_FULLY_CONNECTED);
1118 EXPECT_EQ(model_.operator_codes[0]->version, 5);
1119 }
1120
1121 class QuantizeCustomOpTest
1122 : public QuantizeModelTest,
1123 public ::testing::WithParamInterface<tflite::TensorType> {
1124 protected:
QuantizeCustomOpTest()1125 QuantizeCustomOpTest() {
1126 input_model_ = ReadModel(internal::kModelMixed);
1127 readonly_model_ = input_model_->GetModel();
1128 readonly_model_->UnPackTo(&model_);
1129 }
1130 };
1131
TEST_P(QuantizeCustomOpTest,VerifyMixedQuantization)1132 TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) {
1133 auto status = QuantizeModelAllOperators(
1134 &builder_, &model_, GetParam(), GetParam(),
1135 /*allow_float=*/true, GetParam(), &error_reporter_);
1136 ASSERT_EQ(kTfLiteOk, status);
1137 const auto& subgraph = model_.subgraphs[0];
1138 auto float_graph = readonly_model_->subgraphs()->Get(0);
1139 // The original model reshape->custom->custom->squeeze.
1140 ASSERT_EQ(float_graph->operators()->size(), 4);
1141 // The resulting model should be:
1142 // reshape->dequantize->custom->custom->quantize->squeeze.
1143 ASSERT_EQ(subgraph->operators.size(), 6);
1144 const std::vector<BuiltinOperator> op_codes = {
1145 BuiltinOperator_RESHAPE, BuiltinOperator_DEQUANTIZE,
1146 BuiltinOperator_CUSTOM, BuiltinOperator_CUSTOM,
1147 BuiltinOperator_QUANTIZE, BuiltinOperator_SQUEEZE};
1148 const std::vector<TensorType> op_input_types = {
1149 GetParam(), GetParam(), TensorType_FLOAT32,
1150 TensorType_FLOAT32, TensorType_FLOAT32, GetParam()};
1151 for (int i = 0; i < subgraph->operators.size(); ++i) {
1152 OperatorT* op = subgraph->operators[i].get();
1153 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1154 op_codes[i]);
1155 ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]);
1156 }
1157 }
1158
1159 INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest,
1160 ::testing::Values(TensorType_INT8));
1161
1162 class QuantizePackTest : public QuantizeModelTest {
1163 protected:
QuantizePackTest()1164 QuantizePackTest() {
1165 input_model_ = ReadModel(internal::kModelPack);
1166 readonly_model_ = input_model_->GetModel();
1167 readonly_model_->UnPackTo(&model_);
1168 }
1169 };
1170
TEST_F(QuantizePackTest,VerifyPack)1171 TEST_F(QuantizePackTest, VerifyPack) {
1172 auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1173
1174 ASSERT_EQ(kTfLiteOk, status);
1175
1176 const auto subgraph = model_.subgraphs[0].get();
1177
1178 // The model should only have 3 inputs and 1 output.
1179 EXPECT_EQ(subgraph->inputs.size(), 3);
1180 EXPECT_EQ(subgraph->outputs.size(), 1);
1181
1182 const auto& op1 = subgraph->operators[1].get();
1183 const auto& op2 = subgraph->operators[2].get();
1184 const auto& op3 = subgraph->operators[3].get();
1185 const auto& op4 = subgraph->operators[4].get();
1186
1187 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op1->opcode_index].get()),
1188 BuiltinOperator_QUANTIZE);
1189 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op2->opcode_index].get()),
1190 BuiltinOperator_QUANTIZE);
1191 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op3->opcode_index].get()),
1192 BuiltinOperator_PACK);
1193 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op4->opcode_index].get()),
1194 BuiltinOperator_DEQUANTIZE);
1195
1196 const auto& pack_input0 = subgraph->tensors[op3->inputs[0]].get();
1197 const auto& pack_input1 = subgraph->tensors[op3->inputs[1]].get();
1198 const auto& pack_input2 = subgraph->tensors[op3->inputs[2]].get();
1199
1200 const auto& pack_output = subgraph->tensors[op3->outputs[0]].get();
1201
1202 // Check quantization parameters for input and output.
1203 EXPECT_FLOAT_EQ(pack_input0->quantization->scale[0],
1204 pack_input1->quantization->scale[0]);
1205 EXPECT_FLOAT_EQ(pack_input1->quantization->scale[0],
1206 pack_input2->quantization->scale[0]);
1207 EXPECT_FLOAT_EQ(pack_input0->quantization->zero_point[0],
1208 pack_input1->quantization->zero_point[0]);
1209 EXPECT_FLOAT_EQ(pack_input1->quantization->zero_point[0],
1210 pack_input2->quantization->zero_point[0]);
1211
1212 EXPECT_FLOAT_EQ(pack_input1->quantization->scale[0],
1213 pack_output->quantization->scale[0]);
1214 EXPECT_FLOAT_EQ(pack_input1->quantization->zero_point[0],
1215 pack_output->quantization->zero_point[0]);
1216
1217 // Check type of input and output.
1218 EXPECT_EQ(pack_output->type, TensorType_INT8);
1219 EXPECT_EQ(pack_input0->type, TensorType_INT8);
1220 EXPECT_EQ(pack_input1->type, TensorType_INT8);
1221 EXPECT_EQ(pack_input2->type, TensorType_INT8);
1222 }
1223
1224 class QuantizeMinimumMaximumTest
1225 : public QuantizeModelTest,
1226 public testing::WithParamInterface<const char*> {
1227 protected:
QuantizeMinimumMaximumTest()1228 QuantizeMinimumMaximumTest() {
1229 input_model_ = ReadModel(GetParam());
1230 readonly_model_ = input_model_->GetModel();
1231 readonly_model_->UnPackTo(&model_);
1232 }
1233 };
1234
TEST_P(QuantizeMinimumMaximumTest,VerifyMinimumMaximum)1235 TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) {
1236 auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1237 ASSERT_EQ(kTfLiteOk, status);
1238 const auto& subgraph = model_.subgraphs[0];
1239 // Check that the first op is Quantize and the last is Dequant.
1240 const auto& quant_op = subgraph->operators[0];
1241 const auto& dequant_op = subgraph->operators[subgraph->operators.size() - 1];
1242 const int32_t quant_idx = quant_op->opcode_index;
1243 const int32_t dequant_idx = dequant_op->opcode_index;
1244 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[quant_idx].get()),
1245 BuiltinOperator_QUANTIZE);
1246 EXPECT_EQ(GetBuiltinCode(model_.operator_codes[dequant_idx].get()),
1247 BuiltinOperator_DEQUANTIZE);
1248
1249 const auto& op = subgraph->operators[1].get();
1250
1251 // Check that we have MINIMUM or MAXIMUM operator.
1252 auto op_builtin_code =
1253 GetBuiltinCode(model_.operator_codes[op->opcode_index].get());
1254 ASSERT_TRUE(op_builtin_code == tflite::BuiltinOperator_MINIMUM ||
1255 op_builtin_code == tflite::BuiltinOperator_MAXIMUM);
1256
1257 // Check that we have two inputs and one output.
1258 ASSERT_EQ(op->inputs.size(), 2);
1259 ASSERT_EQ(op->outputs.size(), 1);
1260
1261 // Check that all is quantized.
1262 auto output = subgraph->tensors[op->outputs[0]].get();
1263 auto input1 = subgraph->tensors[op->inputs[0]].get();
1264 auto input2 = subgraph->tensors[op->inputs[1]].get();
1265
1266 EXPECT_EQ(output->type, TensorType_INT8);
1267 EXPECT_EQ(input1->type, TensorType_INT8);
1268 EXPECT_EQ(input2->type, TensorType_INT8);
1269
1270 // Check if the quantization params of the minimum/maximum inputs match
1271 // after requantization
1272 EXPECT_EQ(input1->quantization->scale, input2->quantization->scale);
1273 EXPECT_EQ(input1->quantization->zero_point, input2->quantization->zero_point);
1274
1275 // Check the input quantization params match the output ones.
1276 EXPECT_EQ(output->quantization->scale, input1->quantization->scale);
1277 EXPECT_EQ(output->quantization->zero_point, input1->quantization->zero_point);
1278 EXPECT_EQ(output->quantization->scale, input2->quantization->scale);
1279 EXPECT_EQ(output->quantization->zero_point, input2->quantization->zero_point);
1280 }
1281
1282 INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest,
1283 testing::ValuesIn({internal::kModelWithMinimumOp,
1284 internal::kModelWithMaximumOp}));
1285
1286 class QuantizeUnpackTest : public QuantizeModelTest {
1287 protected:
QuantizeUnpackTest()1288 QuantizeUnpackTest() {
1289 input_model_ = ReadModel(internal::kModelWithUnpack);
1290 readonly_model_ = input_model_->GetModel();
1291 readonly_model_->UnPackTo(&model_);
1292 }
1293 };
1294
TEST_F(QuantizeUnpackTest,VerifyUnpack)1295 TEST_F(QuantizeUnpackTest, VerifyUnpack) {
1296 auto status = QuantizeModel(&builder_, &model_, &error_reporter_);
1297
1298 ASSERT_EQ(kTfLiteOk, status);
1299
1300 const auto subgraph = model_.subgraphs[0].get();
1301 auto op = subgraph->operators[1].get();
1302
1303 auto float_graph = readonly_model_->subgraphs()->Get(0);
1304
1305 ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()),
1306 BuiltinOperator_UNPACK);
1307
1308 // Get unpack input and output tensors
1309 auto unpack_input = subgraph->tensors[op->inputs[0]].get();
1310 auto unpack_output_0 = subgraph->tensors[op->outputs[0]].get();
1311 auto unpack_output_1 = subgraph->tensors[op->outputs[1]].get();
1312
1313 // Verify Unpack input is quantized.
1314 ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
1315 TensorType_FLOAT32);
1316 EXPECT_EQ(unpack_input->type, TensorType_INT8);
1317
1318 // The model should only have one input and 2 outputs.
1319 EXPECT_EQ(subgraph->inputs.size(), 1);
1320 EXPECT_EQ(subgraph->outputs.size(), 2);
1321
1322 // Ensure quantization parameters before and after unpack
1323 // are preserved after quantization for all outputs of
1324 // unpack.
1325 EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0],
1326 unpack_output_0->quantization->scale[0]);
1327 EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0],
1328 unpack_output_1->quantization->scale[0]);
1329 EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0],
1330 unpack_output_0->quantization->zero_point[0]);
1331 EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0],
1332 unpack_output_1->quantization->zero_point[0]);
1333 }
1334
1335 class QuantizeBroadcastToModelTest
1336 : public QuantizeModelTest,
1337 public testing::WithParamInterface<TensorType> {
1338 protected:
QuantizeBroadcastToModelTest()1339 QuantizeBroadcastToModelTest() {
1340 tensor_type_ = GetParam();
1341 input_model_ = ReadModel(internal::kModelWithBroadcastToOp);
1342 readonly_model_ = input_model_->GetModel();
1343 readonly_model_->UnPackTo(&model_);
1344 }
1345 TensorType tensor_type_;
1346 };
1347
1348 INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst,
1349 QuantizeBroadcastToModelTest,
1350 testing::ValuesIn({TensorType_INT8}));
1351
TEST_P(QuantizeBroadcastToModelTest,VerifyBroadcastToQuantization)1352 TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
1353 auto status = QuantizeModelAllOperators(
1354 &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
1355 tensor_type_, &error_reporter_);
1356 EXPECT_EQ(status, kTfLiteOk);
1357
1358 // There is only one subgraph.
1359 const int32_t subgraph_idx = 0;
1360 const auto& subgraph = model_.subgraphs[subgraph_idx];
1361 const auto& readonly_subgraph =
1362 readonly_model_->subgraphs()->Get(subgraph_idx);
1363
1364 // There should be a single broadcast_to op.
1365 EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1366 EXPECT_EQ(subgraph->operators.size(), 1);
1367 const auto& broadcast_to = subgraph->operators[0];
1368 EXPECT_EQ(model_.operator_codes[broadcast_to->opcode_index]->builtin_code,
1369 BuiltinOperator_BROADCAST_TO);
1370
1371 // There should be 3 tensors: input, output, and BroadcastTo/shape.
1372 EXPECT_EQ(subgraph->tensors.size(), 3);
1373
1374 // Input Tensor
1375 EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
1376 EXPECT_EQ(subgraph->tensors[0]->name, "input_1");
1377 EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
1378 EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
1379
1380 // Output Tensor. The name given in the generated
1381 // .bin test file is 'Identity' and should be preserved
1382 EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
1383 EXPECT_EQ(subgraph->tensors[2]->name, "Identity");
1384 EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
1385 EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
1386
1387 // The BroadCastTo shape is of type INT32 and should not be quantized
1388 EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
1389 EXPECT_EQ(subgraph->tensors[1]->name,
1390 "model/tf.broadcast_to/BroadcastTo/shape");
1391 EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1392 EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1393
1394 // check op and versioning.
1395 EXPECT_EQ(model_.operator_codes.size(), 1);
1396 EXPECT_EQ(model_.operator_codes[0]->builtin_code,
1397 BuiltinOperator_BROADCAST_TO);
1398 EXPECT_EQ(model_.operator_codes[0]->version, 3);
1399 }
1400
1401 class QuantizeGatherNDModelTest
1402 : public QuantizeModelTest,
1403 public testing::WithParamInterface<TensorType> {
1404 protected:
QuantizeGatherNDModelTest()1405 QuantizeGatherNDModelTest() {
1406 tensor_type_ = GetParam();
1407 input_model_ = ReadModel(internal::kModelWithGatherNDOp);
1408 readonly_model_ = input_model_->GetModel();
1409 readonly_model_->UnPackTo(&model_);
1410 }
1411
1412 TensorType tensor_type_;
1413 };
1414
1415 INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst,
1416 QuantizeGatherNDModelTest,
1417 testing::ValuesIn({TensorType_INT8}));
1418
TEST_P(QuantizeGatherNDModelTest,QuantizeGatherND)1419 TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) {
1420 auto status = QuantizeModelAllOperators(
1421 &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false,
1422 tensor_type_, &error_reporter_);
1423 EXPECT_EQ(status, kTfLiteOk);
1424
1425 // There is only one subgraph.
1426 const int32_t subgraph_idx = 0;
1427 const auto& subgraph = model_.subgraphs[subgraph_idx];
1428 const auto& readonly_subgraph =
1429 readonly_model_->subgraphs()->Get(subgraph_idx);
1430
1431 // There should be a single gather_nd op.
1432 EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1433 EXPECT_EQ(subgraph->operators.size(), 1);
1434 const auto& gather_nd = subgraph->operators[0];
1435 EXPECT_EQ(model_.operator_codes[gather_nd->opcode_index]->builtin_code,
1436 BuiltinOperator_GATHER_ND);
1437
1438 // There should be 3 tensors: input, output, and indices.
1439 EXPECT_EQ(subgraph->tensors.size(), 3);
1440
1441 // Input Tensor
1442 EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
1443 EXPECT_EQ(subgraph->tensors[0]->name, "input");
1444 EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
1445 EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
1446
1447 // Output Tensor
1448 EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
1449 EXPECT_EQ(subgraph->tensors[2]->name, "output");
1450 EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
1451 EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
1452
1453 // The gather indices are of type INT32 and should not be quantized
1454 EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
1455 EXPECT_EQ(subgraph->tensors[1]->name, "indices");
1456 EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1457 EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1458
1459 // Check op and versioning.
1460 EXPECT_EQ(model_.operator_codes.size(), 1);
1461 EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_GATHER_ND);
1462 EXPECT_EQ(model_.operator_codes[0]->version, 1);
1463 }
1464
1465 class QuantizeWhereModelTest : public QuantizeModelTest {
1466 protected:
QuantizeWhereModelTest()1467 QuantizeWhereModelTest() {
1468 input_model_ = ReadModel(internal::kModelWithWhereOp);
1469 readonly_model_ = input_model_->GetModel();
1470 readonly_model_->UnPackTo(&model_);
1471 }
1472 };
1473
TEST_F(QuantizeWhereModelTest,QuantizeWhere)1474 TEST_F(QuantizeWhereModelTest, QuantizeWhere) {
1475 // Where operator takes a BOOL tensor as input
1476 // and outputs INT64 indices, both of which
1477 // should not be quantized
1478 auto status = QuantizeModel(&builder_, &model_, TensorType_BOOL,
1479 TensorType_INT64, &error_reporter_);
1480 EXPECT_EQ(status, kTfLiteOk);
1481
1482 // There is only one subgraph.
1483 const int32_t subgraph_idx = 0;
1484 const auto& subgraph = model_.subgraphs[subgraph_idx];
1485 const auto& readonly_subgraph =
1486 readonly_model_->subgraphs()->Get(subgraph_idx);
1487
1488 // There should be a single where op.
1489 EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
1490 EXPECT_EQ(subgraph->operators.size(), 1);
1491 const auto& where = subgraph->operators[0];
1492 EXPECT_EQ(model_.operator_codes[where->opcode_index]->builtin_code,
1493 BuiltinOperator_WHERE);
1494
1495 // There should be 2 tensors: input and output.
1496 EXPECT_EQ(subgraph->tensors.size(), 2);
1497
1498 // Testing input tensor type and ensuring it
1499 // was not quantized
1500 EXPECT_EQ(subgraph->tensors[0]->type, TensorType_BOOL);
1501 EXPECT_EQ(subgraph->tensors[0]->name, "input");
1502 EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 0);
1503 EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 0);
1504
1505 // Testing output (indices) tensor type and ensuring it
1506 // was not quantized
1507 EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT64);
1508 EXPECT_EQ(subgraph->tensors[1]->name, "indices");
1509 EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
1510 EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
1511
1512 // check op and versioning.
1513 EXPECT_EQ(model_.operator_codes.size(), 1);
1514 EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_WHERE);
1515 EXPECT_EQ(model_.operator_codes[0]->version, 1);
1516 }
1517
1518 } // namespace
1519 } // namespace optimize
1520 } // namespace tflite
1521
main(int argc,char ** argv)1522 int main(int argc, char** argv) {
1523 tensorflow::string model_file;
1524 const std::vector<tensorflow::Flag> flag_list = {
1525 tensorflow::Flag("test_model_file", &model_file,
1526 "Path to test tflite model file."),
1527 };
1528
1529 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
1530 if (!parse_result) {
1531 std::cerr << "Required test_model_file\n";
1532 std::abort();
1533 }
1534 g_test_model_dir =
1535 new tensorflow::string(tensorflow::io::Dirname(model_file));
1536 ::tensorflow::port::InitMain(argv[0], &argc, &argv);
1537 return RUN_ALL_TESTS();
1538 }
1539