1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lite.py.""" 16 17import io 18import logging 19import os 20import tempfile 21 22from absl.testing import parameterized 23import numpy as np 24from tensorflow import keras 25 26from tensorflow.lite.python import conversion_metadata_schema_py_generated as metadata_fb 27from tensorflow.lite.python import lite 28from tensorflow.lite.python import lite_constants 29from tensorflow.lite.python import schema_py_generated as schema_fb 30from tensorflow.lite.python import util 31from tensorflow.lite.python.convert import ConverterError 32from tensorflow.lite.python.convert import mlir_quantize 33from tensorflow.lite.python.interpreter import Interpreter 34from tensorflow.lite.python.util import get_conversion_metadata 35from tensorflow.python.client import session 36from tensorflow.python.eager import context 37from tensorflow.python.eager import def_function 38from tensorflow.python.framework import constant_op 39from tensorflow.python.framework import convert_to_constants 40from tensorflow.python.framework import dtypes 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import test_util 43from tensorflow.python.framework import versions 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import gen_array_ops 46from tensorflow.python.ops import logging_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import nn_ops 49from tensorflow.python.ops import random_ops 50from tensorflow.python.ops import variable_scope 51from tensorflow.python.ops import variables 52from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer 53from tensorflow.python.platform import gfile 54from tensorflow.python.platform import resource_loader 55from tensorflow.python.platform import test 56from tensorflow.python.saved_model import saved_model 57from tensorflow.python.training.training_util import write_graph 58 59 60class LiteTest(test_util.TensorFlowTestCase): 61 """Base class of all the tests in this module.""" 62 63 64class TestModels(LiteTest): 65 66 def assertValidDebugInfo(self, debug_info): 67 """Verify the DebugInfo is valid.""" 68 file_names = set() 69 for file_path in debug_info.files: 70 file_names.add(os.path.basename(file_path)) 71 # To make the test independent on how the nodes are created, we only assert 72 # the name of this test file. 73 self.assertIn('lite_test.py', file_names) 74 self.assertNotIn('lite_v2_test.py', file_names) 75 76 77class FromConstructor(TestModels): 78 79 # Tests invalid constructors using a dummy value for the GraphDef. 80 def testInvalidConstructor(self): 81 message = ( 82 'If input_tensors and output_tensors are None, both ' 83 'input_arrays_with_shape and output_arrays|control_output_arrays must ' 84 'be defined.') 85 86 # `output_arrays` is not defined. 87 with self.assertRaises(ValueError) as error: 88 lite.TFLiteConverter( 89 None, None, [], input_arrays_with_shape=[('input', [3, 90 9])]).convert() 91 self.assertEqual(message, str(error.exception)) 92 93 # `input_arrays_with_shape` is not defined. 94 with self.assertRaises(ValueError) as error: 95 lite.TFLiteConverter(None, [], None, output_arrays=['output']).convert() 96 self.assertEqual(message, str(error.exception)) 97 98 # Tests valid constructors using a dummy value for the GraphDef. 99 def testValidConstructor(self): 100 converter = lite.TFLiteConverter( 101 None, 102 None, 103 None, 104 input_arrays_with_shape=[('input', [3, 9])], 105 output_arrays=['output']) 106 self.assertFalse(converter._has_valid_tensors()) 107 self.assertEqual(converter.get_input_arrays(), ['input']) 108 109 with self.assertRaises(ValueError) as error: 110 converter._set_batch_size(1) 111 self.assertEqual( 112 'The batch size cannot be set for this model. Please use ' 113 'input_shapes parameter.', str(error.exception)) 114 115 converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor']) 116 self.assertTrue(converter._has_valid_tensors()) 117 118 def testRedundantArgumentsWarning(self): 119 """Test if the warning message when there are redundant arguments.""" 120 with ops.Graph().as_default(): 121 in_tensor = array_ops.placeholder( 122 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 123 out_tensor = math_ops.add(in_tensor, in_tensor, name='add') 124 sess = session.Session() 125 126 frozen_graph_def = ( 127 convert_to_constants.convert_variables_to_constants_from_session_graph( 128 sess, sess.graph_def, ['add'])) 129 130 # Convert model and ensure model is not None. 131 log = io.StringIO() 132 handler = logging.StreamHandler(log) 133 logging.root.addHandler(handler) 134 converter = lite.TFLiteConverter(frozen_graph_def, [in_tensor], 135 [out_tensor], 136 [('in_tensor', [2, 16, 16, 3])], ['add']) 137 138 input_warning_message = 'input_arrays_with_shape will be ignored' 139 output_warning_message = 'output_arrays will be ignored' 140 141 # Convert model and ensure model is not None. 142 tflite_model = converter.convert() 143 self.assertIsNotNone(tflite_model) 144 self.assertIn(input_warning_message, log.getvalue()) 145 self.assertIn(output_warning_message, log.getvalue()) 146 logging.root.removeHandler(handler) 147 148 def testShapeOverriding(self): 149 """Test a shape overriding case via the constructor.""" 150 with ops.Graph().as_default(): 151 in_tensor = array_ops.placeholder( 152 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 153 math_ops.add(in_tensor, in_tensor, name='add') 154 sess = session.Session() 155 156 frozen_graph_def = ( 157 convert_to_constants.convert_variables_to_constants_from_session_graph( 158 sess, sess.graph_def, ['add'])) 159 160 # Convert model and ensure model is not None. 161 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 162 [('in_tensor', [2, 16, 16, 3])], ['add']) 163 tflite_model = converter.convert() 164 self.assertIsNotNone(tflite_model) 165 166 # Check values from converted model. 167 interpreter = Interpreter(model_content=tflite_model) 168 interpreter.allocate_tensors() 169 170 input_details = interpreter.get_input_details() 171 self.assertLen(input_details, 1) 172 self.assertEqual('in_tensor', input_details[0]['name']) 173 self.assertEqual(np.float32, input_details[0]['dtype']) 174 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 175 self.assertEqual((0., 0.), input_details[0]['quantization']) 176 177 output_details = interpreter.get_output_details() 178 self.assertLen(output_details, 1) 179 self.assertEqual('add', output_details[0]['name']) 180 self.assertEqual(np.float32, output_details[0]['dtype']) 181 self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape']) 182 self.assertEqual((0., 0.), output_details[0]['quantization']) 183 184 def testPartialShapeOverriding(self): 185 """Test a partial shape overriding case via the constructor.""" 186 with ops.Graph().as_default(): 187 in_tensor_a = array_ops.placeholder( 188 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_a') 189 in_tensor_b = array_ops.placeholder( 190 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_b') 191 math_ops.add(in_tensor_a, in_tensor_b, name='add') 192 sess = session.Session() 193 194 frozen_graph_def = ( 195 convert_to_constants.convert_variables_to_constants_from_session_graph( 196 sess, sess.graph_def, ['add'])) 197 198 # Convert model and ensure model is not None. 199 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 200 [('in_tensor_a', [2, 16, 16, 3])], ['add']) 201 # There is an unhandled Placeholder op. 202 with self.assertRaises(ConverterError): 203 converter.convert() 204 205 def testInvalidShapeOverriding(self): 206 """Test an invalid shape overriding case via the constructor.""" 207 with ops.Graph().as_default(): 208 in_tensor = array_ops.placeholder( 209 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 210 math_ops.add(in_tensor, in_tensor, name='add') 211 sess = session.Session() 212 213 frozen_graph_def = ( 214 convert_to_constants.convert_variables_to_constants_from_session_graph( 215 sess, sess.graph_def, ['add'])) 216 217 # Convert model and ensure model is not None. 218 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 219 [('wrong_tensor', [2, 16, 16, 3])], 220 ['add']) 221 with self.assertRaises(ConverterError): 222 converter.convert() 223 224 225class FromSessionTest(TestModels, parameterized.TestCase): 226 227 def testFloatModel(self): 228 with ops.Graph().as_default(): 229 in_tensor = array_ops.placeholder( 230 shape=[1, 16, 16, 3], dtype=dtypes.float32) 231 out_tensor = in_tensor + in_tensor 232 sess = session.Session() 233 234 # Convert model and ensure model is not None. 235 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 236 [out_tensor]) 237 tflite_model = converter.convert() 238 self.assertIsNotNone(tflite_model) 239 240 # Check values from converted model. 241 interpreter = Interpreter(model_content=tflite_model) 242 interpreter.allocate_tensors() 243 244 input_details = interpreter.get_input_details() 245 self.assertLen(input_details, 1) 246 self.assertEqual('Placeholder', input_details[0]['name']) 247 self.assertEqual(np.float32, input_details[0]['dtype']) 248 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 249 self.assertEqual((0., 0.), input_details[0]['quantization']) 250 251 output_details = interpreter.get_output_details() 252 self.assertLen(output_details, 1) 253 self.assertEqual('add', output_details[0]['name']) 254 self.assertEqual(np.float32, output_details[0]['dtype']) 255 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 256 self.assertEqual((0., 0.), output_details[0]['quantization']) 257 258 def testFloatModelQuantizedInput(self): 259 with ops.Graph().as_default(): 260 in_tensor = array_ops.placeholder( 261 shape=[1, 16, 16, 3], dtype=dtypes.float32) 262 out_tensor = in_tensor + in_tensor 263 sess = session.Session() 264 265 # Convert model and ensure model is not None. 266 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 267 [out_tensor]) 268 converter.inference_input_type = dtypes.uint8 269 converter.inference_type = dtypes.float32 270 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 271 tflite_model = converter.convert() 272 self.assertIsNotNone(tflite_model) 273 274 # Check values from converted model. 275 interpreter = Interpreter(model_content=tflite_model) 276 interpreter.allocate_tensors() 277 278 input_details = interpreter.get_input_details() 279 self.assertLen(input_details, 1) 280 self.assertEqual('Placeholder', input_details[0]['name']) 281 self.assertEqual(np.uint8, input_details[0]['dtype']) 282 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 283 self.assertEqual((1., 0.), input_details[0]['quantization']) 284 285 output_details = interpreter.get_output_details() 286 self.assertLen(output_details, 1) 287 self.assertEqual('add', output_details[0]['name']) 288 self.assertEqual(np.float32, output_details[0]['dtype']) 289 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 290 self.assertEqual((0., 0.), output_details[0]['quantization']) # float 291 292 def testForgottenCallToAllocateTensors(self): 293 with ops.Graph().as_default(): 294 in_tensor = array_ops.placeholder( 295 shape=[1, 16, 16, 3], dtype=dtypes.float32) 296 out_tensor = in_tensor + in_tensor 297 sess = session.Session() 298 # Convert model and ensure model is not None. 299 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 300 [out_tensor]) 301 tflite_model = converter.convert() 302 self.assertIsNotNone(tflite_model) 303 304 # Check values from converted model. 305 interpreter = Interpreter(model_content=tflite_model) 306 input_index = interpreter.get_input_details()[0]['index'] 307 dummy_tensor = np.ones(shape=[1, 16, 16, 3], dtype=np.float32) 308 with self.assertRaises(ValueError): 309 interpreter.set_tensor(input_index, dummy_tensor) 310 311 @parameterized.named_parameters( 312 ('_INT8InputOutput', False, False, dtypes.int8), 313 ('_UINT8InputOutput', False, False, dtypes.uint8), 314 ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16), 315 ('_IntOnly_INT8InputOutput', True, False, dtypes.int8), 316 ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8), 317 ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16), 318 ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True), 319 ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True)) 320 def testIntegerQuantizationWithUnsupportedOps(self, 321 is_int_only, 322 is_int16_quantize, 323 inference_input_output_type, 324 enable_mlir_quantizer=False): 325 with ops.Graph().as_default(): 326 in_tensor_a = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 327 in_tensor_b = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 328 # ceil kernel does not support int8 nor int16 types neither. 329 left = math_ops.ceil(in_tensor_a) 330 out_tensor_b = math_ops.tanh(in_tensor_b) 331 add = math_ops.add(left, out_tensor_b) 332 # ceil kernel does not support int8 nor int16 types neither. 333 out_tensor_a = math_ops.ceil(add) 334 sess = session.Session() 335 336 def calibration_gen(): 337 for _ in range(5): 338 yield [ 339 np.random.uniform(-1, 1, size=(3)).astype(np.float32), 340 np.random.uniform(-1, 1, size=(3)).astype(np.float32) 341 ] 342 343 quantized_converter = lite.TFLiteConverter.from_session( 344 sess, [in_tensor_a, in_tensor_b], [out_tensor_a, out_tensor_b]) 345 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 346 quantized_converter.representative_dataset = calibration_gen 347 if is_int_only: 348 if is_int16_quantize: 349 quantized_converter.target_spec.supported_ops = [ 350 lite.OpsSet 351 .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, 352 lite.OpsSet.TFLITE_BUILTINS 353 ] 354 else: 355 quantized_converter.target_spec.supported_ops = [ 356 lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS 357 ] 358 else: 359 if is_int16_quantize: 360 quantized_converter.target_spec.supported_ops = [ 361 lite.OpsSet 362 .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, 363 lite.OpsSet.TFLITE_BUILTINS 364 ] 365 else: 366 quantized_converter.target_spec.supported_ops = [ 367 lite.OpsSet.TFLITE_BUILTINS 368 ] 369 370 quantized_converter.inference_input_type = inference_input_output_type 371 quantized_converter.inference_output_type = inference_input_output_type 372 quantized_converter.experimental_new_quantizer = enable_mlir_quantizer 373 quantized_tflite_model = quantized_converter.convert() 374 self.assertIsNotNone(quantized_tflite_model) 375 376 expected_dtype = inference_input_output_type.as_numpy_dtype 377 # Allow float32 for fallback on non-quantizable op. 378 expected_ceil_dtype = ( 379 expected_dtype if enable_mlir_quantizer else dtypes.float32) 380 381 interpreter = Interpreter(model_content=quantized_tflite_model) 382 interpreter.allocate_tensors() 383 input_details = interpreter.get_input_details() 384 self.assertLen(input_details, 2) 385 self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype) 386 self.assertEqual(input_details[1]['dtype'], expected_dtype) 387 output_details = interpreter.get_output_details() 388 self.assertLen(output_details, 2) 389 self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype) 390 self.assertEqual(output_details[1]['dtype'], expected_dtype) 391 392 @parameterized.named_parameters( 393 ('_PerChannelQuant', False, False), 394 ('_PerChannelMlirQuant', False, True), 395 ('_PerTensorQuant', True, False), 396 ('_PerTensorMlirQuant', True, True), 397 ('_PerChannelMlirDynamicRangeQuant', False, False, False), 398 ('_PerTensorMlirDynamicRangeQuant', True, False, False)) 399 def testDisablePerChannelQuantization(self, 400 disable_per_channel=False, 401 enable_mlir_quantizer=False, 402 representative_dataset=True): 403 k_conv_name = 'Conv2D1' 404 # Dynamic range quant requires total num elements of filters > 1024. 405 k_num_filters = 38 406 with ops.Graph().as_default(): 407 inp, output, calibration_gen = self._getIntegerQuantizeModel( 408 k_num_filters) 409 sess = session.Session() 410 411 quantized_converter = lite.TFLiteConverter.from_session( 412 sess, [inp], [output]) 413 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 414 if representative_dataset: 415 quantized_converter.representative_dataset = calibration_gen 416 quantized_converter.experimental_new_quantizer = enable_mlir_quantizer 417 if disable_per_channel: 418 quantized_converter._experimental_disable_per_channel = ( 419 disable_per_channel) 420 quantized_tflite_model = quantized_converter.convert() 421 self.assertIsNotNone(quantized_tflite_model) 422 423 interpreter = Interpreter(model_content=quantized_tflite_model) 424 interpreter.allocate_tensors() 425 detail = next((d for d in interpreter.get_tensor_details() 426 if d['name'] == k_conv_name)) 427 quant_params = detail['quantization_parameters'] 428 expected_num_params = 1 if disable_per_channel else k_num_filters 429 self.assertLen(quant_params['scales'], expected_num_params) 430 self.assertLen(quant_params['zero_points'], expected_num_params) 431 432 def testString(self): 433 with ops.Graph().as_default(): 434 in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string) 435 out_tensor = array_ops.reshape(in_tensor, shape=[2, 2]) 436 sess = session.Session() 437 438 # Convert model and ensure model is not None. 439 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 440 [out_tensor]) 441 tflite_model = converter.convert() 442 self.assertIsNotNone(tflite_model) 443 444 # Check values from converted model. 445 interpreter = Interpreter(model_content=tflite_model) 446 interpreter.allocate_tensors() 447 448 input_details = interpreter.get_input_details() 449 self.assertLen(input_details, 1) 450 self.assertEqual('Placeholder', input_details[0]['name']) 451 self.assertEqual(np.string_, input_details[0]['dtype']) 452 self.assertAllEqual([4], input_details[0]['shape']) 453 454 output_details = interpreter.get_output_details() 455 self.assertLen(output_details, 1) 456 self.assertEqual('Reshape', output_details[0]['name']) 457 self.assertEqual(np.string_, output_details[0]['dtype']) 458 self.assertAllEqual([2, 2], output_details[0]['shape']) 459 # TODO(b/122659643): Test setting/getting string data via the python 460 # interpreter API after support has been added. 461 462 def testIntermediateInputArray(self): 463 """Convert a model from an intermediate input array.""" 464 with ops.Graph().as_default(): 465 in_tensor_init = array_ops.placeholder( 466 shape=[1, 16, 16, 3], dtype=dtypes.float32) 467 in_tensor_final = in_tensor_init + in_tensor_init 468 out_tensor = in_tensor_final + in_tensor_final 469 sess = session.Session() 470 471 # Convert model and ensure model is not None. 472 converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final], 473 [out_tensor]) 474 tflite_model = converter.convert() 475 self.assertIsNotNone(tflite_model) 476 477 # Check values from converted model. 478 interpreter = Interpreter(model_content=tflite_model) 479 interpreter.allocate_tensors() 480 481 input_details = interpreter.get_input_details() 482 self.assertLen(input_details, 1) 483 self.assertEqual('add', input_details[0]['name']) 484 self.assertEqual(np.float32, input_details[0]['dtype']) 485 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 486 self.assertEqual((0., 0.), input_details[0]['quantization']) 487 488 output_details = interpreter.get_output_details() 489 self.assertLen(output_details, 1) 490 self.assertEqual('add_1', output_details[0]['name']) 491 self.assertEqual(np.float32, output_details[0]['dtype']) 492 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 493 self.assertEqual((0., 0.), output_details[0]['quantization']) 494 495 def testSizeNoneInvalid(self): 496 with ops.Graph().as_default(): 497 in_tensor = array_ops.placeholder(dtype=dtypes.float32) 498 out_tensor = in_tensor + in_tensor 499 sess = session.Session() 500 501 # Test None as shape when dynamic shapes are disabled. Run with TOCO in 502 # order to invoke shape checking code. 503 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 504 [out_tensor]) 505 converter.experimental_new_converter = False 506 with self.assertRaises(ValueError) as error: 507 converter.convert() 508 self.assertEqual('Provide an input shape for input array \'Placeholder\'.', 509 str(error.exception)) 510 511 def testScalarValid(self): 512 # Construct a graph using a scalar (empty shape) input. 513 with ops.Graph().as_default(): 514 in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 515 out_tensor = in_tensor + in_tensor 516 sess = session.Session() 517 518 # Test conversion with the scalar input shape. 519 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 520 [out_tensor]) 521 tflite_model = converter.convert() 522 self.assertIsNotNone(tflite_model) 523 524 # Check values from converted model. 525 interpreter = Interpreter(model_content=tflite_model) 526 interpreter.allocate_tensors() 527 528 input_details = interpreter.get_input_details() 529 self.assertLen(input_details, 1) 530 self.assertEqual('Placeholder', input_details[0]['name']) 531 self.assertEqual(np.float32, input_details[0]['dtype']) 532 self.assertEmpty(input_details[0]['shape']) 533 534 output_details = interpreter.get_output_details() 535 self.assertLen(output_details, 1) 536 self.assertEqual('add', output_details[0]['name']) 537 self.assertEqual(np.float32, output_details[0]['dtype']) 538 self.assertEmpty(input_details[0]['shape']) 539 540 # Validate inference using the scalar inputs/outputs. 541 test_input = np.array(4.0, dtype=np.float32) 542 expected_output = np.array(8.0, dtype=np.float32) 543 interpreter.set_tensor(input_details[0]['index'], test_input) 544 interpreter.invoke() 545 546 output_data = interpreter.get_tensor(output_details[0]['index']) 547 self.assertEqual(expected_output, output_data) 548 549 def testSizeInvalid(self): 550 with ops.Graph().as_default(): 551 in_tensor = array_ops.placeholder( 552 shape=[1, None, 16, 3], dtype=dtypes.float32) 553 out_tensor = in_tensor + in_tensor 554 sess = session.Session() 555 556 # Test invalid shape. None after 1st dimension. Run with TOCO in order to 557 # invoke shape checking code. 558 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 559 [out_tensor]) 560 converter.experimental_new_converter = False 561 with self.assertRaises(ValueError) as error: 562 converter.convert() 563 self.assertEqual( 564 'None is only supported in the 1st dimension. Tensor ' 565 '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', 566 str(error.exception)) 567 568 def testSizeNone(self): 569 with ops.Graph().as_default(): 570 in_tensor = array_ops.placeholder( 571 shape=[1, None, 16, 3], dtype=dtypes.float32) 572 out_tensor = in_tensor + in_tensor 573 sess = session.Session() 574 575 # Test None after 1st dimension. 576 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 577 [out_tensor]) 578 tflite_model = converter.convert() 579 580 # Check values from converted model. 581 interpreter = Interpreter(model_content=tflite_model) 582 input_details = interpreter.get_input_details() 583 self.assertLen(input_details, 1) 584 self.assertEqual('Placeholder', input_details[0]['name']) 585 self.assertEqual(np.float32, input_details[0]['dtype']) 586 self.assertAllEqual([1, 1, 16, 3], input_details[0]['shape']) 587 self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature']) 588 self.assertEqual((0., 0.), input_details[0]['quantization']) 589 590 # Resize tensor with strict checking. 591 with self.assertRaises(RuntimeError) as error: 592 interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True) 593 self.assertIn( 594 'ResizeInputTensorStrict only allows mutating unknown dimensions ' 595 'identified by -1.', str(error.exception)) 596 597 # Resize tensor and invoke. 598 interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True) 599 interpreter.allocate_tensors() 600 601 test_input = np.full([1, 16, 16, 3], 1.0, dtype=np.float32) 602 interpreter.set_tensor(input_details[0]['index'], test_input) 603 interpreter.invoke() 604 605 input_details = interpreter.get_input_details() 606 self.assertLen(input_details, 1) 607 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 608 self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature']) 609 610 output_details = interpreter.get_output_details() 611 self.assertAllEqual([1, -1, 16, 3], output_details[0]['shape_signature']) 612 613 def testResizeTensorInputStrict(self): 614 # Ensures that resize_tensor_input(strict=True) works as expected. 615 with ops.Graph().as_default(): 616 in_tensor = array_ops.placeholder( 617 shape=[1, 16, 16, 3], dtype=dtypes.float32) 618 out_tensor = in_tensor + in_tensor 619 sess = session.Session() 620 621 # Convert model and ensure model is not None. 622 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 623 [out_tensor]) 624 tflite_model = converter.convert() 625 self.assertIsNotNone(tflite_model) 626 627 # Check values from converted model. 628 interpreter = Interpreter(model_content=tflite_model) 629 630 # Resize incorrect value. 631 with self.assertRaises(RuntimeError) as error: 632 interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True) 633 self.assertIn( 634 'ResizeInputTensorStrict only allows mutating unknown dimensions ' 635 'identified by -1.', str(error.exception)) 636 637 # Resize correct value. 638 interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True) 639 interpreter.allocate_tensors() 640 641 def testBatchSizeValid(self): 642 with ops.Graph().as_default(): 643 in_tensor = array_ops.placeholder( 644 shape=[None, 16, 16, 3], dtype=dtypes.float32) 645 out_tensor = in_tensor + in_tensor 646 sess = session.Session() 647 648 # Convert model and ensure model is not None. 649 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 650 [out_tensor]) 651 tflite_model = converter.convert() 652 self.assertIsNotNone(tflite_model) 653 654 # Check values from converted model. 655 interpreter = Interpreter(model_content=tflite_model) 656 interpreter.allocate_tensors() 657 658 input_details = interpreter.get_input_details() 659 self.assertLen(input_details, 1) 660 self.assertEqual('Placeholder', input_details[0]['name']) 661 self.assertEqual(np.float32, input_details[0]['dtype']) 662 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 663 self.assertEqual((0., 0.), input_details[0]['quantization']) 664 665 output_details = interpreter.get_output_details() 666 self.assertLen(output_details, 1) 667 self.assertEqual('add', output_details[0]['name']) 668 self.assertEqual(np.float32, output_details[0]['dtype']) 669 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 670 self.assertEqual((0., 0.), output_details[0]['quantization']) 671 672 def testBatchSizeNonZero(self): 673 with ops.Graph().as_default(): 674 in_tensor_1 = array_ops.placeholder( 675 shape=[None, 4], dtype=dtypes.float32, name='input1') 676 in_tensor_2 = array_ops.placeholder( 677 shape=[4, 10], dtype=dtypes.float32, name='input2') 678 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2) 679 sess = session.Session() 680 681 # Convert model and ensure model is not None. 682 converter = lite.TFLiteConverter.from_session(sess, 683 [in_tensor_1, in_tensor_2], 684 [out_tensor]) 685 tflite_model = converter.convert() 686 self.assertIsNotNone(tflite_model) 687 688 # Check values from converted model. 689 interpreter = Interpreter(model_content=tflite_model) 690 interpreter.allocate_tensors() 691 692 input_details = interpreter.get_input_details() 693 self.assertLen(input_details, 2) 694 self.assertEqual('input1', input_details[0]['name']) 695 self.assertAllEqual([1, 4], input_details[0]['shape']) 696 self.assertEqual('input2', input_details[1]['name']) 697 self.assertAllEqual([4, 10], input_details[1]['shape']) 698 699 def testFreezeGraph(self): 700 with ops.Graph().as_default(): 701 in_tensor = array_ops.placeholder( 702 shape=[1, 16, 16, 3], dtype=dtypes.float32) 703 var = variable_scope.get_variable( 704 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) 705 # Get the second output to ensure freezing properly processes tensor names 706 # like 'X:1'. 707 out_tensor = nn_ops.top_k(in_tensor + var, name='top_k')[1] 708 sess = session.Session() 709 sess.run(_global_variables_initializer()) 710 711 # Convert model and ensure model is not None. 712 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 713 [out_tensor]) 714 tflite_model = converter.convert() 715 self.assertIsNotNone(tflite_model) 716 717 # Check values from converted model. 718 interpreter = Interpreter(model_content=tflite_model) 719 interpreter.allocate_tensors() 720 721 input_details = interpreter.get_input_details() 722 self.assertLen(input_details, 1) 723 self.assertEqual('Placeholder', input_details[0]['name']) 724 self.assertEqual(np.float32, input_details[0]['dtype']) 725 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 726 self.assertEqual((0., 0.), input_details[0]['quantization']) 727 728 output_details = interpreter.get_output_details() 729 self.assertLen(output_details, 1) 730 self.assertEqual('top_k:1', output_details[0]['name']) 731 self.assertEqual(np.int32, output_details[0]['dtype']) 732 self.assertAllEqual([1, 16, 16, 1], output_details[0]['shape']) 733 self.assertEqual((0., 0.), output_details[0]['quantization']) 734 735 def testGraphviz(self): 736 with ops.Graph().as_default(): 737 in_tensor = array_ops.placeholder( 738 shape=[1, 16, 16, 3], dtype=dtypes.float32) 739 out_tensor = in_tensor + in_tensor 740 sess = session.Session() 741 742 # Convert model and ensure model is not None. 743 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 744 [out_tensor]) 745 converter.output_format = lite_constants.GRAPHVIZ_DOT 746 graphviz_output = converter.convert() 747 self.assertIsNotNone(graphviz_output) 748 749 def testDumpGraphviz(self): 750 with ops.Graph().as_default(): 751 in_tensor = array_ops.placeholder( 752 shape=[1, 16, 16, 3], dtype=dtypes.float32) 753 out_tensor = in_tensor + in_tensor 754 sess = session.Session() 755 756 # Convert model and ensure model is not None. 757 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 758 [out_tensor]) 759 graphviz_dir = self.get_temp_dir() 760 converter.dump_graphviz_dir = graphviz_dir 761 tflite_model = converter.convert() 762 self.assertIsNotNone(tflite_model) 763 764 # Ensure interpreter is able to allocate and check graphviz data. 765 interpreter = Interpreter(model_content=tflite_model) 766 interpreter.allocate_tensors() 767 768 num_items_graphviz = len(os.listdir(graphviz_dir)) 769 self.assertIsNotNone(num_items_graphviz) 770 self.assertIsNotNone( 771 os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot'))) 772 self.assertIsNotNone( 773 os.path.exists( 774 os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot'))) 775 776 def testDumpConversionSummary(self): 777 with ops.Graph().as_default(): 778 in_tensor = array_ops.placeholder( 779 shape=[1, 16, 16, 3], dtype=dtypes.float32) 780 out_tensor = in_tensor + in_tensor 781 sess = session.Session() 782 783 # Convert model and ensure model is not None. 784 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 785 [out_tensor]) 786 log_dir = self.get_temp_dir() 787 converter.conversion_summary_dir = log_dir 788 tflite_model = converter.convert() 789 self.assertIsNotNone(tflite_model) 790 791 self.assertNotEmpty(os.listdir(log_dir)) 792 793 def testDumpConversionSummaryWithOldConverter(self): 794 with ops.Graph().as_default(): 795 in_tensor = array_ops.placeholder( 796 shape=[1, 16, 16, 3], dtype=dtypes.float32) 797 out_tensor = in_tensor + in_tensor 798 sess = session.Session() 799 800 # Convert model and ensure model is not None. 801 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 802 [out_tensor]) 803 converter.experimental_new_converter = False 804 log_dir = self.get_temp_dir() 805 converter.conversion_summary_dir = log_dir 806 tflite_model = converter.convert() 807 self.assertIsNotNone(tflite_model) 808 # Check nothing is generated under the conversion summary path. 809 num_items_conversion_summary = len(os.listdir(log_dir)) 810 self.assertEqual(num_items_conversion_summary, 0) 811 812 def testQuantizeDynamicRange(self): 813 np.random.seed(0) 814 with ops.Graph().as_default(): 815 # We need the tensor to have more than 1024 elements for quantize_weights 816 # to kick in. Thus, the [33, 33] shape. 817 in_tensor_1 = array_ops.placeholder( 818 shape=[33, 33], dtype=dtypes.float32, name='inputA') 819 in_tensor_2 = constant_op.constant( 820 np.random.uniform(low=-10., high=10., size=(33, 33)), 821 shape=[33, 33], 822 dtype=dtypes.float32, 823 name='inputB') 824 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 825 sess = session.Session() 826 827 # Convert float model. 828 float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1], 829 [out_tensor]) 830 float_tflite_model = float_converter.convert() 831 self.assertIsNotNone(float_tflite_model) 832 833 # Convert quantized weights model. 834 quantized_converter = lite.TFLiteConverter.from_session( 835 sess, [in_tensor_1], [out_tensor]) 836 837 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 838 quantized_tflite_model = quantized_converter.convert() 839 self.assertIsNotNone(quantized_tflite_model) 840 841 # Ensure that the quantized weights tflite model is smaller. 842 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 843 844 def testQuantizeDynamicRangeDeprecatedPostTrainingQuantizeAttribute( 845 self): 846 with ops.Graph().as_default(): 847 in_tensor_1 = array_ops.placeholder( 848 shape=[33, 33], dtype=dtypes.float32, name='inputA') 849 in_tensor_2 = constant_op.constant( 850 np.random.uniform(low=-10., high=10., size=(33, 33)), 851 shape=[33, 33], 852 dtype=dtypes.float32, 853 name='inputB') 854 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 855 sess = session.Session() 856 857 quantized_converter = lite.TFLiteConverter.from_session( 858 sess, [in_tensor_1], [out_tensor]) 859 self.assertFalse(quantized_converter.post_training_quantize) 860 861 quantized_converter.post_training_quantize = True 862 self.assertTrue(quantized_converter.post_training_quantize) 863 self.assertEqual(quantized_converter.optimizations, [lite.Optimize.DEFAULT]) 864 865 quantized_tflite_model = quantized_converter.convert() 866 self.assertIsNotNone(quantized_tflite_model) 867 868 def _getIntegerQuantizeModel(self, num_filters=16): 869 np.random.seed(0) 870 inp = array_ops.placeholder( 871 dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input') 872 conv = nn_ops.conv2d( 873 inp, 874 filter=array_ops.ones([3, 3, 3, num_filters]), 875 strides=[1, 1, 1, 1], 876 padding='SAME') 877 output = nn_ops.relu(conv, name='output') 878 879 def calibration_gen(): 880 for _ in range(5): 881 yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)] 882 883 return (inp, output, calibration_gen) 884 885 def testQuantizeInt8AllowFloat(self): 886 with ops.Graph().as_default(): 887 inp, output, calibration_gen = self._getIntegerQuantizeModel() 888 sess = session.Session() 889 890 # Convert float model. 891 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 892 float_tflite_model = float_converter.convert() 893 self.assertIsNotNone(float_tflite_model) 894 # Check the conversion metadata. 895 metadata = get_conversion_metadata(float_tflite_model) 896 self.assertIsNotNone(metadata) 897 self.assertEqual( 898 metadata.environment.tensorflowVersion.decode('utf-8'), 899 versions.__version__) 900 self.assertEqual(metadata.environment.apiVersion, 1) 901 self.assertEqual(metadata.environment.modelType, 902 metadata_fb.ModelType.TF_SESSION) 903 self.assertEqual(metadata.options.allowCustomOps, False) 904 self.assertEqual(metadata.options.enableSelectTfOps, False) 905 self.assertEqual(metadata.options.forceSelectTfOps, False) 906 self.assertAllEqual([], metadata.options.modelOptimizationModes) 907 908 # Convert quantized model. 909 quantized_converter = lite.TFLiteConverter.from_session( 910 sess, [inp], [output]) 911 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 912 quantized_converter.representative_dataset = calibration_gen 913 quantized_tflite_model = quantized_converter.convert() 914 self.assertIsNotNone(quantized_tflite_model) 915 # Check the conversion metadata. 916 metadata = get_conversion_metadata(quantized_tflite_model) 917 self.assertIsNotNone(metadata) 918 self.assertAllEqual([metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER], 919 metadata.options.modelOptimizationModes) 920 921 # The default input and output types should be float. 922 interpreter = Interpreter(model_content=quantized_tflite_model) 923 interpreter.allocate_tensors() 924 input_details = interpreter.get_input_details() 925 self.assertLen(input_details, 1) 926 self.assertEqual(np.float32, input_details[0]['dtype']) 927 output_details = interpreter.get_output_details() 928 self.assertLen(output_details, 1) 929 self.assertEqual(np.float32, output_details[0]['dtype']) 930 931 # Ensure that the quantized weights tflite model is smaller. 932 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 933 934 @parameterized.named_parameters( 935 # Quantize model to Int8 936 ('UseTfliteBuiltinsInt', [lite.OpsSet.TFLITE_BUILTINS_INT8], 937 [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]), 938 ('UseTfliteBuiltinsInt16', [ 939 lite.OpsSet 940 .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 941 ], [metadata_fb.ModelOptimizationMode.PTQ_INT16])) 942 def testQuantizeInt8And16x8(self, supported_ops, expected_opt_modes): 943 with ops.Graph().as_default(): 944 inp, output, calibration_gen = self._getIntegerQuantizeModel() 945 sess = session.Session() 946 947 # Convert float model. 948 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 949 float_tflite_model = float_converter.convert() 950 self.assertIsNotNone(float_tflite_model) 951 952 # Convert model by specifying target spec (instead of optimizations), since 953 # when targeting an integer only backend, quantization is mandatory. 954 quantized_converter = lite.TFLiteConverter.from_session( 955 sess, [inp], [output]) 956 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 957 quantized_converter.target_spec.supported_ops = supported_ops 958 quantized_converter.representative_dataset = calibration_gen 959 quantized_tflite_model = quantized_converter.convert() 960 self.assertIsNotNone(quantized_tflite_model) 961 # Check the conversion metadata. 962 metadata = get_conversion_metadata(quantized_tflite_model) 963 self.assertIsNotNone(metadata) 964 self.assertEqual( 965 metadata.environment.tensorflowVersion.decode('utf-8'), 966 versions.__version__) 967 self.assertEqual(metadata.environment.apiVersion, 1) 968 self.assertEqual(metadata.environment.modelType, 969 metadata_fb.ModelType.TF_SESSION) 970 self.assertEqual(metadata.options.allowCustomOps, False) 971 self.assertEqual(metadata.options.enableSelectTfOps, False) 972 self.assertEqual(metadata.options.forceSelectTfOps, False) 973 self.assertAllEqual(expected_opt_modes, 974 metadata.options.modelOptimizationModes) 975 976 # The default input and output types should be float. 977 interpreter = Interpreter(model_content=quantized_tflite_model) 978 interpreter.allocate_tensors() 979 input_details = interpreter.get_input_details() 980 self.assertLen(input_details, 1) 981 self.assertEqual(np.float32, input_details[0]['dtype']) 982 output_details = interpreter.get_output_details() 983 self.assertLen(output_details, 1) 984 self.assertEqual(np.float32, output_details[0]['dtype']) 985 986 # Ensure that the quantized weights tflite model is smaller. 987 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 988 989 def testQuantizeInt8InputOutput(self): 990 with ops.Graph().as_default(): 991 inp, output, calibration_gen = self._getIntegerQuantizeModel() 992 sess = session.Session() 993 994 # Convert float model. 995 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 996 float_tflite_model = float_converter.convert() 997 self.assertIsNotNone(float_tflite_model) 998 999 # Convert quantized weights model. 1000 quantized_converter = lite.TFLiteConverter.from_session( 1001 sess, [inp], [output]) 1002 quantized_converter.inference_input_type = dtypes.int8 1003 quantized_converter.inference_output_type = dtypes.int8 1004 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1005 quantized_converter.representative_dataset = calibration_gen 1006 quantized_tflite_model = quantized_converter.convert() 1007 self.assertIsNotNone(quantized_tflite_model) 1008 1009 # The input and output types should be int8. 1010 interpreter = Interpreter(model_content=quantized_tflite_model) 1011 interpreter.allocate_tensors() 1012 input_details = interpreter.get_input_details() 1013 self.assertLen(input_details, 1) 1014 self.assertEqual(np.int8, input_details[0]['dtype']) 1015 output_details = interpreter.get_output_details() 1016 self.assertLen(output_details, 1) 1017 self.assertEqual(np.int8, output_details[0]['dtype']) 1018 1019 # Ensure that the quantized weights tflite model is smaller. 1020 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 1021 1022 def testInvalidQuantizeInt8(self): 1023 np.random.seed(0) 1024 with ops.Graph().as_default(): 1025 # We need the tensor to have more than 1024 elements for quantize_weights 1026 # to kick in. Thus, the [33, 33] shape. 1027 in_tensor_1 = array_ops.placeholder( 1028 shape=[33, 33], dtype=dtypes.float32, name='inputA') 1029 in_tensor_2 = constant_op.constant( 1030 np.random.uniform(low=-10., high=10., size=(33, 33)), 1031 shape=[33, 33], 1032 dtype=dtypes.float32, 1033 name='inputB') 1034 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 1035 sess = session.Session() 1036 1037 # Attempt to convert to quantized weights model. 1038 quantized_converter = lite.TFLiteConverter.from_session( 1039 sess, [in_tensor_1], [out_tensor]) 1040 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1041 # Restricting to int8 type only 1042 quantized_converter.target_spec.supported_types = [dtypes.int8] 1043 # A representative dataset is required for full fixed point quantization. 1044 with self.assertRaises(ValueError) as error: 1045 quantized_converter.convert() 1046 self.assertEqual( 1047 'For full integer quantization, a `representative_dataset` ' 1048 'must be specified.', str(error.exception)) 1049 1050 def testQuantizeUInt8(self): 1051 with ops.Graph().as_default(): 1052 in_tensor_1 = array_ops.placeholder( 1053 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 1054 in_tensor_2 = array_ops.placeholder( 1055 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 1056 out_tensor = array_ops.fake_quant_with_min_max_args( 1057 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 1058 sess = session.Session() 1059 1060 # Convert model and ensure model is not None. 1061 converter = lite.TFLiteConverter.from_session(sess, 1062 [in_tensor_1, in_tensor_2], 1063 [out_tensor]) 1064 converter.inference_type = dtypes.uint8 1065 converter.quantized_input_stats = { 1066 'inputA': (0., 1.), 1067 'inputB': (0., 1.) 1068 } # mean, std_dev 1069 tflite_model = converter.convert() 1070 self.assertIsNotNone(tflite_model) 1071 1072 # Check values from converted model. 1073 interpreter = Interpreter(model_content=tflite_model) 1074 interpreter.allocate_tensors() 1075 1076 input_details = interpreter.get_input_details() 1077 self.assertLen(input_details, 2) 1078 self.assertEqual('inputA', input_details[0]['name']) 1079 self.assertEqual(np.uint8, input_details[0]['dtype']) 1080 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1081 self.assertEqual((1., 0.), input_details[0]['quantization']) 1082 1083 self.assertEqual('inputB', input_details[1]['name']) 1084 self.assertEqual(np.uint8, input_details[1]['dtype']) 1085 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 1086 self.assertEqual((1., 0.), input_details[1]['quantization']) 1087 1088 output_details = interpreter.get_output_details() 1089 self.assertLen(output_details, 1) 1090 self.assertEqual(np.uint8, output_details[0]['dtype']) 1091 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1092 self.assertGreater(output_details[0]['quantization'][0], 0) # scale 1093 1094 def testQuantizeUInt8UsingDefaultRangeStats(self): 1095 with ops.Graph().as_default(): 1096 in_tensor = array_ops.placeholder( 1097 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1098 out_tensor = in_tensor + in_tensor 1099 sess = session.Session() 1100 1101 # Convert model and ensure model is not None. 1102 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1103 [out_tensor]) 1104 converter.inference_type = dtypes.uint8 1105 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 1106 converter.default_ranges_stats = (0, 6) # min, max 1107 tflite_model = converter.convert() 1108 self.assertIsNotNone(tflite_model) 1109 1110 # Check values from converted model. 1111 interpreter = Interpreter(model_content=tflite_model) 1112 interpreter.allocate_tensors() 1113 1114 input_details = interpreter.get_input_details() 1115 self.assertLen(input_details, 1) 1116 self.assertEqual('Placeholder', input_details[0]['name']) 1117 self.assertEqual(np.uint8, input_details[0]['dtype']) 1118 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1119 self.assertEqual((1., 0.), input_details[0]['quantization']) 1120 1121 output_details = interpreter.get_output_details() 1122 self.assertLen(output_details, 1) 1123 self.assertEqual('add', output_details[0]['name']) 1124 self.assertEqual(np.uint8, output_details[0]['dtype']) 1125 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1126 self.assertGreater(output_details[0]['quantization'][0], 0) # scale 1127 1128 @parameterized.named_parameters( 1129 # Quantize to Float16 even if rep data provided. 1130 ('UseRepresentativeData', True, False, True, False, False, False, 1131 [metadata_fb.ModelOptimizationMode.PTQ_FLOAT16]), 1132 # Quantize to Float16 if no rep data provided. 1133 ('NoRepresentativeData', False, False, True, False, False, False, 1134 [metadata_fb.ModelOptimizationMode.PTQ_FLOAT16]), 1135 # Post training quantization if both rep data and int8 included. 1136 ('SampleDataIncludeInt8', True, True, False, False, True, False, 1137 [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]), 1138 # Same as above, but using MLIR quantizer 1139 ('SampleDataIncludeInt8Quant', True, True, False, False, True, True, 1140 [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER])) 1141 def testQuantizeFloat16(self, use_rep_data, include_int8, 1142 is_float16_quantized, is_float16_accumulation, 1143 is_post_training_quantized, enable_mlir_quantizer, 1144 expected_opt_modes): 1145 with ops.Graph().as_default(): 1146 inp, output, calibration_gen = self._getIntegerQuantizeModel() 1147 sess = session.Session() 1148 1149 bias_idx = 1 1150 bias_name = 'Conv2D' 1151 1152 # Convert float model. 1153 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1154 float_tflite_model = float_converter.convert() 1155 self.assertIsNotNone(float_tflite_model) 1156 interpreter = Interpreter(model_content=float_tflite_model) 1157 interpreter.allocate_tensors() 1158 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'], 1159 bias_name) 1160 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'], 1161 dtypes.float32) 1162 1163 # Convert model to quantized version 1164 quantized_converter = lite.TFLiteConverter.from_session( 1165 sess, [inp], [output]) 1166 quantized_converter.experimental_new_quantizer = enable_mlir_quantizer 1167 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1168 quantized_converter.target_spec.supported_types = [dtypes.float16] 1169 if include_int8: 1170 quantized_converter.target_spec.supported_types.append(dtypes.int8) 1171 if use_rep_data: 1172 quantized_converter.representative_dataset = calibration_gen 1173 if is_float16_accumulation: 1174 quantized_converter.target_spec.experimental_supported_accumulation_type = dtypes.float16 # pylint: disable=line-too-long 1175 1176 else: 1177 quantized_tflite_model = quantized_converter.convert() 1178 self.assertIsNotNone(quantized_tflite_model) 1179 metadata = get_conversion_metadata(quantized_tflite_model) 1180 self.assertIsNotNone(metadata) 1181 self.assertAllEqual(expected_opt_modes, 1182 metadata.options.modelOptimizationModes) 1183 interpreter = Interpreter(model_content=quantized_tflite_model) 1184 interpreter.allocate_tensors() 1185 1186 # MLIR quantizer has different bias index. 1187 bias_tensor = [ 1188 tensor for tensor in interpreter.get_tensor_details() 1189 if tensor['name'] == bias_name 1190 ] 1191 self.assertLen(bias_tensor, 1) 1192 1193 if is_float16_quantized: 1194 # Verify that bias constant is float16 type. 1195 self.assertEqual(bias_tensor[0]['dtype'], dtypes.float16) 1196 elif is_post_training_quantized: 1197 # Verify that bias constants is int32 type. 1198 self.assertEqual(bias_tensor[0]['dtype'], dtypes.int32) 1199 else: 1200 raise ValueError('Invalid test options.') 1201 1202 def testInvalidQuantizeFloat16(self): 1203 with ops.Graph().as_default(): 1204 inp, output, _ = self._getIntegerQuantizeModel() 1205 sess = session.Session() 1206 1207 # Specify float16 quantization 1208 quantized_converter = lite.TFLiteConverter.from_session( 1209 sess, [inp], [output]) 1210 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1211 quantized_converter.target_spec.supported_types = [dtypes.float16] 1212 # Specify only int8 builtin ops 1213 quantized_converter.target_spec.supported_ops = [ 1214 lite.OpsSet.TFLITE_BUILTINS_INT8 1215 ] 1216 with self.assertRaises(ValueError) as error: 1217 quantized_converter.convert() 1218 self.assertEqual( 1219 'As full integer quantization has been enabled by setting ' 1220 '`target_spec.supported_ops`={tf.lite.OpsSet.TFLITE_BUILTINS_INT8}, ' 1221 'thus `target_spec.supported_types` should be left uninitizalized ' 1222 'or set to {tf.int8}.', str(error.exception)) 1223 1224 @parameterized.named_parameters(('InferenceType_INT8', dtypes.int8), 1225 ('InferenceType_UINT8', dtypes.uint8)) 1226 def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type): 1227 with ops.Graph().as_default(): 1228 in_tensor = array_ops.placeholder( 1229 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1230 out_tensor = array_ops.fake_quant_with_min_max_args( 1231 in_tensor + in_tensor, min=0., max=1.) 1232 sess = session.Session() 1233 1234 quantized_converter = lite.TFLiteConverter.from_session( 1235 sess, [in_tensor], [out_tensor]) 1236 1237 with self.assertRaises(ValueError) as error: 1238 quantized_converter.inference_type = quantized_type 1239 quantized_converter.convert() 1240 self.assertEqual( 1241 'The `quantized_input_stats` flag must be defined when either ' 1242 '`inference_type` flag or `inference_input_type` flag is set to ' 1243 'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and ' 1244 '`inference_input_type=None`.'.format(quantized_type.name), 1245 str(error.exception)) 1246 1247 with self.assertRaises(ValueError) as error: 1248 quantized_converter.inference_type = dtypes.float32 1249 quantized_converter.inference_input_type = quantized_type 1250 quantized_converter.convert() 1251 self.assertEqual( 1252 'The `quantized_input_stats` flag must be defined when either ' 1253 '`inference_type` flag or `inference_input_type` flag is set to ' 1254 'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and ' 1255 '`inference_input_type=tf.{}`.'.format(quantized_type.name), 1256 str(error.exception)) 1257 1258 quantized_converter.inference_type = quantized_type 1259 quantized_converter.inference_input_type = quantized_type 1260 1261 input_arrays = quantized_converter.get_input_arrays() 1262 quantized_converter.quantized_input_stats = {input_arrays[0]: (0., 1.)} 1263 quantized_converter.convert() 1264 1265 def testInvalidQuantizeQATModelMissingInputStats(self): 1266 with ops.Graph().as_default(): 1267 in_tensor_1 = array_ops.placeholder( 1268 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 1269 in_tensor_2 = array_ops.placeholder( 1270 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 1271 out_tensor = array_ops.fake_quant_with_min_max_args( 1272 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 1273 sess = session.Session() 1274 1275 # Convert model and ensure model is not None. 1276 converter = lite.TFLiteConverter.from_session(sess, 1277 [in_tensor_1, in_tensor_2], 1278 [out_tensor]) 1279 converter.inference_type = dtypes.uint8 1280 converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev 1281 with self.assertRaises(ValueError) as error: 1282 converter.convert() 1283 self.assertEqual( 1284 'Quantization input stats are not available for input tensors ' 1285 '\'inputB\'.', str(error.exception)) 1286 1287 def testTrainingTimeAndPostTrainingCalibrateAndQuantize(self): 1288 with ops.Graph().as_default(): 1289 inp, output, calibration_gen = self._getIntegerQuantizeModel() 1290 sess = session.Session() 1291 1292 # Convert float model. 1293 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1294 float_tflite_model = float_converter.convert() 1295 self.assertIsNotNone(float_tflite_model) 1296 1297 converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1298 1299 # extra flags to trigger training time quantization conversion 1300 converter.inference_type = dtypes.int8 1301 converter.inference_input_type = dtypes.float32 1302 converter.inference_output_type = dtypes.float32 1303 input_arrays = converter.get_input_arrays() 1304 converter.quantized_input_stats = {input_arrays[0]: (0., 1.)} 1305 # trigger post-training quantization 1306 converter.optimizations = [lite.Optimize.DEFAULT] 1307 converter.representative_dataset = calibration_gen 1308 converter.experimental_new_quantizer = True 1309 quantized_tflite_model = converter.convert() 1310 self.assertIsNotNone(quantized_tflite_model) 1311 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 1312 1313 # calibration only api 1314 converter._experimental_calibrate_only = True 1315 calibrated_tflite = converter.convert() 1316 quantized_tflite_model = mlir_quantize( 1317 calibrated_tflite, fully_quantize=True) 1318 interpreter = Interpreter(model_content=quantized_tflite_model) 1319 interpreter.allocate_tensors() 1320 input_details = interpreter.get_input_details() 1321 self.assertEqual(np.int8, input_details[0]['dtype']) 1322 self.assertEqual((1., 0.), input_details[0]['quantization']) 1323 1324 output_details = interpreter.get_output_details() 1325 self.assertEqual(np.int8, output_details[0]['dtype']) 1326 1327 def testFloatTocoConverter(self): 1328 """Tests deprecated test TocoConverter.""" 1329 with ops.Graph().as_default(): 1330 in_tensor = array_ops.placeholder( 1331 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1332 out_tensor = in_tensor + in_tensor 1333 sess = session.Session() 1334 1335 # Convert model and ensure model is not None. 1336 converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) 1337 tflite_model = converter.convert() 1338 self.assertIsNotNone(tflite_model) 1339 1340 # Ensure the interpreter is able to load. 1341 interpreter = Interpreter(model_content=tflite_model) 1342 interpreter.allocate_tensors() 1343 1344 def testMultipleOutputNodeNames(self): 1345 """Tests converting a graph with an op that have multiple outputs.""" 1346 with ops.Graph().as_default(): 1347 input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) 1348 out0, out1, out2, out3 = array_ops.split( 1349 input_tensor, [1, 1, 1, 1], axis=0) 1350 sess = session.Session() 1351 1352 # Convert model and ensure model is not None. 1353 converter = lite.TFLiteConverter.from_session(sess, [input_tensor], 1354 [out0, out1, out2, out3]) 1355 tflite_model = converter.convert() 1356 self.assertIsNotNone(tflite_model) 1357 1358 # Check values from converted model. 1359 interpreter = Interpreter(model_content=tflite_model) 1360 interpreter.allocate_tensors() 1361 1362 input_details = interpreter.get_input_details() 1363 self.assertLen(input_details, 1) 1364 interpreter.set_tensor(input_details[0]['index'], 1365 np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) 1366 interpreter.invoke() 1367 1368 output_details = interpreter.get_output_details() 1369 self.assertLen(output_details, 4) 1370 self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index'])) 1371 self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index'])) 1372 self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index'])) 1373 self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index'])) 1374 1375 @test_util.run_in_graph_and_eager_modes 1376 def testFunctions(self): 1377 """Tests tf.function in 1.X.""" 1378 1379 @def_function.function 1380 def plus_placeholder(x, placeholder): 1381 return x + placeholder 1382 1383 with ops.Graph().as_default(): 1384 placeholder = array_ops.placeholder( 1385 dtype=dtypes.float32, shape=[1], name='input') 1386 variable_node = variables.Variable(1.0, name='variable_node') 1387 defun_node = plus_placeholder(variable_node, placeholder) 1388 output_node = math_ops.multiply(defun_node, 2.0, name='output_node') 1389 1390 # Initialize variables in the model. 1391 sess = session.Session() 1392 sess.run(variables.variables_initializer([variable_node])) 1393 1394 # Convert model and ensure model is not None. 1395 converter = lite.TFLiteConverter.from_session(sess, [placeholder], 1396 [output_node]) 1397 tflite_model = converter.convert() 1398 self.assertIsNotNone(tflite_model) 1399 1400 # Check values from converted model. 1401 interpreter = Interpreter(model_content=tflite_model) 1402 interpreter.allocate_tensors() 1403 1404 input_details = interpreter.get_input_details() 1405 self.assertLen(input_details, 1) 1406 self.assertEqual('input', input_details[0]['name']) 1407 self.assertEqual(np.float32, input_details[0]['dtype']) 1408 self.assertAllEqual([1], input_details[0]['shape']) 1409 self.assertEqual((0., 0.), input_details[0]['quantization']) 1410 1411 output_details = interpreter.get_output_details() 1412 self.assertLen(output_details, 1) 1413 self.assertEqual('output_node', output_details[0]['name']) 1414 self.assertEqual(np.float32, output_details[0]['dtype']) 1415 self.assertAllEqual([1], output_details[0]['shape']) 1416 self.assertEqual((0., 0.), output_details[0]['quantization']) 1417 1418 def testInferenceInputOutputTypeFloatDefault(self): 1419 with ops.Graph().as_default(): 1420 in_tensor = array_ops.placeholder( 1421 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1422 out_tensor = in_tensor + in_tensor 1423 sess = session.Session() 1424 1425 # Convert model and ensure model is not None. 1426 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1427 [out_tensor]) 1428 tflite_model = converter.convert() 1429 self.assertIsNotNone(tflite_model) 1430 1431 # Check values from converted model. 1432 interpreter = Interpreter(model_content=tflite_model) 1433 interpreter.allocate_tensors() 1434 1435 input_details = interpreter.get_input_details() 1436 self.assertLen(input_details, 1) 1437 self.assertEqual('Placeholder', input_details[0]['name']) 1438 self.assertEqual(np.float32, input_details[0]['dtype']) 1439 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1440 1441 output_details = interpreter.get_output_details() 1442 self.assertLen(output_details, 1) 1443 self.assertEqual('add', output_details[0]['name']) 1444 self.assertEqual(np.float32, output_details[0]['dtype']) 1445 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1446 1447 def testInferenceInputOutputTypeQuantizedUint8Default(self): 1448 with ops.Graph().as_default(): 1449 in_tensor = array_ops.placeholder( 1450 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1451 out_tensor = array_ops.fake_quant_with_min_max_args( 1452 in_tensor + in_tensor, min=0., max=1., name='output') 1453 sess = session.Session() 1454 1455 # Convert model and ensure model is not None. 1456 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1457 [out_tensor]) 1458 converter.inference_type = dtypes.uint8 1459 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 1460 tflite_model = converter.convert() 1461 self.assertIsNotNone(tflite_model) 1462 1463 # Check values from converted model. 1464 interpreter = Interpreter(model_content=tflite_model) 1465 interpreter.allocate_tensors() 1466 1467 input_details = interpreter.get_input_details() 1468 self.assertLen(input_details, 1) 1469 self.assertEqual('Placeholder', input_details[0]['name']) 1470 self.assertEqual(np.uint8, input_details[0]['dtype']) 1471 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1472 1473 output_details = interpreter.get_output_details() 1474 self.assertLen(output_details, 1) 1475 self.assertEqual('output', output_details[0]['name']) 1476 self.assertEqual(np.uint8, output_details[0]['dtype']) 1477 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1478 1479 def testReusingConverterWithDifferentPostTrainingQuantization(self): 1480 with ops.Graph().as_default(): 1481 in_tensor = array_ops.placeholder( 1482 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1483 out_tensor = array_ops.fake_quant_with_min_max_args( 1484 in_tensor + in_tensor, min=0., max=1., name='output') 1485 sess = session.Session() 1486 1487 # Convert model and ensure model is not None. 1488 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1489 [out_tensor]) 1490 1491 converter.post_training_quantize = True 1492 tflite_model = converter.convert() 1493 self.assertIsNotNone(tflite_model) 1494 1495 converter.post_training_quantize = False 1496 tflite_model = converter.convert() 1497 self.assertIsNotNone(tflite_model) 1498 1499 def testResizeWithShape(self): 1500 with ops.Graph().as_default(): 1501 # Construct a graph with a dynamically shapped input and an internal node 1502 # that relies on the output of that input's shape. 1503 in_tensor = array_ops.placeholder( 1504 shape=[None, None], dtype=dtypes.float32) 1505 in_tensor2 = [[1, 2], [3, 4]] 1506 out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor)) 1507 sess = session.Session() 1508 1509 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1510 [out_tensor]) 1511 tflite_model = converter.convert() 1512 1513 # Check values from converted model. 1514 interpreter = Interpreter(model_content=tflite_model) 1515 input_details = interpreter.get_input_details() 1516 self.assertLen(input_details, 1) 1517 self.assertAllEqual([1, 1], input_details[0]['shape']) 1518 self.assertAllEqual([-1, -1], input_details[0]['shape_signature']) 1519 1520 # Resize tensor and invoke. 1521 interpreter.resize_tensor_input(0, [4]) 1522 interpreter.allocate_tensors() 1523 interpreter.invoke() 1524 1525 # The output should be reshaped properly according to the resized input. 1526 output_details = interpreter.get_output_details() 1527 self.assertLen(output_details, 1) 1528 self.assertEqual(np.int32, output_details[0]['dtype']) 1529 self.assertAllEqual([4], output_details[0]['shape']) 1530 output_data = interpreter.get_tensor(output_details[0]['index']) 1531 self.assertAllEqual([1, 2, 3, 4], output_data) 1532 1533 def testResizingIntermediateDynamicTensor(self): 1534 # This is a regression test for the case where shape of dynamic output 1535 # tensors changes between invocations. 1536 # See also https://github.com/tensorflow/tensorflow/issues/26549 1537 with ops.Graph().as_default(): 1538 input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32) 1539 input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32) 1540 1541 # The bug is triggered only when dynamic tensor is intermediate. Putting 1542 # some other ops around it. 1543 neg = math_ops.negative(input2_tensor) 1544 padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32) 1545 output_tensor = array_ops.pad(input_tensor, padding) + neg 1546 1547 sess = session.Session() 1548 1549 converter = lite.TFLiteConverter.from_session( 1550 sess, [input_tensor, padding, input2_tensor], [output_tensor]) 1551 tflite_model = converter.convert() 1552 1553 interpreter = Interpreter(model_content=tflite_model) 1554 interpreter.allocate_tensors() 1555 1556 input_details = interpreter.get_input_details() 1557 interpreter.set_tensor(input_details[1]['index'], 1558 np.array([[1, 1], [1, 1]], dtype=np.int32)) 1559 interpreter.invoke() 1560 1561 # Without the fix, invocation will fail when changing the shape of 1562 # intermediate dynamic tensors. 1563 interpreter.set_tensor(input_details[1]['index'], 1564 np.array([[2, 2], [2, 2]], dtype=np.int32)) 1565 interpreter.invoke() 1566 1567 def testGraphDebugInfo(self): 1568 """Test a session has debug info captured.""" 1569 1570 @def_function.function 1571 def plus_placeholder(x, placeholder): 1572 return x + placeholder 1573 1574 with ops.Graph().as_default(): 1575 placeholder = array_ops.placeholder( 1576 dtype=dtypes.float32, shape=[1], name='input') 1577 variable_node = variables.Variable(1.0, name='variable_node') 1578 defun_node = plus_placeholder(variable_node, placeholder) 1579 output_node = math_ops.multiply(defun_node, 2.0, name='output_node') 1580 1581 # Initialize variables in the model. 1582 sess = session.Session() 1583 sess.run(variables.variables_initializer([variable_node])) 1584 1585 converter = lite.TFLiteConverter.from_session(sess, [placeholder], 1586 [output_node]) 1587 converter.convert() 1588 self.assertValidDebugInfo(converter._debug_info) 1589 1590 # Check the add node in the inlined function is included. 1591 func = sess.graph.as_graph_def().library.function[0].signature.name 1592 self.assertIn(('add@' + func), converter._debug_info.traces) 1593 1594 def testOutputOnlyModel(self): 1595 with ops.Graph().as_default(): 1596 out_tensor = random_ops.random_normal(shape=[3]) 1597 sess = session.Session() 1598 1599 # Convert model and ensure model is not None. 1600 converter = lite.TFLiteConverter.from_session(sess, [], [out_tensor]) 1601 converter.target_spec.supported_ops = [ 1602 lite.OpsSet.TFLITE_BUILTINS, 1603 lite.OpsSet.SELECT_TF_OPS, 1604 ] 1605 1606 # Empty input array is a valid input. 1607 self.assertTrue(converter._has_valid_tensors()) 1608 1609 tflite_model = converter.convert() 1610 self.assertIsNotNone(tflite_model) 1611 1612 1613class FromFrozenGraphFile(LiteTest): 1614 1615 def testFloat(self): 1616 with ops.Graph().as_default(): 1617 in_tensor = array_ops.placeholder( 1618 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1619 _ = in_tensor + in_tensor 1620 sess = session.Session() 1621 1622 # Write graph to file. 1623 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1624 write_graph(sess.graph_def, '', graph_def_file, False) 1625 sess.close() 1626 1627 # Convert model and ensure model is not None. 1628 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 1629 ['Placeholder'], ['add']) 1630 tflite_model = converter.convert() 1631 self.assertIsNotNone(tflite_model) 1632 1633 # Check values from converted model. 1634 interpreter = Interpreter(model_content=tflite_model) 1635 interpreter.allocate_tensors() 1636 1637 input_details = interpreter.get_input_details() 1638 self.assertLen(input_details, 1) 1639 self.assertEqual('Placeholder', input_details[0]['name']) 1640 self.assertEqual(np.float32, input_details[0]['dtype']) 1641 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1642 self.assertEqual((0., 0.), input_details[0]['quantization']) 1643 1644 output_details = interpreter.get_output_details() 1645 self.assertLen(output_details, 1) 1646 self.assertEqual('add', output_details[0]['name']) 1647 self.assertEqual(np.float32, output_details[0]['dtype']) 1648 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1649 self.assertEqual((0., 0.), output_details[0]['quantization']) 1650 1651 def testFloatWithShapesArray(self): 1652 """Test a shape overriding case.""" 1653 with ops.Graph().as_default(): 1654 in_tensor = array_ops.placeholder( 1655 shape=[None, 16, 16, 3], dtype=dtypes.float32) 1656 _ = in_tensor + in_tensor 1657 sess = session.Session() 1658 1659 # Write graph to file. 1660 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1661 write_graph(sess.graph_def, '', graph_def_file, False) 1662 sess.close() 1663 1664 # Convert model and ensure model is not None. 1665 converter = lite.TFLiteConverter.from_frozen_graph( 1666 graph_def_file, ['Placeholder'], ['add'], 1667 input_shapes={'Placeholder': [2, 16, 16, 3]}) 1668 tflite_model = converter.convert() 1669 self.assertIsNotNone(tflite_model) 1670 1671 # Check values from converted model. 1672 interpreter = Interpreter(model_content=tflite_model) 1673 interpreter.allocate_tensors() 1674 1675 input_details = interpreter.get_input_details() 1676 self.assertLen(input_details, 1) 1677 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 1678 1679 def testInvalidShapesArray(self): 1680 """Test an invalid shape overriding case, which has a wrong input name.""" 1681 with ops.Graph().as_default(): 1682 in_tensor = array_ops.placeholder( 1683 shape=[None, 16, 16, 3], dtype=dtypes.float32) 1684 _ = in_tensor + in_tensor 1685 sess = session.Session() 1686 1687 # Write graph to file. 1688 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1689 write_graph(sess.graph_def, '', graph_def_file, False) 1690 sess.close() 1691 1692 # Convert model and ensure model is not None. 1693 with self.assertRaises(ValueError): 1694 lite.TFLiteConverter.from_frozen_graph( 1695 graph_def_file, ['Placeholder'], ['add'], 1696 input_shapes={'wrong_input': [2, 16, 16, 3]}) 1697 1698 def testPartialShapesArray(self): 1699 """Test a shape overriding case, with the only one input among two.""" 1700 with ops.Graph().as_default(): 1701 a = array_ops.placeholder( 1702 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='a') 1703 b = array_ops.placeholder( 1704 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='b') 1705 _ = math_ops.add(a, b, name='add') 1706 sess = session.Session() 1707 1708 # Write graph to file. 1709 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1710 write_graph(sess.graph_def, '', graph_def_file, False) 1711 sess.close() 1712 1713 # Convert model and ensure model is not None. 1714 converter = lite.TFLiteConverter.from_frozen_graph( 1715 graph_def_file, ['a', 'b'], ['add'], input_shapes={'a': [2, 16, 16, 3]}) 1716 tflite_model = converter.convert() 1717 self.assertIsNotNone(tflite_model) 1718 1719 # Check values from converted model. 1720 interpreter = Interpreter(model_content=tflite_model) 1721 interpreter.allocate_tensors() 1722 1723 input_details = interpreter.get_input_details() 1724 self.assertLen(input_details, 2) 1725 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 1726 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 1727 1728 def testFreezeGraph(self): 1729 with ops.Graph().as_default(): 1730 in_tensor = array_ops.placeholder( 1731 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1732 var = variable_scope.get_variable( 1733 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) 1734 _ = in_tensor + var 1735 sess = session.Session() 1736 1737 # Write graph to file. 1738 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1739 write_graph(sess.graph_def, '', graph_def_file, False) 1740 sess.close() 1741 1742 # Ensure the graph with variables cannot be converted. 1743 with self.assertRaises(ValueError) as error: 1744 lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'], 1745 ['add']) 1746 self.assertEqual('Please freeze the graph using freeze_graph.py.', 1747 str(error.exception)) 1748 1749 def testPbtxt(self): 1750 with ops.Graph().as_default(): 1751 in_tensor = array_ops.placeholder( 1752 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1753 _ = in_tensor + in_tensor 1754 sess = session.Session() 1755 1756 # Write graph to file. 1757 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') 1758 write_graph(sess.graph_def, '', graph_def_file, True) 1759 sess.close() 1760 1761 # Convert model and ensure model is not None. 1762 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 1763 ['Placeholder'], ['add']) 1764 tflite_model = converter.convert() 1765 self.assertIsNotNone(tflite_model) 1766 1767 # Check values from converted model. 1768 interpreter = Interpreter(model_content=tflite_model) 1769 interpreter.allocate_tensors() 1770 1771 input_details = interpreter.get_input_details() 1772 self.assertLen(input_details, 1) 1773 self.assertEqual('Placeholder', input_details[0]['name']) 1774 self.assertEqual(np.float32, input_details[0]['dtype']) 1775 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1776 self.assertEqual((0., 0.), input_details[0]['quantization']) 1777 1778 output_details = interpreter.get_output_details() 1779 self.assertLen(output_details, 1) 1780 self.assertEqual('add', output_details[0]['name']) 1781 self.assertEqual(np.float32, output_details[0]['dtype']) 1782 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1783 self.assertEqual((0., 0.), output_details[0]['quantization']) 1784 1785 def testInvalidFileNotFound(self): 1786 with self.assertRaises(IOError) as error: 1787 lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'], 1788 ['add']) 1789 self.assertEqual('File \'invalid_file\' does not exist.', 1790 str(error.exception)) 1791 1792 def testInvalidFileBadData(self): 1793 graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') 1794 with gfile.Open(graph_def_file, 'wb') as temp_file: 1795 temp_file.write('bad data') 1796 temp_file.flush() 1797 1798 # Attempts to convert the invalid model. 1799 with self.assertRaises(IOError) as error: 1800 lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'], 1801 ['add']) 1802 self.assertEqual( 1803 'Unable to parse input file \'{}\'.'.format(graph_def_file), 1804 str(error.exception)) 1805 1806 def testFloatTocoConverter(self): 1807 with ops.Graph().as_default(): 1808 in_tensor = array_ops.placeholder( 1809 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1810 _ = in_tensor + in_tensor 1811 sess = session.Session() 1812 1813 # Write graph to file. 1814 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1815 write_graph(sess.graph_def, '', graph_def_file, False) 1816 sess.close() 1817 1818 # Convert model and ensure model is not None. 1819 converter = lite.TocoConverter.from_frozen_graph(graph_def_file, 1820 ['Placeholder'], ['add']) 1821 tflite_model = converter.convert() 1822 self.assertIsNotNone(tflite_model) 1823 1824 # Ensure the model is able to load. 1825 interpreter = Interpreter(model_content=tflite_model) 1826 interpreter.allocate_tensors() 1827 1828 def testGraphDebugInfo(self): 1829 """Test a frozen graph doesn't have debug info captured.""" 1830 with ops.Graph().as_default(): 1831 in_tensor = array_ops.placeholder( 1832 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1833 _ = in_tensor + in_tensor 1834 sess = session.Session() 1835 1836 # Write graph to file. 1837 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1838 write_graph(sess.graph_def, '', graph_def_file, False) 1839 sess.close() 1840 1841 # Convert model and ensure model is not None. 1842 converter = lite.TocoConverter.from_frozen_graph(graph_def_file, 1843 ['Placeholder'], ['add']) 1844 converter.convert() 1845 # GraphDebugInfo should be none for frozen graph. 1846 self.assertFalse(converter._debug_info) 1847 1848 def testExcludeConversionMetadata(self): 1849 with ops.Graph().as_default(): 1850 in_tensor = array_ops.placeholder( 1851 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1852 _ = in_tensor + in_tensor 1853 sess = session.Session() 1854 1855 # Write graph to file. 1856 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1857 write_graph(sess.graph_def, '', graph_def_file, False) 1858 sess.close() 1859 1860 # Convert model and ensure model is not None. 1861 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 1862 ['Placeholder'], ['add']) 1863 converter.exclude_conversion_metadata = True 1864 tflite_model = converter.convert() 1865 self.assertIsNotNone(tflite_model) 1866 # Check the conversion metadata. 1867 metadata = get_conversion_metadata(tflite_model) 1868 self.assertIsNone(metadata) 1869 1870 1871class FromFrozenGraphObjectDetection(LiteTest): 1872 1873 def _initObjectDetectionArgs(self): 1874 # Initializes the arguments required for the object detection model. 1875 # Looks for the model file which is saved in a different location internally 1876 # and externally. 1877 filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb') 1878 if not os.path.exists(filename): 1879 filename = os.path.join( 1880 resource_loader.get_root_dir_with_all_resources(), 1881 '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb') 1882 if not os.path.exists(filename): 1883 raise IOError("File '{0}' does not exist.".format(filename)) 1884 1885 self._graph_def_file = filename 1886 self._input_arrays = ['normalized_input_image_tensor'] 1887 self._output_arrays = [ 1888 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', 1889 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3' 1890 ] 1891 self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]} 1892 1893 def testTFLiteGraphDef(self): 1894 # Tests the object detection model that cannot be loaded in TensorFlow. 1895 self._initObjectDetectionArgs() 1896 1897 converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file, 1898 self._input_arrays, 1899 self._output_arrays, 1900 self._input_shapes) 1901 converter.allow_custom_ops = True 1902 tflite_model = converter.convert() 1903 self.assertIsNotNone(tflite_model) 1904 1905 # Check values from converted model. 1906 interpreter = Interpreter(model_content=tflite_model) 1907 interpreter.allocate_tensors() 1908 1909 input_details = interpreter.get_input_details() 1910 self.assertLen(input_details, 1) 1911 self.assertEqual('normalized_input_image_tensor', input_details[0]['name']) 1912 self.assertEqual(np.float32, input_details[0]['dtype']) 1913 self.assertAllEqual([1, 300, 300, 3], input_details[0]['shape']) 1914 self.assertEqual((0., 0.), input_details[0]['quantization']) 1915 1916 output_details = interpreter.get_output_details() 1917 self.assertLen(output_details, 4) 1918 self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name']) 1919 self.assertEqual(np.float32, output_details[0]['dtype']) 1920 self.assertAllEqual([1, 10, 4], output_details[0]['shape']) 1921 self.assertEqual((0., 0.), output_details[0]['quantization']) 1922 1923 self.assertEqual('TFLite_Detection_PostProcess:1', 1924 output_details[1]['name']) 1925 self.assertAllEqual([1, 10], output_details[1]['shape']) 1926 self.assertEqual('TFLite_Detection_PostProcess:2', 1927 output_details[2]['name']) 1928 self.assertAllEqual([1, 10], output_details[2]['shape']) 1929 self.assertEqual('TFLite_Detection_PostProcess:3', 1930 output_details[3]['name']) 1931 self.assertAllEqual([1], output_details[3]['shape']) 1932 1933 def testTFLiteGraphDefWithControlOutput(self): 1934 with ops.Graph().as_default(): 1935 in_tensor = array_ops.placeholder( 1936 shape=[5, 5], dtype=dtypes.float32, name='input') 1937 out_tensor = in_tensor + in_tensor 1938 logging_ops.print_v2(out_tensor) 1939 sess = session.Session() 1940 1941 converter = lite.TFLiteConverter( 1942 sess.graph_def, 1943 input_tensors=None, 1944 output_tensors=None, 1945 input_arrays_with_shape=[('input', [5, 5])], 1946 output_arrays=None, 1947 experimental_debug_info_func=None) 1948 converter._control_output_arrays = ['PrintV2'] 1949 converter.target_spec.supported_ops = [ 1950 lite.OpsSet.TFLITE_BUILTINS, 1951 lite.OpsSet.SELECT_TF_OPS, 1952 ] 1953 tflite_model = converter.convert() 1954 self.assertIsNotNone(tflite_model) 1955 1956 model = util._convert_model_from_bytearray_to_object(tflite_model) 1957 self.assertEqual(model.operatorCodes[0].builtinCode, 1958 schema_fb.BuiltinOperator.ADD) 1959 self.assertEqual(model.operatorCodes[1].builtinCode, 1960 schema_fb.BuiltinOperator.CUSTOM) 1961 self.assertEqual(model.operatorCodes[1].customCode, b'FlexStringFormat') 1962 self.assertEqual(model.operatorCodes[2].builtinCode, 1963 schema_fb.BuiltinOperator.CUSTOM) 1964 self.assertEqual(model.operatorCodes[2].customCode, b'FlexPrintV2') 1965 1966 # Check values from converted model. 1967 interpreter = Interpreter(model_content=tflite_model) 1968 interpreter.allocate_tensors() 1969 1970 input_details = interpreter.get_input_details() 1971 self.assertLen(input_details, 1) 1972 self.assertEqual('input', input_details[0]['name']) 1973 self.assertEqual(np.float32, input_details[0]['dtype']) 1974 self.assertAllEqual([5, 5], input_details[0]['shape']) 1975 self.assertEqual((0., 0.), input_details[0]['quantization']) 1976 1977 output_details = interpreter.get_output_details() 1978 self.assertLen(output_details, 0) 1979 1980 def testModifyIOToUint8(self): 1981 # Tests the object detection model that cannot be loaded in TensorFlow. 1982 self._initObjectDetectionArgs() 1983 1984 def representative_dataset_gen(): 1985 for _ in range(2): 1986 yield [ 1987 np.random.uniform(low=0, high=1, 1988 size=(1, 300, 300, 3)).astype(np.float32) 1989 ] 1990 1991 converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file, 1992 self._input_arrays, 1993 self._output_arrays, 1994 self._input_shapes) 1995 converter.representative_dataset = representative_dataset_gen 1996 converter.target_spec.supported_ops = {lite.OpsSet.TFLITE_BUILTINS_INT8} 1997 converter.inference_type = dtypes.int8 1998 converter.inference_input_type = dtypes.uint8 1999 converter.inference_output_type = dtypes.uint8 2000 converter.experimental_new_quantizer = True 2001 converter.quantized_input_stats = { 2002 'normalized_input_image_tensor': (0., 1.) 2003 } # mean, std_dev 2004 converter.allow_custom_ops = True 2005 tflite_model = converter.convert() 2006 2007 self.assertIsNotNone(tflite_model) 2008 2009 model = util._convert_model_from_bytearray_to_object(tflite_model) 2010 quant_opcode_idxs = util.get_quantize_opcode_idx(model) 2011 2012 subgraph = model.subgraphs[0] 2013 tensors = subgraph.tensors 2014 operators = subgraph.operators 2015 for op in operators: 2016 if op.opcodeIndex in quant_opcode_idxs: 2017 input_type = util._convert_tflite_enum_type_to_tf_type( 2018 tensors[op.inputs[0]].type) 2019 if op.outputs[0] in subgraph.outputs: 2020 self.assertEqual(input_type, dtypes.float32) 2021 2022 2023class FromSavedModelTest(TestModels): 2024 2025 def _createSavedModel(self, shape): 2026 """Create a simple SavedModel.""" 2027 saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel') 2028 with ops.Graph().as_default(): 2029 with session.Session() as sess: 2030 in_tensor_1 = array_ops.placeholder( 2031 shape=shape, dtype=dtypes.float32, name='inputB') 2032 in_tensor_2 = array_ops.placeholder( 2033 shape=shape, dtype=dtypes.float32, name='inputA') 2034 out_tensor = in_tensor_1 + in_tensor_2 2035 inputs = {'x': in_tensor_1, 'y': in_tensor_2} 2036 outputs = {'z': out_tensor} 2037 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 2038 return saved_model_dir 2039 2040 def testSimpleModel(self): 2041 """Test a SavedModel.""" 2042 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2043 2044 # Convert model and ensure model is not None. 2045 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2046 tflite_model = converter.convert() 2047 self.assertIsNotNone(tflite_model) 2048 2049 interpreter = Interpreter(model_content=tflite_model) 2050 interpreter.allocate_tensors() 2051 2052 input_details = interpreter.get_input_details() 2053 self.assertLen(input_details, 2) 2054 self.assertStartsWith(input_details[0]['name'], 'inputA') 2055 self.assertEqual(np.float32, input_details[0]['dtype']) 2056 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 2057 self.assertEqual((0., 0.), input_details[0]['quantization']) 2058 2059 self.assertStartsWith(input_details[1]['name'], 'inputB') 2060 self.assertEqual(np.float32, input_details[1]['dtype']) 2061 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 2062 self.assertEqual((0., 0.), input_details[1]['quantization']) 2063 2064 output_details = interpreter.get_output_details() 2065 self.assertLen(output_details, 1) 2066 self.assertStartsWith(output_details[0]['name'], 'add') 2067 self.assertEqual(np.float32, output_details[0]['dtype']) 2068 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 2069 self.assertEqual((0., 0.), output_details[0]['quantization']) 2070 2071 def testNoneBatchSize(self): 2072 """Test a SavedModel, with None in input tensor's shape.""" 2073 saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) 2074 2075 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2076 tflite_model = converter.convert() 2077 self.assertIsNotNone(tflite_model) 2078 2079 # Check values from converted model. 2080 interpreter = Interpreter(model_content=tflite_model) 2081 interpreter.allocate_tensors() 2082 2083 input_details = interpreter.get_input_details() 2084 self.assertLen(input_details, 2) 2085 self.assertStartsWith(input_details[0]['name'], 'inputA') 2086 self.assertEqual(np.float32, input_details[0]['dtype']) 2087 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 2088 self.assertEqual((0., 0.), input_details[0]['quantization']) 2089 2090 self.assertStartsWith(input_details[1]['name'], 'inputB') 2091 self.assertEqual(np.float32, input_details[1]['dtype']) 2092 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 2093 self.assertEqual((0., 0.), input_details[1]['quantization']) 2094 2095 output_details = interpreter.get_output_details() 2096 self.assertLen(output_details, 1) 2097 self.assertStartsWith(output_details[0]['name'], 'add') 2098 self.assertEqual(np.float32, output_details[0]['dtype']) 2099 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 2100 self.assertEqual((0., 0.), output_details[0]['quantization']) 2101 2102 def testOrderInputArrays(self): 2103 """Test a SavedModel ordering of input arrays.""" 2104 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2105 2106 converter = lite.TFLiteConverter.from_saved_model( 2107 saved_model_dir, input_arrays=['inputB', 'inputA']) 2108 tflite_model = converter.convert() 2109 self.assertIsNotNone(tflite_model) 2110 2111 # Check values from converted model. 2112 interpreter = Interpreter(model_content=tflite_model) 2113 interpreter.allocate_tensors() 2114 2115 input_details = interpreter.get_input_details() 2116 self.assertLen(input_details, 2) 2117 self.assertStartsWith(input_details[0]['name'], 'inputA') 2118 self.assertEqual(np.float32, input_details[0]['dtype']) 2119 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 2120 self.assertEqual((0., 0.), input_details[0]['quantization']) 2121 2122 self.assertStartsWith(input_details[1]['name'], 'inputB') 2123 self.assertEqual(np.float32, input_details[1]['dtype']) 2124 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 2125 self.assertEqual((0., 0.), input_details[1]['quantization']) 2126 2127 output_details = interpreter.get_output_details() 2128 self.assertLen(output_details, 1) 2129 self.assertStartsWith(output_details[0]['name'], 'add') 2130 self.assertEqual(np.float32, output_details[0]['dtype']) 2131 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 2132 self.assertEqual((0., 0.), output_details[0]['quantization']) 2133 2134 def testShapeOverriding(self): 2135 """Test a SavedModel with the input_shapes arugment.""" 2136 saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) 2137 2138 # Convert model and ensure model is not None. 2139 converter = lite.TFLiteConverter.from_saved_model( 2140 saved_model_dir, 2141 input_shapes={ 2142 'inputA': [2, 16, 16, 3], 2143 'inputB': [2, 16, 16, 3] 2144 }) 2145 tflite_model = converter.convert() 2146 self.assertIsNotNone(tflite_model) 2147 2148 interpreter = Interpreter(model_content=tflite_model) 2149 interpreter.allocate_tensors() 2150 2151 input_details = interpreter.get_input_details() 2152 self.assertLen(input_details, 2) 2153 self.assertStartsWith(input_details[0]['name'], 'inputA') 2154 self.assertEqual(np.float32, input_details[0]['dtype']) 2155 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 2156 self.assertEqual((0., 0.), input_details[0]['quantization']) 2157 2158 self.assertStartsWith(input_details[1]['name'], 'inputB') 2159 self.assertEqual(np.float32, input_details[1]['dtype']) 2160 self.assertAllEqual([2, 16, 16, 3], input_details[1]['shape']) 2161 self.assertEqual((0., 0.), input_details[1]['quantization']) 2162 2163 output_details = interpreter.get_output_details() 2164 self.assertLen(output_details, 1) 2165 self.assertStartsWith(output_details[0]['name'], 'add') 2166 self.assertEqual(np.float32, output_details[0]['dtype']) 2167 self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape']) 2168 self.assertEqual((0., 0.), output_details[0]['quantization']) 2169 2170 def testWrongInputShapes(self): 2171 """Test a SavedModel with a wrong name in the input_shapes argument.""" 2172 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2173 2174 # Check case where input shape is given. 2175 with self.assertRaises(ValueError): 2176 lite.TFLiteConverter.from_saved_model( 2177 saved_model_dir, 2178 input_arrays=['inputA'], 2179 input_shapes={'wrong_input': [1, 16, 16, 3]}) 2180 2181 def testSubsetInputShaapes(self): 2182 """Test a SavedModel with a subset of the input array names of the model.""" 2183 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2184 2185 # Check case where input shape is given. 2186 converter = lite.TFLiteConverter.from_saved_model( 2187 saved_model_dir, 2188 input_arrays=['inputA'], 2189 input_shapes={'inputA': [1, 16, 16, 3]}) 2190 2191 # Since we only partially specify the input, this is not allowed. 2192 with self.assertRaises(ConverterError): 2193 _ = converter.convert() 2194 2195 # Check case where input shape is None. 2196 converter = lite.TFLiteConverter.from_saved_model( 2197 saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) 2198 2199 # Since we only partially specify the input, this is not allowed. 2200 with self.assertRaises(ConverterError): 2201 _ = converter.convert() 2202 2203 def testSimpleModelTocoConverter(self): 2204 """Test a SavedModel with deprecated TocoConverter.""" 2205 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2206 2207 # Convert model and ensure model is not None. 2208 converter = lite.TocoConverter.from_saved_model(saved_model_dir) 2209 tflite_model = converter.convert() 2210 self.assertIsNotNone(tflite_model) 2211 2212 # Ensure the model is able to load. 2213 interpreter = Interpreter(model_content=tflite_model) 2214 interpreter.allocate_tensors() 2215 2216 def testGraphDebugInfo(self): 2217 """Test a SavedModel has debug info captured.""" 2218 self.skipTest( 2219 'b/221093690: The debug info is not from self._createSavedModel(), ' 2220 'but from saved_model.loader_impl().') 2221 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2222 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2223 converter.convert() 2224 self.assertValidDebugInfo(converter._debug_info) 2225 2226 2227class MyAddLayer(keras.layers.Layer): 2228 2229 def __init__(self, increment, **kwargs): 2230 super(MyAddLayer, self).__init__(**kwargs) 2231 self._increment = increment 2232 2233 def call(self, inputs): 2234 return inputs + self._increment 2235 2236 def get_config(self): 2237 config = super(MyAddLayer, self).get_config() 2238 config['increment'] = self._increment 2239 return config 2240 2241 2242class FromKerasFile(TestModels, parameterized.TestCase): 2243 2244 def setUp(self): 2245 super(FromKerasFile, self).setUp() 2246 self._keras_file = None 2247 self._custom_objects = None 2248 if not context.executing_eagerly(): 2249 keras.backend.clear_session() 2250 2251 def tearDown(self): 2252 if self._keras_file: 2253 os.remove(self._keras_file) 2254 super(FromKerasFile, self).tearDown() 2255 2256 def _getSequentialModel(self, include_custom_layer=False): 2257 model = keras.models.Sequential() 2258 model.add(keras.layers.Dense(2, input_shape=(3,))) 2259 if include_custom_layer: 2260 model.add(MyAddLayer(1.0)) 2261 model.add(keras.layers.RepeatVector(3)) 2262 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 2263 model.compile( 2264 loss=keras.losses.MSE, 2265 optimizer='sgd', 2266 metrics=[keras.metrics.categorical_accuracy], 2267 sample_weight_mode='temporal') 2268 x = np.random.random((1, 3)) 2269 y = np.random.random((1, 3, 3)) 2270 model.train_on_batch(x, y) 2271 model.predict(x) 2272 2273 try: 2274 fd, self._keras_file = tempfile.mkstemp('.h5') 2275 keras.models.save_model(model, self._keras_file) 2276 finally: 2277 os.close(fd) 2278 2279 if include_custom_layer: 2280 self._custom_objects = {'MyAddLayer': MyAddLayer} 2281 2282 @parameterized.named_parameters(('_graph', context.graph_mode), 2283 ('_eager', context.eager_mode)) 2284 def testSequentialModel(self, test_context): 2285 """Test a Sequential tf.keras model with default inputs.""" 2286 with test_context(): 2287 self._getSequentialModel() 2288 2289 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2290 tflite_model = converter.convert() 2291 self.assertIsNotNone(tflite_model) 2292 2293 # Check tensor details of converted model. 2294 interpreter = Interpreter(model_content=tflite_model) 2295 interpreter.allocate_tensors() 2296 2297 input_details = interpreter.get_input_details() 2298 self.assertLen(input_details, 1) 2299 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2300 self.assertEqual(np.float32, input_details[0]['dtype']) 2301 self.assertAllEqual([1, 3], input_details[0]['shape']) 2302 self.assertEqual((0., 0.), input_details[0]['quantization']) 2303 2304 output_details = interpreter.get_output_details() 2305 self.assertLen(output_details, 1) 2306 self.assertEqual(np.float32, output_details[0]['dtype']) 2307 self.assertAllEqual([1, 3, 3], output_details[0]['shape']) 2308 self.assertEqual((0., 0.), output_details[0]['quantization']) 2309 2310 # Check inference of converted model. 2311 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2312 interpreter.set_tensor(input_details[0]['index'], input_data) 2313 interpreter.invoke() 2314 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2315 2316 keras_model = keras.models.load_model(self._keras_file) 2317 keras_result = keras_model.predict(input_data) 2318 2319 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2320 2321 @parameterized.named_parameters(('_graph', context.graph_mode), 2322 ('_eager', context.eager_mode)) 2323 def testCustomLayer(self, test_context): 2324 """Test a Sequential tf.keras model with default inputs.""" 2325 with test_context(): 2326 self._getSequentialModel(include_custom_layer=True) 2327 2328 converter = lite.TFLiteConverter.from_keras_model_file( 2329 self._keras_file, custom_objects=self._custom_objects) 2330 tflite_model = converter.convert() 2331 self.assertIsNotNone(tflite_model) 2332 2333 # Check tensor details of converted model. 2334 interpreter = Interpreter(model_content=tflite_model) 2335 interpreter.allocate_tensors() 2336 2337 input_details = interpreter.get_input_details() 2338 output_details = interpreter.get_output_details() 2339 2340 # Check inference of converted model. 2341 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2342 interpreter.set_tensor(input_details[0]['index'], input_data) 2343 interpreter.invoke() 2344 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2345 2346 keras_model = keras.models.load_model( 2347 self._keras_file, custom_objects=self._custom_objects) 2348 keras_result = keras_model.predict(input_data) 2349 2350 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2351 2352 def testSequentialModelInputArray(self): 2353 """Test a Sequential tf.keras model testing input arrays argument.""" 2354 ops.disable_eager_execution() 2355 self._getSequentialModel() 2356 2357 # Invalid input array raises error. 2358 with self.assertRaises(ValueError) as error: 2359 lite.TFLiteConverter.from_keras_model_file( 2360 self._keras_file, input_arrays=['invalid-input']) 2361 self.assertEqual("Invalid tensors 'invalid-input' were found.", 2362 str(error.exception)) 2363 2364 # Valid input array. 2365 converter = lite.TFLiteConverter.from_keras_model_file( 2366 self._keras_file, input_arrays=['dense_input']) 2367 tflite_model = converter.convert() 2368 self.assertIsNotNone(tflite_model) 2369 2370 def testSequentialModelInputShape(self): 2371 """Test a Sequential tf.keras model testing input shapes argument.""" 2372 self._getSequentialModel() 2373 2374 # Passing in shape of invalid input array raises error. 2375 with self.assertRaises(ValueError) as error: 2376 converter = lite.TFLiteConverter.from_keras_model_file( 2377 self._keras_file, input_shapes={'invalid-input': [2, 3]}) 2378 self.assertEqual( 2379 "Invalid tensor 'invalid-input' found in tensor shapes map.", 2380 str(error.exception)) 2381 2382 # Passing in shape of valid input array. 2383 converter = lite.TFLiteConverter.from_keras_model_file( 2384 self._keras_file, input_shapes={'dense_input': [2, 3]}) 2385 tflite_model = converter.convert() 2386 self.assertIsNotNone(tflite_model) 2387 2388 # Check input shape from converted model. 2389 interpreter = Interpreter(model_content=tflite_model) 2390 interpreter.allocate_tensors() 2391 2392 input_details = interpreter.get_input_details() 2393 self.assertLen(input_details, 1) 2394 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2395 self.assertAllEqual([2, 3], input_details[0]['shape']) 2396 2397 def testSequentialModelOutputArray(self): 2398 """Test a Sequential tf.keras model testing output arrays argument.""" 2399 ops.disable_eager_execution() 2400 self._getSequentialModel() 2401 2402 # Invalid output array raises error. 2403 with self.assertRaises(ValueError) as error: 2404 lite.TFLiteConverter.from_keras_model_file( 2405 self._keras_file, output_arrays=['invalid-output']) 2406 self.assertEqual("Invalid tensors 'invalid-output' were found.", 2407 str(error.exception)) 2408 2409 # Valid output array. 2410 converter = lite.TFLiteConverter.from_keras_model_file( 2411 self._keras_file, output_arrays=['time_distributed/Reshape_1']) 2412 tflite_model = converter.convert() 2413 self.assertIsNotNone(tflite_model) 2414 2415 @parameterized.named_parameters(('_graph', context.graph_mode), 2416 ('_eager', context.eager_mode)) 2417 def testFunctionalModel(self, test_context): 2418 """Test a Functional tf.keras model with default inputs.""" 2419 with test_context(): 2420 inputs = keras.layers.Input(shape=(3,), name='input') 2421 x = keras.layers.Dense(2)(inputs) 2422 output = keras.layers.Dense(3)(x) 2423 2424 model = keras.models.Model(inputs, output) 2425 model.compile( 2426 loss=keras.losses.MSE, 2427 optimizer='sgd', 2428 metrics=[keras.metrics.categorical_accuracy]) 2429 x = np.random.random((1, 3)) 2430 y = np.random.random((1, 3)) 2431 model.train_on_batch(x, y) 2432 2433 model.predict(x) 2434 fd, self._keras_file = tempfile.mkstemp('.h5') 2435 try: 2436 keras.models.save_model(model, self._keras_file) 2437 finally: 2438 os.close(fd) 2439 2440 # Convert to TFLite model. 2441 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2442 tflite_model = converter.convert() 2443 self.assertIsNotNone(tflite_model) 2444 2445 # Check tensor details of converted model. 2446 interpreter = Interpreter(model_content=tflite_model) 2447 interpreter.allocate_tensors() 2448 2449 input_details = interpreter.get_input_details() 2450 self.assertLen(input_details, 1) 2451 self.assertEqual('input', input_details[0]['name']) 2452 self.assertEqual(np.float32, input_details[0]['dtype']) 2453 self.assertAllEqual([1, 3], input_details[0]['shape']) 2454 self.assertEqual((0., 0.), input_details[0]['quantization']) 2455 2456 output_details = interpreter.get_output_details() 2457 self.assertLen(output_details, 1) 2458 self.assertEqual(np.float32, output_details[0]['dtype']) 2459 self.assertAllEqual([1, 3], output_details[0]['shape']) 2460 self.assertEqual((0., 0.), output_details[0]['quantization']) 2461 2462 # Check inference of converted model. 2463 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2464 interpreter.set_tensor(input_details[0]['index'], input_data) 2465 interpreter.invoke() 2466 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2467 2468 keras_model = keras.models.load_model(self._keras_file) 2469 keras_result = keras_model.predict(input_data) 2470 2471 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2472 2473 def _getFunctionalModelMultipleInputs(self): 2474 a = keras.layers.Input(shape=(3,), name='input_a') 2475 b = keras.layers.Input(shape=(3,), name='input_b') 2476 dense = keras.layers.Dense(4, name='dense') 2477 c = dense(a) 2478 d = dense(b) 2479 e = keras.layers.Dropout(0.5, name='dropout')(c) 2480 2481 model = keras.models.Model([a, b], [d, e]) 2482 model.compile( 2483 loss=keras.losses.MSE, 2484 optimizer='sgd', 2485 metrics=[keras.metrics.mae], 2486 loss_weights=[1., 0.5]) 2487 2488 input_a_np = np.random.random((10, 3)) 2489 input_b_np = np.random.random((10, 3)) 2490 output_d_np = np.random.random((10, 4)) 2491 output_e_np = np.random.random((10, 4)) 2492 model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) 2493 2494 model.predict([input_a_np, input_b_np], batch_size=5) 2495 fd, self._keras_file = tempfile.mkstemp('.h5') 2496 try: 2497 keras.models.save_model(model, self._keras_file) 2498 finally: 2499 os.close(fd) 2500 2501 def testFunctionalModelMultipleInputs(self): 2502 """Test a Functional tf.keras model with multiple inputs and outputs.""" 2503 self._getFunctionalModelMultipleInputs() 2504 2505 # Convert to TFLite model. 2506 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2507 tflite_model = converter.convert() 2508 self.assertIsNotNone(tflite_model) 2509 2510 # Check values from converted model. 2511 interpreter = Interpreter(model_content=tflite_model) 2512 interpreter.allocate_tensors() 2513 2514 input_details = interpreter.get_input_details() 2515 self.assertLen(input_details, 2) 2516 self.assertEndsWith(input_details[0]['name'], 'input_a') 2517 self.assertEqual(np.float32, input_details[0]['dtype']) 2518 self.assertAllEqual([1, 3], input_details[0]['shape']) 2519 self.assertEqual((0., 0.), input_details[0]['quantization']) 2520 2521 self.assertEndsWith(input_details[1]['name'], 'input_b') 2522 self.assertEqual(np.float32, input_details[1]['dtype']) 2523 self.assertAllEqual([1, 3], input_details[1]['shape']) 2524 self.assertEqual((0., 0.), input_details[1]['quantization']) 2525 2526 output_details = interpreter.get_output_details() 2527 self.assertLen(output_details, 2) 2528 self.assertEqual(np.float32, output_details[0]['dtype']) 2529 self.assertAllEqual([1, 4], output_details[0]['shape']) 2530 self.assertEqual((0., 0.), output_details[0]['quantization']) 2531 2532 self.assertEqual(np.float32, output_details[1]['dtype']) 2533 self.assertAllEqual([1, 4], output_details[1]['shape']) 2534 self.assertEqual((0., 0.), output_details[1]['quantization']) 2535 2536 def testShapeOverriding(self): 2537 """Test a Functional tf.keras model with input shape overriding.""" 2538 self._getFunctionalModelMultipleInputs() 2539 2540 # Convert to TFLite model. 2541 converter = lite.TFLiteConverter.from_keras_model_file( 2542 self._keras_file, input_shapes={ 2543 'input_a': {2, 3}, 2544 'input_b': {2, 3} 2545 }) 2546 tflite_model = converter.convert() 2547 self.assertIsNotNone(tflite_model) 2548 2549 # Check values from converted model. 2550 interpreter = Interpreter(model_content=tflite_model) 2551 interpreter.allocate_tensors() 2552 2553 input_details = interpreter.get_input_details() 2554 self.assertLen(input_details, 2) 2555 self.assertEndsWith(input_details[0]['name'], 'input_a') 2556 self.assertEqual(np.float32, input_details[0]['dtype']) 2557 self.assertAllEqual([2, 3], input_details[0]['shape']) 2558 self.assertEqual((0., 0.), input_details[0]['quantization']) 2559 2560 self.assertEndsWith(input_details[1]['name'], 'input_b') 2561 self.assertEqual(np.float32, input_details[1]['dtype']) 2562 self.assertAllEqual([2, 3], input_details[1]['shape']) 2563 self.assertEqual((0., 0.), input_details[1]['quantization']) 2564 2565 output_details = interpreter.get_output_details() 2566 self.assertLen(output_details, 2) 2567 self.assertEqual(np.float32, output_details[0]['dtype']) 2568 self.assertAllEqual([2, 4], output_details[0]['shape']) 2569 self.assertEqual((0., 0.), output_details[0]['quantization']) 2570 2571 self.assertEqual(np.float32, output_details[1]['dtype']) 2572 self.assertAllEqual([2, 4], output_details[1]['shape']) 2573 self.assertEqual((0., 0.), output_details[1]['quantization']) 2574 2575 def testPartialShapeOverriding(self): 2576 """Test a Functional tf.keras model with partial input shape overriding.""" 2577 self._getFunctionalModelMultipleInputs() 2578 2579 # Convert to TFLite model. 2580 converter = lite.TFLiteConverter.from_keras_model_file( 2581 self._keras_file, input_shapes={'input_a': {2, 3}}) 2582 tflite_model = converter.convert() 2583 self.assertIsNotNone(tflite_model) 2584 2585 # Check values from converted model. 2586 interpreter = Interpreter(model_content=tflite_model) 2587 interpreter.allocate_tensors() 2588 2589 input_details = interpreter.get_input_details() 2590 self.assertLen(input_details, 2) 2591 self.assertEndsWith(input_details[0]['name'], 'input_a') 2592 self.assertEqual(np.float32, input_details[0]['dtype']) 2593 self.assertAllEqual([2, 3], input_details[0]['shape']) 2594 self.assertEqual((0., 0.), input_details[0]['quantization']) 2595 2596 self.assertEndsWith(input_details[1]['name'], 'input_b') 2597 self.assertEqual(np.float32, input_details[1]['dtype']) 2598 self.assertAllEqual([1, 3], input_details[1]['shape']) 2599 self.assertEqual((0., 0.), input_details[1]['quantization']) 2600 2601 output_details = interpreter.get_output_details() 2602 self.assertLen(output_details, 2) 2603 self.assertEqual(np.float32, output_details[0]['dtype']) 2604 self.assertAllEqual([1, 4], output_details[0]['shape']) 2605 self.assertEqual((0., 0.), output_details[0]['quantization']) 2606 2607 self.assertEqual(np.float32, output_details[1]['dtype']) 2608 self.assertAllEqual([2, 4], output_details[1]['shape']) 2609 self.assertEqual((0., 0.), output_details[1]['quantization']) 2610 2611 def testWrongShapeOverriding(self): 2612 """Test a Functional tf.keras model with wrong input shape overriding.""" 2613 self._getFunctionalModelMultipleInputs() 2614 2615 # Convert to TFLite model. 2616 with self.assertRaises(ValueError): 2617 lite.TFLiteConverter.from_keras_model_file( 2618 self._keras_file, input_shapes={'wrong_input': {2, 3}}) 2619 2620 def testFunctionalSequentialModel(self): 2621 """Test a Functional tf.keras model containing a Sequential model.""" 2622 model = keras.models.Sequential() 2623 model.add(keras.layers.Dense(2, input_shape=(3,))) 2624 model.add(keras.layers.RepeatVector(3)) 2625 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 2626 model = keras.models.Model(model.input, model.output) 2627 2628 model.compile( 2629 loss=keras.losses.MSE, 2630 optimizer='sgd', 2631 metrics=[keras.metrics.categorical_accuracy], 2632 sample_weight_mode='temporal') 2633 x = np.random.random((1, 3)) 2634 y = np.random.random((1, 3, 3)) 2635 model.train_on_batch(x, y) 2636 model.predict(x) 2637 2638 model.predict(x) 2639 fd, self._keras_file = tempfile.mkstemp('.h5') 2640 try: 2641 keras.models.save_model(model, self._keras_file) 2642 finally: 2643 os.close(fd) 2644 2645 # Convert to TFLite model. 2646 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2647 tflite_model = converter.convert() 2648 self.assertIsNotNone(tflite_model) 2649 2650 # Check tensor details of converted model. 2651 interpreter = Interpreter(model_content=tflite_model) 2652 interpreter.allocate_tensors() 2653 2654 input_details = interpreter.get_input_details() 2655 self.assertLen(input_details, 1) 2656 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2657 self.assertEqual(np.float32, input_details[0]['dtype']) 2658 self.assertAllEqual([1, 3], input_details[0]['shape']) 2659 self.assertEqual((0., 0.), input_details[0]['quantization']) 2660 2661 output_details = interpreter.get_output_details() 2662 self.assertLen(output_details, 1) 2663 self.assertEqual(np.float32, output_details[0]['dtype']) 2664 self.assertAllEqual([1, 3, 3], output_details[0]['shape']) 2665 self.assertEqual((0., 0.), output_details[0]['quantization']) 2666 2667 # Check inference of converted model. 2668 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2669 interpreter.set_tensor(input_details[0]['index'], input_data) 2670 interpreter.invoke() 2671 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2672 2673 keras_model = keras.models.load_model(self._keras_file) 2674 keras_result = keras_model.predict(input_data) 2675 2676 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2677 2678 def testSequentialModelTocoConverter(self): 2679 """Test a Sequential tf.keras model with deprecated TocoConverter.""" 2680 self._getSequentialModel() 2681 2682 converter = lite.TocoConverter.from_keras_model_file(self._keras_file) 2683 tflite_model = converter.convert() 2684 self.assertIsNotNone(tflite_model) 2685 2686 # Ensure the model is able to load. 2687 interpreter = Interpreter(model_content=tflite_model) 2688 interpreter.allocate_tensors() 2689 2690 @parameterized.named_parameters(('_graph', context.graph_mode), 2691 ('_eager', context.eager_mode)) 2692 def testGraphDebugInfo(self, test_context): 2693 """Test a Sequential tf.keras model has debug info captured.""" 2694 with test_context(): 2695 self._getSequentialModel() 2696 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2697 converter.convert() 2698 self.assertValidDebugInfo(converter._debug_info) 2699 2700 2701class SparsityTest(TestModels): 2702 2703 def _getSparsificableModel(self, matrix_b_values): 2704 with ops.Graph().as_default(): 2705 in_tensor_1 = array_ops.placeholder( 2706 shape=[16, 4], dtype=dtypes.float32, name='input1') 2707 in_tensor_2 = constant_op.constant( 2708 matrix_b_values, shape=[4, 8], dtype=dtypes.float32) 2709 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2) 2710 sess = session.Session() 2711 2712 return (sess, [in_tensor_1], [out_tensor]) 2713 2714 def testRandomSparsity(self): 2715 matrix_b_values = [ 2716 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 2717 0, 0, 0, 0, 0, 0, 0, 1 2718 ] 2719 sess, inputs, outputs = self._getSparsificableModel(matrix_b_values) 2720 float_converter = lite.TFLiteConverter.from_session(sess, inputs, outputs) 2721 float_converter.optimizations = [lite.Optimize.EXPERIMENTAL_SPARSITY] 2722 float_tflite_model = float_converter.convert() 2723 self.assertIsNotNone(float_tflite_model) 2724 # Check the conversion metadata. 2725 metadata = get_conversion_metadata(float_tflite_model) 2726 self.assertIsNotNone(metadata) 2727 self.assertAllEqual([metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY], 2728 metadata.options.modelOptimizationModes) 2729 2730 def testSparsifyModel(self): 2731 matrix_b_values = [ 2732 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2733 0, 0, 0, 0, 0, 0, 1, 0 2734 ] 2735 sess, inputs, outputs = self._getSparsificableModel(matrix_b_values) 2736 converter = lite.TFLiteConverter.from_session(sess, inputs, outputs) 2737 converter.optimizations = {lite.Optimize.EXPERIMENTAL_SPARSITY} 2738 tflite_model = converter.convert() 2739 self.assertTrue(tflite_model) 2740 # Check the conversion metadata. 2741 metadata = get_conversion_metadata(tflite_model) 2742 self.assertIsNotNone(metadata) 2743 self.assertAllEqual([ 2744 metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY, 2745 ], metadata.options.modelOptimizationModes) 2746 2747 def testSparsifyQuantizedModel(self): 2748 matrix_b_values = [ 2749 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2750 0, 0, 0, 0, 0, 0, 1, 0 2751 ] 2752 sess, inputs, outputs = self._getSparsificableModel(matrix_b_values) 2753 converter = lite.TFLiteConverter.from_session(sess, inputs, outputs) 2754 converter.optimizations = { 2755 lite.Optimize.DEFAULT, lite.Optimize.EXPERIMENTAL_SPARSITY 2756 } 2757 tflite_model = converter.convert() 2758 self.assertIsNotNone(tflite_model) 2759 # Check the conversion metadata. 2760 metadata = get_conversion_metadata(tflite_model) 2761 self.assertIsNotNone(metadata) 2762 self.assertAllEqual([ 2763 metadata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE, 2764 metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY, 2765 ], metadata.options.modelOptimizationModes) 2766 2767 2768class GrapplerTest(TestModels, parameterized.TestCase): 2769 2770 def testConstantFolding(self): 2771 ops.disable_eager_execution() 2772 # Constant folding handles the tf.broadcast_to operation which was not 2773 # supported by the TFLite at the time this test was added. 2774 with ops.Graph().as_default(): 2775 in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32) 2776 y_const = constant_op.constant([1., 2., 3.]) 2777 y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3]) 2778 out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output') 2779 sess = session.Session() 2780 2781 # Convert model. 2782 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2783 [out_tensor]) 2784 tflite_model = converter.convert() 2785 2786 # Check values from converted model. 2787 interpreter = Interpreter(model_content=tflite_model) 2788 interpreter.allocate_tensors() 2789 2790 input_details = interpreter.get_input_details() 2791 self.assertLen(input_details, 1) 2792 self.assertEqual('Placeholder', input_details[0]['name']) 2793 self.assertEqual(np.float32, input_details[0]['dtype']) 2794 self.assertAllEqual([3, 3], input_details[0]['shape']) 2795 2796 output_details = interpreter.get_output_details() 2797 self.assertLen(output_details, 1) 2798 self.assertEqual('output', output_details[0]['name']) 2799 self.assertEqual(np.float32, output_details[0]['dtype']) 2800 self.assertAllEqual([3, 3], output_details[0]['shape']) 2801 2802 def testInputNodeIsNotFolded(self): 2803 ops.disable_eager_execution() 2804 # Constant folding handles the tf.broadcast_to operation which was not 2805 # supported by the TFLite at the time this test was added. 2806 with ops.Graph().as_default(): 2807 in_tensor = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 2808 y_const = constant_op.constant([1., 2., 3.]) 2809 y_add = y_const + y_const 2810 out_tensor = in_tensor * y_add 2811 sess = session.Session() 2812 2813 # Convert model. 2814 converter = lite.TFLiteConverter.from_session(sess, [in_tensor, y_const], 2815 [out_tensor]) 2816 tflite_model = converter.convert() 2817 2818 # Check values from converted model. 2819 interpreter = Interpreter(model_content=tflite_model) 2820 interpreter.allocate_tensors() 2821 2822 input_details = interpreter.get_input_details() 2823 self.assertLen(input_details, 2) 2824 self.assertEqual('Placeholder', input_details[0]['name']) 2825 self.assertEqual('Const', input_details[1]['name']) 2826 2827 def testGrapplerConstFolding(self): 2828 # Constant folding converts the following add operation to tf.broadcast_to 2829 # operation which was not supported by the TFLite at the time this test was 2830 # added. 2831 @def_function.function 2832 def plus_placeholder(x, placeholder): 2833 return x + placeholder 2834 2835 with ops.Graph().as_default(): 2836 in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32) 2837 out_tensor = plus_placeholder( 2838 array_ops.zeros([2, 2, 2]), 2839 array_ops.reshape(in_tensor, shape=[2, 2])) 2840 sess = session.Session() 2841 2842 # Convert model. 2843 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2844 [out_tensor]) 2845 tflite_model = converter.convert() 2846 2847 # Check values from converted model. 2848 interpreter = Interpreter(model_content=tflite_model) 2849 interpreter.allocate_tensors() 2850 2851 input_details = interpreter.get_input_details() 2852 self.assertLen(input_details, 1) 2853 self.assertEqual('Placeholder', input_details[0]['name']) 2854 2855 2856class DefaultConverterAttrsTest(LiteTest): 2857 2858 def testAttrs(self): 2859 with ops.Graph().as_default(): 2860 in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32) 2861 out_tensor = in_tensor + in_tensor 2862 sess = session.Session() 2863 2864 # Convert model. 2865 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2866 [out_tensor]) 2867 2868 # Assert output format. 2869 self.assertEqual(converter.output_format, lite_constants.TFLITE) 2870 2871 # Assert the default inference type is float. 2872 self.assertEqual(converter.inference_type, dtypes.float32) 2873 2874 # Assert the default inference type overrides are None. 2875 self.assertIsNone(converter.inference_input_type) 2876 self.assertIsNone(converter.inference_output_type) 2877 2878 # Assert the default quantization options are not set. 2879 self.assertEqual(converter.quantized_input_stats, {}) 2880 self.assertIsNone(converter.default_ranges_stats) 2881 self.assertFalse(converter.reorder_across_fake_quant) 2882 self.assertFalse(converter.change_concat_input_ranges) 2883 2884 # Assert dropping control dependency is enabled by default. 2885 self.assertIsNotNone(converter.drop_control_dependency) 2886 2887 # Assert dumping extra information is disabled by default. 2888 self.assertIsNone(converter.dump_graphviz_dir) 2889 self.assertFalse(converter.dump_graphviz_video) 2890 self.assertIsNone(converter.conversion_summary_dir) 2891 2892 2893class ControlFlowV1OpsTest(LiteTest): 2894 2895 def testConverterErrorOnControlFlowV1Ops(self): 2896 graph_def_file = resource_loader.get_path_to_datafile( 2897 'testdata/control_flow_v1.pbtxt') 2898 input_arrays = ['a', 'b', 'c', 'd'] 2899 output_arrays = ['Merge'] 2900 2901 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 2902 input_arrays, 2903 output_arrays) 2904 with self.assertRaises(ConverterError) as error: 2905 converter.convert() 2906 self.assertIn( 2907 'Failed to functionalize Control Flow V1 ops. Consider using Control ' 2908 'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/' 2909 'tf/compat/v1/enable_control_flow_v2.', str(error.exception)) 2910 2911 2912class QuantizationModeTest(LiteTest, parameterized.TestCase): 2913 2914 @parameterized.named_parameters( 2915 ('size', lite.Optimize.OPTIMIZE_FOR_SIZE), 2916 ('latency', lite.Optimize.OPTIMIZE_FOR_LATENCY)) 2917 def testDeprecatedOptionWarning(self, optimization): 2918 """Test if the warning message when using TOCO is logged.""" 2919 log = io.StringIO() 2920 handler = logging.StreamHandler(log) 2921 logging.root.addHandler(handler) 2922 warning_message = 'please use optimizations=[Optimize.DEFAULT] instead.' 2923 lite.QuantizationMode([optimization], lite.TargetSpec(), None, None) 2924 self.assertIn(warning_message, log.getvalue()) 2925 logging.root.removeHandler(handler) 2926 2927 2928if __name__ == '__main__': 2929 test.main() 2930