xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/python/lite_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lite.py."""
16
17import io
18import logging
19import os
20import tempfile
21
22from absl.testing import parameterized
23import numpy as np
24from tensorflow import keras
25
26from tensorflow.lite.python import conversion_metadata_schema_py_generated as metadata_fb
27from tensorflow.lite.python import lite
28from tensorflow.lite.python import lite_constants
29from tensorflow.lite.python import schema_py_generated as schema_fb
30from tensorflow.lite.python import util
31from tensorflow.lite.python.convert import ConverterError
32from tensorflow.lite.python.convert import mlir_quantize
33from tensorflow.lite.python.interpreter import Interpreter
34from tensorflow.lite.python.util import get_conversion_metadata
35from tensorflow.python.client import session
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import convert_to_constants
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import test_util
43from tensorflow.python.framework import versions
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import gen_array_ops
46from tensorflow.python.ops import logging_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import nn_ops
49from tensorflow.python.ops import random_ops
50from tensorflow.python.ops import variable_scope
51from tensorflow.python.ops import variables
52from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
53from tensorflow.python.platform import gfile
54from tensorflow.python.platform import resource_loader
55from tensorflow.python.platform import test
56from tensorflow.python.saved_model import saved_model
57from tensorflow.python.training.training_util import write_graph
58
59
60class LiteTest(test_util.TensorFlowTestCase):
61  """Base class of all the tests in this module."""
62
63
64class TestModels(LiteTest):
65
66  def assertValidDebugInfo(self, debug_info):
67    """Verify the DebugInfo is valid."""
68    file_names = set()
69    for file_path in debug_info.files:
70      file_names.add(os.path.basename(file_path))
71    # To make the test independent on how the nodes are created, we only assert
72    # the name of this test file.
73    self.assertIn('lite_test.py', file_names)
74    self.assertNotIn('lite_v2_test.py', file_names)
75
76
77class FromConstructor(TestModels):
78
79  # Tests invalid constructors using a dummy value for the GraphDef.
80  def testInvalidConstructor(self):
81    message = (
82        'If input_tensors and output_tensors are None, both '
83        'input_arrays_with_shape and output_arrays|control_output_arrays must '
84        'be defined.')
85
86    # `output_arrays` is not defined.
87    with self.assertRaises(ValueError) as error:
88      lite.TFLiteConverter(
89          None, None, [], input_arrays_with_shape=[('input', [3,
90                                                              9])]).convert()
91    self.assertEqual(message, str(error.exception))
92
93    # `input_arrays_with_shape` is not defined.
94    with self.assertRaises(ValueError) as error:
95      lite.TFLiteConverter(None, [], None, output_arrays=['output']).convert()
96    self.assertEqual(message, str(error.exception))
97
98  # Tests valid constructors using a dummy value for the GraphDef.
99  def testValidConstructor(self):
100    converter = lite.TFLiteConverter(
101        None,
102        None,
103        None,
104        input_arrays_with_shape=[('input', [3, 9])],
105        output_arrays=['output'])
106    self.assertFalse(converter._has_valid_tensors())
107    self.assertEqual(converter.get_input_arrays(), ['input'])
108
109    with self.assertRaises(ValueError) as error:
110      converter._set_batch_size(1)
111    self.assertEqual(
112        'The batch size cannot be set for this model. Please use '
113        'input_shapes parameter.', str(error.exception))
114
115    converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
116    self.assertTrue(converter._has_valid_tensors())
117
118  def testRedundantArgumentsWarning(self):
119    """Test if the warning message when there are redundant arguments."""
120    with ops.Graph().as_default():
121      in_tensor = array_ops.placeholder(
122          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
123      out_tensor = math_ops.add(in_tensor, in_tensor, name='add')
124      sess = session.Session()
125
126    frozen_graph_def = (
127        convert_to_constants.convert_variables_to_constants_from_session_graph(
128            sess, sess.graph_def, ['add']))
129
130    # Convert model and ensure model is not None.
131    log = io.StringIO()
132    handler = logging.StreamHandler(log)
133    logging.root.addHandler(handler)
134    converter = lite.TFLiteConverter(frozen_graph_def, [in_tensor],
135                                     [out_tensor],
136                                     [('in_tensor', [2, 16, 16, 3])], ['add'])
137
138    input_warning_message = 'input_arrays_with_shape will be ignored'
139    output_warning_message = 'output_arrays will be ignored'
140
141    # Convert model and ensure model is not None.
142    tflite_model = converter.convert()
143    self.assertIsNotNone(tflite_model)
144    self.assertIn(input_warning_message, log.getvalue())
145    self.assertIn(output_warning_message, log.getvalue())
146    logging.root.removeHandler(handler)
147
148  def testShapeOverriding(self):
149    """Test a shape overriding case via the constructor."""
150    with ops.Graph().as_default():
151      in_tensor = array_ops.placeholder(
152          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
153      math_ops.add(in_tensor, in_tensor, name='add')
154      sess = session.Session()
155
156    frozen_graph_def = (
157        convert_to_constants.convert_variables_to_constants_from_session_graph(
158            sess, sess.graph_def, ['add']))
159
160    # Convert model and ensure model is not None.
161    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
162                                     [('in_tensor', [2, 16, 16, 3])], ['add'])
163    tflite_model = converter.convert()
164    self.assertIsNotNone(tflite_model)
165
166    # Check values from converted model.
167    interpreter = Interpreter(model_content=tflite_model)
168    interpreter.allocate_tensors()
169
170    input_details = interpreter.get_input_details()
171    self.assertLen(input_details, 1)
172    self.assertEqual('in_tensor', input_details[0]['name'])
173    self.assertEqual(np.float32, input_details[0]['dtype'])
174    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
175    self.assertEqual((0., 0.), input_details[0]['quantization'])
176
177    output_details = interpreter.get_output_details()
178    self.assertLen(output_details, 1)
179    self.assertEqual('add', output_details[0]['name'])
180    self.assertEqual(np.float32, output_details[0]['dtype'])
181    self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape'])
182    self.assertEqual((0., 0.), output_details[0]['quantization'])
183
184  def testPartialShapeOverriding(self):
185    """Test a partial shape overriding case via the constructor."""
186    with ops.Graph().as_default():
187      in_tensor_a = array_ops.placeholder(
188          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_a')
189      in_tensor_b = array_ops.placeholder(
190          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_b')
191      math_ops.add(in_tensor_a, in_tensor_b, name='add')
192      sess = session.Session()
193
194    frozen_graph_def = (
195        convert_to_constants.convert_variables_to_constants_from_session_graph(
196            sess, sess.graph_def, ['add']))
197
198    # Convert model and ensure model is not None.
199    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
200                                     [('in_tensor_a', [2, 16, 16, 3])], ['add'])
201    # There is an unhandled Placeholder op.
202    with self.assertRaises(ConverterError):
203      converter.convert()
204
205  def testInvalidShapeOverriding(self):
206    """Test an invalid shape overriding case via the constructor."""
207    with ops.Graph().as_default():
208      in_tensor = array_ops.placeholder(
209          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
210      math_ops.add(in_tensor, in_tensor, name='add')
211      sess = session.Session()
212
213    frozen_graph_def = (
214        convert_to_constants.convert_variables_to_constants_from_session_graph(
215            sess, sess.graph_def, ['add']))
216
217    # Convert model and ensure model is not None.
218    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
219                                     [('wrong_tensor', [2, 16, 16, 3])],
220                                     ['add'])
221    with self.assertRaises(ConverterError):
222      converter.convert()
223
224
225class FromSessionTest(TestModels, parameterized.TestCase):
226
227  def testFloatModel(self):
228    with ops.Graph().as_default():
229      in_tensor = array_ops.placeholder(
230          shape=[1, 16, 16, 3], dtype=dtypes.float32)
231      out_tensor = in_tensor + in_tensor
232      sess = session.Session()
233
234    # Convert model and ensure model is not None.
235    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
236                                                  [out_tensor])
237    tflite_model = converter.convert()
238    self.assertIsNotNone(tflite_model)
239
240    # Check values from converted model.
241    interpreter = Interpreter(model_content=tflite_model)
242    interpreter.allocate_tensors()
243
244    input_details = interpreter.get_input_details()
245    self.assertLen(input_details, 1)
246    self.assertEqual('Placeholder', input_details[0]['name'])
247    self.assertEqual(np.float32, input_details[0]['dtype'])
248    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
249    self.assertEqual((0., 0.), input_details[0]['quantization'])
250
251    output_details = interpreter.get_output_details()
252    self.assertLen(output_details, 1)
253    self.assertEqual('add', output_details[0]['name'])
254    self.assertEqual(np.float32, output_details[0]['dtype'])
255    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
256    self.assertEqual((0., 0.), output_details[0]['quantization'])
257
258  def testFloatModelQuantizedInput(self):
259    with ops.Graph().as_default():
260      in_tensor = array_ops.placeholder(
261          shape=[1, 16, 16, 3], dtype=dtypes.float32)
262      out_tensor = in_tensor + in_tensor
263      sess = session.Session()
264
265    # Convert model and ensure model is not None.
266    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
267                                                  [out_tensor])
268    converter.inference_input_type = dtypes.uint8
269    converter.inference_type = dtypes.float32
270    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
271    tflite_model = converter.convert()
272    self.assertIsNotNone(tflite_model)
273
274    # Check values from converted model.
275    interpreter = Interpreter(model_content=tflite_model)
276    interpreter.allocate_tensors()
277
278    input_details = interpreter.get_input_details()
279    self.assertLen(input_details, 1)
280    self.assertEqual('Placeholder', input_details[0]['name'])
281    self.assertEqual(np.uint8, input_details[0]['dtype'])
282    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
283    self.assertEqual((1., 0.), input_details[0]['quantization'])
284
285    output_details = interpreter.get_output_details()
286    self.assertLen(output_details, 1)
287    self.assertEqual('add', output_details[0]['name'])
288    self.assertEqual(np.float32, output_details[0]['dtype'])
289    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
290    self.assertEqual((0., 0.), output_details[0]['quantization'])  # float
291
292  def testForgottenCallToAllocateTensors(self):
293    with ops.Graph().as_default():
294      in_tensor = array_ops.placeholder(
295          shape=[1, 16, 16, 3], dtype=dtypes.float32)
296      out_tensor = in_tensor + in_tensor
297      sess = session.Session()
298    # Convert model and ensure model is not None.
299    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
300                                                  [out_tensor])
301    tflite_model = converter.convert()
302    self.assertIsNotNone(tflite_model)
303
304    # Check values from converted model.
305    interpreter = Interpreter(model_content=tflite_model)
306    input_index = interpreter.get_input_details()[0]['index']
307    dummy_tensor = np.ones(shape=[1, 16, 16, 3], dtype=np.float32)
308    with self.assertRaises(ValueError):
309      interpreter.set_tensor(input_index, dummy_tensor)
310
311  @parameterized.named_parameters(
312      ('_INT8InputOutput', False, False, dtypes.int8),
313      ('_UINT8InputOutput', False, False, dtypes.uint8),
314      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
315      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
316      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
317      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
318      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
319      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
320  def testIntegerQuantizationWithUnsupportedOps(self,
321                                                is_int_only,
322                                                is_int16_quantize,
323                                                inference_input_output_type,
324                                                enable_mlir_quantizer=False):
325    with ops.Graph().as_default():
326      in_tensor_a = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
327      in_tensor_b = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
328      # ceil kernel does not support int8 nor int16 types neither.
329      left = math_ops.ceil(in_tensor_a)
330      out_tensor_b = math_ops.tanh(in_tensor_b)
331      add = math_ops.add(left, out_tensor_b)
332      # ceil kernel does not support int8 nor int16 types neither.
333      out_tensor_a = math_ops.ceil(add)
334      sess = session.Session()
335
336    def calibration_gen():
337      for _ in range(5):
338        yield [
339            np.random.uniform(-1, 1, size=(3)).astype(np.float32),
340            np.random.uniform(-1, 1, size=(3)).astype(np.float32)
341        ]
342
343    quantized_converter = lite.TFLiteConverter.from_session(
344        sess, [in_tensor_a, in_tensor_b], [out_tensor_a, out_tensor_b])
345    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
346    quantized_converter.representative_dataset = calibration_gen
347    if is_int_only:
348      if is_int16_quantize:
349        quantized_converter.target_spec.supported_ops = [
350            lite.OpsSet
351            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
352            lite.OpsSet.TFLITE_BUILTINS
353        ]
354      else:
355        quantized_converter.target_spec.supported_ops = [
356            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
357        ]
358    else:
359      if is_int16_quantize:
360        quantized_converter.target_spec.supported_ops = [
361            lite.OpsSet
362            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
363            lite.OpsSet.TFLITE_BUILTINS
364        ]
365      else:
366        quantized_converter.target_spec.supported_ops = [
367            lite.OpsSet.TFLITE_BUILTINS
368        ]
369
370    quantized_converter.inference_input_type = inference_input_output_type
371    quantized_converter.inference_output_type = inference_input_output_type
372    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
373    quantized_tflite_model = quantized_converter.convert()
374    self.assertIsNotNone(quantized_tflite_model)
375
376    expected_dtype = inference_input_output_type.as_numpy_dtype
377    # Allow float32 for fallback on non-quantizable op.
378    expected_ceil_dtype = (
379        expected_dtype if enable_mlir_quantizer else dtypes.float32)
380
381    interpreter = Interpreter(model_content=quantized_tflite_model)
382    interpreter.allocate_tensors()
383    input_details = interpreter.get_input_details()
384    self.assertLen(input_details, 2)
385    self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype)
386    self.assertEqual(input_details[1]['dtype'], expected_dtype)
387    output_details = interpreter.get_output_details()
388    self.assertLen(output_details, 2)
389    self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
390    self.assertEqual(output_details[1]['dtype'], expected_dtype)
391
392  @parameterized.named_parameters(
393      ('_PerChannelQuant', False, False),
394      ('_PerChannelMlirQuant', False, True),
395      ('_PerTensorQuant', True, False),
396      ('_PerTensorMlirQuant', True, True),
397      ('_PerChannelMlirDynamicRangeQuant', False, False, False),
398      ('_PerTensorMlirDynamicRangeQuant', True, False, False))
399  def testDisablePerChannelQuantization(self,
400                                        disable_per_channel=False,
401                                        enable_mlir_quantizer=False,
402                                        representative_dataset=True):
403    k_conv_name = 'Conv2D1'
404    # Dynamic range quant requires total num elements of filters > 1024.
405    k_num_filters = 38
406    with ops.Graph().as_default():
407      inp, output, calibration_gen = self._getIntegerQuantizeModel(
408          k_num_filters)
409      sess = session.Session()
410
411    quantized_converter = lite.TFLiteConverter.from_session(
412        sess, [inp], [output])
413    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
414    if representative_dataset:
415      quantized_converter.representative_dataset = calibration_gen
416    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
417    if disable_per_channel:
418      quantized_converter._experimental_disable_per_channel = (
419          disable_per_channel)
420    quantized_tflite_model = quantized_converter.convert()
421    self.assertIsNotNone(quantized_tflite_model)
422
423    interpreter = Interpreter(model_content=quantized_tflite_model)
424    interpreter.allocate_tensors()
425    detail = next((d for d in interpreter.get_tensor_details()
426                   if d['name'] == k_conv_name))
427    quant_params = detail['quantization_parameters']
428    expected_num_params = 1 if disable_per_channel else k_num_filters
429    self.assertLen(quant_params['scales'], expected_num_params)
430    self.assertLen(quant_params['zero_points'], expected_num_params)
431
432  def testString(self):
433    with ops.Graph().as_default():
434      in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
435      out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
436      sess = session.Session()
437
438    # Convert model and ensure model is not None.
439    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
440                                                  [out_tensor])
441    tflite_model = converter.convert()
442    self.assertIsNotNone(tflite_model)
443
444    # Check values from converted model.
445    interpreter = Interpreter(model_content=tflite_model)
446    interpreter.allocate_tensors()
447
448    input_details = interpreter.get_input_details()
449    self.assertLen(input_details, 1)
450    self.assertEqual('Placeholder', input_details[0]['name'])
451    self.assertEqual(np.string_, input_details[0]['dtype'])
452    self.assertAllEqual([4], input_details[0]['shape'])
453
454    output_details = interpreter.get_output_details()
455    self.assertLen(output_details, 1)
456    self.assertEqual('Reshape', output_details[0]['name'])
457    self.assertEqual(np.string_, output_details[0]['dtype'])
458    self.assertAllEqual([2, 2], output_details[0]['shape'])
459    # TODO(b/122659643): Test setting/getting string data via the python
460    # interpreter API after support has been added.
461
462  def testIntermediateInputArray(self):
463    """Convert a model from an intermediate input array."""
464    with ops.Graph().as_default():
465      in_tensor_init = array_ops.placeholder(
466          shape=[1, 16, 16, 3], dtype=dtypes.float32)
467      in_tensor_final = in_tensor_init + in_tensor_init
468      out_tensor = in_tensor_final + in_tensor_final
469      sess = session.Session()
470
471    # Convert model and ensure model is not None.
472    converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final],
473                                                  [out_tensor])
474    tflite_model = converter.convert()
475    self.assertIsNotNone(tflite_model)
476
477    # Check values from converted model.
478    interpreter = Interpreter(model_content=tflite_model)
479    interpreter.allocate_tensors()
480
481    input_details = interpreter.get_input_details()
482    self.assertLen(input_details, 1)
483    self.assertEqual('add', input_details[0]['name'])
484    self.assertEqual(np.float32, input_details[0]['dtype'])
485    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
486    self.assertEqual((0., 0.), input_details[0]['quantization'])
487
488    output_details = interpreter.get_output_details()
489    self.assertLen(output_details, 1)
490    self.assertEqual('add_1', output_details[0]['name'])
491    self.assertEqual(np.float32, output_details[0]['dtype'])
492    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
493    self.assertEqual((0., 0.), output_details[0]['quantization'])
494
495  def testSizeNoneInvalid(self):
496    with ops.Graph().as_default():
497      in_tensor = array_ops.placeholder(dtype=dtypes.float32)
498      out_tensor = in_tensor + in_tensor
499      sess = session.Session()
500
501    # Test None as shape when dynamic shapes are disabled. Run with TOCO in
502    # order to invoke shape checking code.
503    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
504                                                  [out_tensor])
505    converter.experimental_new_converter = False
506    with self.assertRaises(ValueError) as error:
507      converter.convert()
508    self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
509                     str(error.exception))
510
511  def testScalarValid(self):
512    # Construct a graph using a scalar (empty shape) input.
513    with ops.Graph().as_default():
514      in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
515      out_tensor = in_tensor + in_tensor
516      sess = session.Session()
517
518    # Test conversion with the scalar input shape.
519    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
520                                                  [out_tensor])
521    tflite_model = converter.convert()
522    self.assertIsNotNone(tflite_model)
523
524    # Check values from converted model.
525    interpreter = Interpreter(model_content=tflite_model)
526    interpreter.allocate_tensors()
527
528    input_details = interpreter.get_input_details()
529    self.assertLen(input_details, 1)
530    self.assertEqual('Placeholder', input_details[0]['name'])
531    self.assertEqual(np.float32, input_details[0]['dtype'])
532    self.assertEmpty(input_details[0]['shape'])
533
534    output_details = interpreter.get_output_details()
535    self.assertLen(output_details, 1)
536    self.assertEqual('add', output_details[0]['name'])
537    self.assertEqual(np.float32, output_details[0]['dtype'])
538    self.assertEmpty(input_details[0]['shape'])
539
540    # Validate inference using the scalar inputs/outputs.
541    test_input = np.array(4.0, dtype=np.float32)
542    expected_output = np.array(8.0, dtype=np.float32)
543    interpreter.set_tensor(input_details[0]['index'], test_input)
544    interpreter.invoke()
545
546    output_data = interpreter.get_tensor(output_details[0]['index'])
547    self.assertEqual(expected_output, output_data)
548
549  def testSizeInvalid(self):
550    with ops.Graph().as_default():
551      in_tensor = array_ops.placeholder(
552          shape=[1, None, 16, 3], dtype=dtypes.float32)
553      out_tensor = in_tensor + in_tensor
554      sess = session.Session()
555
556    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
557    # invoke shape checking code.
558    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
559                                                  [out_tensor])
560    converter.experimental_new_converter = False
561    with self.assertRaises(ValueError) as error:
562      converter.convert()
563    self.assertEqual(
564        'None is only supported in the 1st dimension. Tensor '
565        '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
566        str(error.exception))
567
568  def testSizeNone(self):
569    with ops.Graph().as_default():
570      in_tensor = array_ops.placeholder(
571          shape=[1, None, 16, 3], dtype=dtypes.float32)
572      out_tensor = in_tensor + in_tensor
573      sess = session.Session()
574
575    # Test None after 1st dimension.
576    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
577                                                  [out_tensor])
578    tflite_model = converter.convert()
579
580    # Check values from converted model.
581    interpreter = Interpreter(model_content=tflite_model)
582    input_details = interpreter.get_input_details()
583    self.assertLen(input_details, 1)
584    self.assertEqual('Placeholder', input_details[0]['name'])
585    self.assertEqual(np.float32, input_details[0]['dtype'])
586    self.assertAllEqual([1, 1, 16, 3], input_details[0]['shape'])
587    self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature'])
588    self.assertEqual((0., 0.), input_details[0]['quantization'])
589
590    # Resize tensor with strict checking.
591    with self.assertRaises(RuntimeError) as error:
592      interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
593    self.assertIn(
594        'ResizeInputTensorStrict only allows mutating unknown dimensions '
595        'identified by -1.', str(error.exception))
596
597    # Resize tensor and invoke.
598    interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
599    interpreter.allocate_tensors()
600
601    test_input = np.full([1, 16, 16, 3], 1.0, dtype=np.float32)
602    interpreter.set_tensor(input_details[0]['index'], test_input)
603    interpreter.invoke()
604
605    input_details = interpreter.get_input_details()
606    self.assertLen(input_details, 1)
607    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
608    self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature'])
609
610    output_details = interpreter.get_output_details()
611    self.assertAllEqual([1, -1, 16, 3], output_details[0]['shape_signature'])
612
613  def testResizeTensorInputStrict(self):
614    # Ensures that resize_tensor_input(strict=True) works as expected.
615    with ops.Graph().as_default():
616      in_tensor = array_ops.placeholder(
617          shape=[1, 16, 16, 3], dtype=dtypes.float32)
618      out_tensor = in_tensor + in_tensor
619      sess = session.Session()
620
621    # Convert model and ensure model is not None.
622    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
623                                                  [out_tensor])
624    tflite_model = converter.convert()
625    self.assertIsNotNone(tflite_model)
626
627    # Check values from converted model.
628    interpreter = Interpreter(model_content=tflite_model)
629
630    # Resize incorrect value.
631    with self.assertRaises(RuntimeError) as error:
632      interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
633    self.assertIn(
634        'ResizeInputTensorStrict only allows mutating unknown dimensions '
635        'identified by -1.', str(error.exception))
636
637    # Resize correct value.
638    interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
639    interpreter.allocate_tensors()
640
641  def testBatchSizeValid(self):
642    with ops.Graph().as_default():
643      in_tensor = array_ops.placeholder(
644          shape=[None, 16, 16, 3], dtype=dtypes.float32)
645      out_tensor = in_tensor + in_tensor
646      sess = session.Session()
647
648    # Convert model and ensure model is not None.
649    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
650                                                  [out_tensor])
651    tflite_model = converter.convert()
652    self.assertIsNotNone(tflite_model)
653
654    # Check values from converted model.
655    interpreter = Interpreter(model_content=tflite_model)
656    interpreter.allocate_tensors()
657
658    input_details = interpreter.get_input_details()
659    self.assertLen(input_details, 1)
660    self.assertEqual('Placeholder', input_details[0]['name'])
661    self.assertEqual(np.float32, input_details[0]['dtype'])
662    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
663    self.assertEqual((0., 0.), input_details[0]['quantization'])
664
665    output_details = interpreter.get_output_details()
666    self.assertLen(output_details, 1)
667    self.assertEqual('add', output_details[0]['name'])
668    self.assertEqual(np.float32, output_details[0]['dtype'])
669    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
670    self.assertEqual((0., 0.), output_details[0]['quantization'])
671
672  def testBatchSizeNonZero(self):
673    with ops.Graph().as_default():
674      in_tensor_1 = array_ops.placeholder(
675          shape=[None, 4], dtype=dtypes.float32, name='input1')
676      in_tensor_2 = array_ops.placeholder(
677          shape=[4, 10], dtype=dtypes.float32, name='input2')
678      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2)
679      sess = session.Session()
680
681    # Convert model and ensure model is not None.
682    converter = lite.TFLiteConverter.from_session(sess,
683                                                  [in_tensor_1, in_tensor_2],
684                                                  [out_tensor])
685    tflite_model = converter.convert()
686    self.assertIsNotNone(tflite_model)
687
688    # Check values from converted model.
689    interpreter = Interpreter(model_content=tflite_model)
690    interpreter.allocate_tensors()
691
692    input_details = interpreter.get_input_details()
693    self.assertLen(input_details, 2)
694    self.assertEqual('input1', input_details[0]['name'])
695    self.assertAllEqual([1, 4], input_details[0]['shape'])
696    self.assertEqual('input2', input_details[1]['name'])
697    self.assertAllEqual([4, 10], input_details[1]['shape'])
698
699  def testFreezeGraph(self):
700    with ops.Graph().as_default():
701      in_tensor = array_ops.placeholder(
702          shape=[1, 16, 16, 3], dtype=dtypes.float32)
703      var = variable_scope.get_variable(
704          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
705      # Get the second output to ensure freezing properly processes tensor names
706      # like 'X:1'.
707      out_tensor = nn_ops.top_k(in_tensor + var, name='top_k')[1]
708      sess = session.Session()
709      sess.run(_global_variables_initializer())
710
711    # Convert model and ensure model is not None.
712    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
713                                                  [out_tensor])
714    tflite_model = converter.convert()
715    self.assertIsNotNone(tflite_model)
716
717    # Check values from converted model.
718    interpreter = Interpreter(model_content=tflite_model)
719    interpreter.allocate_tensors()
720
721    input_details = interpreter.get_input_details()
722    self.assertLen(input_details, 1)
723    self.assertEqual('Placeholder', input_details[0]['name'])
724    self.assertEqual(np.float32, input_details[0]['dtype'])
725    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
726    self.assertEqual((0., 0.), input_details[0]['quantization'])
727
728    output_details = interpreter.get_output_details()
729    self.assertLen(output_details, 1)
730    self.assertEqual('top_k:1', output_details[0]['name'])
731    self.assertEqual(np.int32, output_details[0]['dtype'])
732    self.assertAllEqual([1, 16, 16, 1], output_details[0]['shape'])
733    self.assertEqual((0., 0.), output_details[0]['quantization'])
734
735  def testGraphviz(self):
736    with ops.Graph().as_default():
737      in_tensor = array_ops.placeholder(
738          shape=[1, 16, 16, 3], dtype=dtypes.float32)
739      out_tensor = in_tensor + in_tensor
740      sess = session.Session()
741
742    # Convert model and ensure model is not None.
743    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
744                                                  [out_tensor])
745    converter.output_format = lite_constants.GRAPHVIZ_DOT
746    graphviz_output = converter.convert()
747    self.assertIsNotNone(graphviz_output)
748
749  def testDumpGraphviz(self):
750    with ops.Graph().as_default():
751      in_tensor = array_ops.placeholder(
752          shape=[1, 16, 16, 3], dtype=dtypes.float32)
753      out_tensor = in_tensor + in_tensor
754      sess = session.Session()
755
756    # Convert model and ensure model is not None.
757    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
758                                                  [out_tensor])
759    graphviz_dir = self.get_temp_dir()
760    converter.dump_graphviz_dir = graphviz_dir
761    tflite_model = converter.convert()
762    self.assertIsNotNone(tflite_model)
763
764    # Ensure interpreter is able to allocate and check graphviz data.
765    interpreter = Interpreter(model_content=tflite_model)
766    interpreter.allocate_tensors()
767
768    num_items_graphviz = len(os.listdir(graphviz_dir))
769    self.assertIsNotNone(num_items_graphviz)
770    self.assertIsNotNone(
771        os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot')))
772    self.assertIsNotNone(
773        os.path.exists(
774            os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot')))
775
776  def testDumpConversionSummary(self):
777    with ops.Graph().as_default():
778      in_tensor = array_ops.placeholder(
779          shape=[1, 16, 16, 3], dtype=dtypes.float32)
780      out_tensor = in_tensor + in_tensor
781      sess = session.Session()
782
783    # Convert model and ensure model is not None.
784    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
785                                                  [out_tensor])
786    log_dir = self.get_temp_dir()
787    converter.conversion_summary_dir = log_dir
788    tflite_model = converter.convert()
789    self.assertIsNotNone(tflite_model)
790
791    self.assertNotEmpty(os.listdir(log_dir))
792
793  def testDumpConversionSummaryWithOldConverter(self):
794    with ops.Graph().as_default():
795      in_tensor = array_ops.placeholder(
796          shape=[1, 16, 16, 3], dtype=dtypes.float32)
797      out_tensor = in_tensor + in_tensor
798      sess = session.Session()
799
800    # Convert model and ensure model is not None.
801    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
802                                                  [out_tensor])
803    converter.experimental_new_converter = False
804    log_dir = self.get_temp_dir()
805    converter.conversion_summary_dir = log_dir
806    tflite_model = converter.convert()
807    self.assertIsNotNone(tflite_model)
808    # Check nothing is generated under the conversion summary path.
809    num_items_conversion_summary = len(os.listdir(log_dir))
810    self.assertEqual(num_items_conversion_summary, 0)
811
812  def testQuantizeDynamicRange(self):
813    np.random.seed(0)
814    with ops.Graph().as_default():
815      # We need the tensor to have more than 1024 elements for quantize_weights
816      # to kick in. Thus, the [33, 33] shape.
817      in_tensor_1 = array_ops.placeholder(
818          shape=[33, 33], dtype=dtypes.float32, name='inputA')
819      in_tensor_2 = constant_op.constant(
820          np.random.uniform(low=-10., high=10., size=(33, 33)),
821          shape=[33, 33],
822          dtype=dtypes.float32,
823          name='inputB')
824      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
825      sess = session.Session()
826
827    # Convert float model.
828    float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
829                                                        [out_tensor])
830    float_tflite_model = float_converter.convert()
831    self.assertIsNotNone(float_tflite_model)
832
833    # Convert quantized weights model.
834    quantized_converter = lite.TFLiteConverter.from_session(
835        sess, [in_tensor_1], [out_tensor])
836
837    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
838    quantized_tflite_model = quantized_converter.convert()
839    self.assertIsNotNone(quantized_tflite_model)
840
841    # Ensure that the quantized weights tflite model is smaller.
842    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
843
844  def testQuantizeDynamicRangeDeprecatedPostTrainingQuantizeAttribute(
845      self):
846    with ops.Graph().as_default():
847      in_tensor_1 = array_ops.placeholder(
848          shape=[33, 33], dtype=dtypes.float32, name='inputA')
849      in_tensor_2 = constant_op.constant(
850          np.random.uniform(low=-10., high=10., size=(33, 33)),
851          shape=[33, 33],
852          dtype=dtypes.float32,
853          name='inputB')
854      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
855      sess = session.Session()
856
857    quantized_converter = lite.TFLiteConverter.from_session(
858        sess, [in_tensor_1], [out_tensor])
859    self.assertFalse(quantized_converter.post_training_quantize)
860
861    quantized_converter.post_training_quantize = True
862    self.assertTrue(quantized_converter.post_training_quantize)
863    self.assertEqual(quantized_converter.optimizations, [lite.Optimize.DEFAULT])
864
865    quantized_tflite_model = quantized_converter.convert()
866    self.assertIsNotNone(quantized_tflite_model)
867
868  def _getIntegerQuantizeModel(self, num_filters=16):
869    np.random.seed(0)
870    inp = array_ops.placeholder(
871        dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input')
872    conv = nn_ops.conv2d(
873        inp,
874        filter=array_ops.ones([3, 3, 3, num_filters]),
875        strides=[1, 1, 1, 1],
876        padding='SAME')
877    output = nn_ops.relu(conv, name='output')
878
879    def calibration_gen():
880      for _ in range(5):
881        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
882
883    return (inp, output, calibration_gen)
884
885  def testQuantizeInt8AllowFloat(self):
886    with ops.Graph().as_default():
887      inp, output, calibration_gen = self._getIntegerQuantizeModel()
888      sess = session.Session()
889
890    # Convert float model.
891    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
892    float_tflite_model = float_converter.convert()
893    self.assertIsNotNone(float_tflite_model)
894    # Check the conversion metadata.
895    metadata = get_conversion_metadata(float_tflite_model)
896    self.assertIsNotNone(metadata)
897    self.assertEqual(
898        metadata.environment.tensorflowVersion.decode('utf-8'),
899        versions.__version__)
900    self.assertEqual(metadata.environment.apiVersion, 1)
901    self.assertEqual(metadata.environment.modelType,
902                     metadata_fb.ModelType.TF_SESSION)
903    self.assertEqual(metadata.options.allowCustomOps, False)
904    self.assertEqual(metadata.options.enableSelectTfOps, False)
905    self.assertEqual(metadata.options.forceSelectTfOps, False)
906    self.assertAllEqual([], metadata.options.modelOptimizationModes)
907
908    # Convert quantized model.
909    quantized_converter = lite.TFLiteConverter.from_session(
910        sess, [inp], [output])
911    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
912    quantized_converter.representative_dataset = calibration_gen
913    quantized_tflite_model = quantized_converter.convert()
914    self.assertIsNotNone(quantized_tflite_model)
915    # Check the conversion metadata.
916    metadata = get_conversion_metadata(quantized_tflite_model)
917    self.assertIsNotNone(metadata)
918    self.assertAllEqual([metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER],
919                        metadata.options.modelOptimizationModes)
920
921    # The default input and output types should be float.
922    interpreter = Interpreter(model_content=quantized_tflite_model)
923    interpreter.allocate_tensors()
924    input_details = interpreter.get_input_details()
925    self.assertLen(input_details, 1)
926    self.assertEqual(np.float32, input_details[0]['dtype'])
927    output_details = interpreter.get_output_details()
928    self.assertLen(output_details, 1)
929    self.assertEqual(np.float32, output_details[0]['dtype'])
930
931    # Ensure that the quantized weights tflite model is smaller.
932    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
933
934  @parameterized.named_parameters(
935      # Quantize model to Int8
936      ('UseTfliteBuiltinsInt', [lite.OpsSet.TFLITE_BUILTINS_INT8],
937       [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]),
938      ('UseTfliteBuiltinsInt16', [
939          lite.OpsSet
940          .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
941      ], [metadata_fb.ModelOptimizationMode.PTQ_INT16]))
942  def testQuantizeInt8And16x8(self, supported_ops, expected_opt_modes):
943    with ops.Graph().as_default():
944      inp, output, calibration_gen = self._getIntegerQuantizeModel()
945      sess = session.Session()
946
947    # Convert float model.
948    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
949    float_tflite_model = float_converter.convert()
950    self.assertIsNotNone(float_tflite_model)
951
952    # Convert model by specifying target spec (instead of optimizations), since
953    # when targeting an integer only backend, quantization is mandatory.
954    quantized_converter = lite.TFLiteConverter.from_session(
955        sess, [inp], [output])
956    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
957    quantized_converter.target_spec.supported_ops = supported_ops
958    quantized_converter.representative_dataset = calibration_gen
959    quantized_tflite_model = quantized_converter.convert()
960    self.assertIsNotNone(quantized_tflite_model)
961    # Check the conversion metadata.
962    metadata = get_conversion_metadata(quantized_tflite_model)
963    self.assertIsNotNone(metadata)
964    self.assertEqual(
965        metadata.environment.tensorflowVersion.decode('utf-8'),
966        versions.__version__)
967    self.assertEqual(metadata.environment.apiVersion, 1)
968    self.assertEqual(metadata.environment.modelType,
969                     metadata_fb.ModelType.TF_SESSION)
970    self.assertEqual(metadata.options.allowCustomOps, False)
971    self.assertEqual(metadata.options.enableSelectTfOps, False)
972    self.assertEqual(metadata.options.forceSelectTfOps, False)
973    self.assertAllEqual(expected_opt_modes,
974                        metadata.options.modelOptimizationModes)
975
976    # The default input and output types should be float.
977    interpreter = Interpreter(model_content=quantized_tflite_model)
978    interpreter.allocate_tensors()
979    input_details = interpreter.get_input_details()
980    self.assertLen(input_details, 1)
981    self.assertEqual(np.float32, input_details[0]['dtype'])
982    output_details = interpreter.get_output_details()
983    self.assertLen(output_details, 1)
984    self.assertEqual(np.float32, output_details[0]['dtype'])
985
986    # Ensure that the quantized weights tflite model is smaller.
987    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
988
989  def testQuantizeInt8InputOutput(self):
990    with ops.Graph().as_default():
991      inp, output, calibration_gen = self._getIntegerQuantizeModel()
992      sess = session.Session()
993
994    # Convert float model.
995    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
996    float_tflite_model = float_converter.convert()
997    self.assertIsNotNone(float_tflite_model)
998
999    # Convert quantized weights model.
1000    quantized_converter = lite.TFLiteConverter.from_session(
1001        sess, [inp], [output])
1002    quantized_converter.inference_input_type = dtypes.int8
1003    quantized_converter.inference_output_type = dtypes.int8
1004    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1005    quantized_converter.representative_dataset = calibration_gen
1006    quantized_tflite_model = quantized_converter.convert()
1007    self.assertIsNotNone(quantized_tflite_model)
1008
1009    # The input and output types should be int8.
1010    interpreter = Interpreter(model_content=quantized_tflite_model)
1011    interpreter.allocate_tensors()
1012    input_details = interpreter.get_input_details()
1013    self.assertLen(input_details, 1)
1014    self.assertEqual(np.int8, input_details[0]['dtype'])
1015    output_details = interpreter.get_output_details()
1016    self.assertLen(output_details, 1)
1017    self.assertEqual(np.int8, output_details[0]['dtype'])
1018
1019    # Ensure that the quantized weights tflite model is smaller.
1020    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1021
1022  def testInvalidQuantizeInt8(self):
1023    np.random.seed(0)
1024    with ops.Graph().as_default():
1025      # We need the tensor to have more than 1024 elements for quantize_weights
1026      # to kick in. Thus, the [33, 33] shape.
1027      in_tensor_1 = array_ops.placeholder(
1028          shape=[33, 33], dtype=dtypes.float32, name='inputA')
1029      in_tensor_2 = constant_op.constant(
1030          np.random.uniform(low=-10., high=10., size=(33, 33)),
1031          shape=[33, 33],
1032          dtype=dtypes.float32,
1033          name='inputB')
1034      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
1035      sess = session.Session()
1036
1037    # Attempt to convert to quantized weights model.
1038    quantized_converter = lite.TFLiteConverter.from_session(
1039        sess, [in_tensor_1], [out_tensor])
1040    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1041    # Restricting to int8 type only
1042    quantized_converter.target_spec.supported_types = [dtypes.int8]
1043    # A representative dataset is required for full fixed point quantization.
1044    with self.assertRaises(ValueError) as error:
1045      quantized_converter.convert()
1046    self.assertEqual(
1047        'For full integer quantization, a `representative_dataset` '
1048        'must be specified.', str(error.exception))
1049
1050  def testQuantizeUInt8(self):
1051    with ops.Graph().as_default():
1052      in_tensor_1 = array_ops.placeholder(
1053          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
1054      in_tensor_2 = array_ops.placeholder(
1055          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
1056      out_tensor = array_ops.fake_quant_with_min_max_args(
1057          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
1058      sess = session.Session()
1059
1060    # Convert model and ensure model is not None.
1061    converter = lite.TFLiteConverter.from_session(sess,
1062                                                  [in_tensor_1, in_tensor_2],
1063                                                  [out_tensor])
1064    converter.inference_type = dtypes.uint8
1065    converter.quantized_input_stats = {
1066        'inputA': (0., 1.),
1067        'inputB': (0., 1.)
1068    }  # mean, std_dev
1069    tflite_model = converter.convert()
1070    self.assertIsNotNone(tflite_model)
1071
1072    # Check values from converted model.
1073    interpreter = Interpreter(model_content=tflite_model)
1074    interpreter.allocate_tensors()
1075
1076    input_details = interpreter.get_input_details()
1077    self.assertLen(input_details, 2)
1078    self.assertEqual('inputA', input_details[0]['name'])
1079    self.assertEqual(np.uint8, input_details[0]['dtype'])
1080    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1081    self.assertEqual((1., 0.), input_details[0]['quantization'])
1082
1083    self.assertEqual('inputB', input_details[1]['name'])
1084    self.assertEqual(np.uint8, input_details[1]['dtype'])
1085    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
1086    self.assertEqual((1., 0.), input_details[1]['quantization'])
1087
1088    output_details = interpreter.get_output_details()
1089    self.assertLen(output_details, 1)
1090    self.assertEqual(np.uint8, output_details[0]['dtype'])
1091    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1092    self.assertGreater(output_details[0]['quantization'][0], 0)  # scale
1093
1094  def testQuantizeUInt8UsingDefaultRangeStats(self):
1095    with ops.Graph().as_default():
1096      in_tensor = array_ops.placeholder(
1097          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1098      out_tensor = in_tensor + in_tensor
1099      sess = session.Session()
1100
1101    # Convert model and ensure model is not None.
1102    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1103                                                  [out_tensor])
1104    converter.inference_type = dtypes.uint8
1105    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
1106    converter.default_ranges_stats = (0, 6)  # min, max
1107    tflite_model = converter.convert()
1108    self.assertIsNotNone(tflite_model)
1109
1110    # Check values from converted model.
1111    interpreter = Interpreter(model_content=tflite_model)
1112    interpreter.allocate_tensors()
1113
1114    input_details = interpreter.get_input_details()
1115    self.assertLen(input_details, 1)
1116    self.assertEqual('Placeholder', input_details[0]['name'])
1117    self.assertEqual(np.uint8, input_details[0]['dtype'])
1118    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1119    self.assertEqual((1., 0.), input_details[0]['quantization'])
1120
1121    output_details = interpreter.get_output_details()
1122    self.assertLen(output_details, 1)
1123    self.assertEqual('add', output_details[0]['name'])
1124    self.assertEqual(np.uint8, output_details[0]['dtype'])
1125    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1126    self.assertGreater(output_details[0]['quantization'][0], 0)  # scale
1127
1128  @parameterized.named_parameters(
1129      # Quantize to Float16 even if rep data provided.
1130      ('UseRepresentativeData', True, False, True, False, False, False,
1131       [metadata_fb.ModelOptimizationMode.PTQ_FLOAT16]),
1132      # Quantize to Float16 if no rep data provided.
1133      ('NoRepresentativeData', False, False, True, False, False, False,
1134       [metadata_fb.ModelOptimizationMode.PTQ_FLOAT16]),
1135      # Post training quantization if both rep data and int8 included.
1136      ('SampleDataIncludeInt8', True, True, False, False, True, False,
1137       [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]),
1138      # Same as above, but using MLIR quantizer
1139      ('SampleDataIncludeInt8Quant', True, True, False, False, True, True,
1140       [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]))
1141  def testQuantizeFloat16(self, use_rep_data, include_int8,
1142                          is_float16_quantized, is_float16_accumulation,
1143                          is_post_training_quantized, enable_mlir_quantizer,
1144                          expected_opt_modes):
1145    with ops.Graph().as_default():
1146      inp, output, calibration_gen = self._getIntegerQuantizeModel()
1147      sess = session.Session()
1148
1149    bias_idx = 1
1150    bias_name = 'Conv2D'
1151
1152    # Convert float model.
1153    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1154    float_tflite_model = float_converter.convert()
1155    self.assertIsNotNone(float_tflite_model)
1156    interpreter = Interpreter(model_content=float_tflite_model)
1157    interpreter.allocate_tensors()
1158    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'],
1159                     bias_name)
1160    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
1161                     dtypes.float32)
1162
1163    # Convert model to quantized version
1164    quantized_converter = lite.TFLiteConverter.from_session(
1165        sess, [inp], [output])
1166    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
1167    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1168    quantized_converter.target_spec.supported_types = [dtypes.float16]
1169    if include_int8:
1170      quantized_converter.target_spec.supported_types.append(dtypes.int8)
1171    if use_rep_data:
1172      quantized_converter.representative_dataset = calibration_gen
1173    if is_float16_accumulation:
1174      quantized_converter.target_spec.experimental_supported_accumulation_type = dtypes.float16  # pylint: disable=line-too-long
1175
1176    else:
1177      quantized_tflite_model = quantized_converter.convert()
1178      self.assertIsNotNone(quantized_tflite_model)
1179      metadata = get_conversion_metadata(quantized_tflite_model)
1180      self.assertIsNotNone(metadata)
1181      self.assertAllEqual(expected_opt_modes,
1182                          metadata.options.modelOptimizationModes)
1183      interpreter = Interpreter(model_content=quantized_tflite_model)
1184      interpreter.allocate_tensors()
1185
1186      # MLIR quantizer has different bias index.
1187      bias_tensor = [
1188          tensor for tensor in interpreter.get_tensor_details()
1189          if tensor['name'] == bias_name
1190      ]
1191      self.assertLen(bias_tensor, 1)
1192
1193      if is_float16_quantized:
1194        # Verify that bias constant is float16 type.
1195        self.assertEqual(bias_tensor[0]['dtype'], dtypes.float16)
1196      elif is_post_training_quantized:
1197        # Verify that bias constants is int32 type.
1198        self.assertEqual(bias_tensor[0]['dtype'], dtypes.int32)
1199      else:
1200        raise ValueError('Invalid test options.')
1201
1202  def testInvalidQuantizeFloat16(self):
1203    with ops.Graph().as_default():
1204      inp, output, _ = self._getIntegerQuantizeModel()
1205      sess = session.Session()
1206
1207    # Specify float16 quantization
1208    quantized_converter = lite.TFLiteConverter.from_session(
1209        sess, [inp], [output])
1210    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1211    quantized_converter.target_spec.supported_types = [dtypes.float16]
1212    # Specify only int8 builtin ops
1213    quantized_converter.target_spec.supported_ops = [
1214        lite.OpsSet.TFLITE_BUILTINS_INT8
1215    ]
1216    with self.assertRaises(ValueError) as error:
1217      quantized_converter.convert()
1218    self.assertEqual(
1219        'As full integer quantization has been enabled by setting '
1220        '`target_spec.supported_ops`={tf.lite.OpsSet.TFLITE_BUILTINS_INT8}, '
1221        'thus `target_spec.supported_types` should be left uninitizalized '
1222        'or set to {tf.int8}.', str(error.exception))
1223
1224  @parameterized.named_parameters(('InferenceType_INT8', dtypes.int8),
1225                                  ('InferenceType_UINT8', dtypes.uint8))
1226  def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type):
1227    with ops.Graph().as_default():
1228      in_tensor = array_ops.placeholder(
1229          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1230      out_tensor = array_ops.fake_quant_with_min_max_args(
1231          in_tensor + in_tensor, min=0., max=1.)
1232      sess = session.Session()
1233
1234    quantized_converter = lite.TFLiteConverter.from_session(
1235        sess, [in_tensor], [out_tensor])
1236
1237    with self.assertRaises(ValueError) as error:
1238      quantized_converter.inference_type = quantized_type
1239      quantized_converter.convert()
1240    self.assertEqual(
1241        'The `quantized_input_stats` flag must be defined when either '
1242        '`inference_type` flag or `inference_input_type` flag is set to '
1243        'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and '
1244        '`inference_input_type=None`.'.format(quantized_type.name),
1245        str(error.exception))
1246
1247    with self.assertRaises(ValueError) as error:
1248      quantized_converter.inference_type = dtypes.float32
1249      quantized_converter.inference_input_type = quantized_type
1250      quantized_converter.convert()
1251    self.assertEqual(
1252        'The `quantized_input_stats` flag must be defined when either '
1253        '`inference_type` flag or `inference_input_type` flag is set to '
1254        'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and '
1255        '`inference_input_type=tf.{}`.'.format(quantized_type.name),
1256        str(error.exception))
1257
1258    quantized_converter.inference_type = quantized_type
1259    quantized_converter.inference_input_type = quantized_type
1260
1261    input_arrays = quantized_converter.get_input_arrays()
1262    quantized_converter.quantized_input_stats = {input_arrays[0]: (0., 1.)}
1263    quantized_converter.convert()
1264
1265  def testInvalidQuantizeQATModelMissingInputStats(self):
1266    with ops.Graph().as_default():
1267      in_tensor_1 = array_ops.placeholder(
1268          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
1269      in_tensor_2 = array_ops.placeholder(
1270          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
1271      out_tensor = array_ops.fake_quant_with_min_max_args(
1272          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
1273      sess = session.Session()
1274
1275    # Convert model and ensure model is not None.
1276    converter = lite.TFLiteConverter.from_session(sess,
1277                                                  [in_tensor_1, in_tensor_2],
1278                                                  [out_tensor])
1279    converter.inference_type = dtypes.uint8
1280    converter.quantized_input_stats = {'inputA': (0., 1.)}  # mean, std_dev
1281    with self.assertRaises(ValueError) as error:
1282      converter.convert()
1283    self.assertEqual(
1284        'Quantization input stats are not available for input tensors '
1285        '\'inputB\'.', str(error.exception))
1286
1287  def testTrainingTimeAndPostTrainingCalibrateAndQuantize(self):
1288    with ops.Graph().as_default():
1289      inp, output, calibration_gen = self._getIntegerQuantizeModel()
1290      sess = session.Session()
1291
1292    # Convert float model.
1293    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1294    float_tflite_model = float_converter.convert()
1295    self.assertIsNotNone(float_tflite_model)
1296
1297    converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1298
1299    # extra flags to trigger training time quantization conversion
1300    converter.inference_type = dtypes.int8
1301    converter.inference_input_type = dtypes.float32
1302    converter.inference_output_type = dtypes.float32
1303    input_arrays = converter.get_input_arrays()
1304    converter.quantized_input_stats = {input_arrays[0]: (0., 1.)}
1305    # trigger post-training quantization
1306    converter.optimizations = [lite.Optimize.DEFAULT]
1307    converter.representative_dataset = calibration_gen
1308    converter.experimental_new_quantizer = True
1309    quantized_tflite_model = converter.convert()
1310    self.assertIsNotNone(quantized_tflite_model)
1311    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1312
1313    # calibration only api
1314    converter._experimental_calibrate_only = True
1315    calibrated_tflite = converter.convert()
1316    quantized_tflite_model = mlir_quantize(
1317        calibrated_tflite, fully_quantize=True)
1318    interpreter = Interpreter(model_content=quantized_tflite_model)
1319    interpreter.allocate_tensors()
1320    input_details = interpreter.get_input_details()
1321    self.assertEqual(np.int8, input_details[0]['dtype'])
1322    self.assertEqual((1., 0.), input_details[0]['quantization'])
1323
1324    output_details = interpreter.get_output_details()
1325    self.assertEqual(np.int8, output_details[0]['dtype'])
1326
1327  def testFloatTocoConverter(self):
1328    """Tests deprecated test TocoConverter."""
1329    with ops.Graph().as_default():
1330      in_tensor = array_ops.placeholder(
1331          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1332      out_tensor = in_tensor + in_tensor
1333      sess = session.Session()
1334
1335    # Convert model and ensure model is not None.
1336    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
1337    tflite_model = converter.convert()
1338    self.assertIsNotNone(tflite_model)
1339
1340    # Ensure the interpreter is able to load.
1341    interpreter = Interpreter(model_content=tflite_model)
1342    interpreter.allocate_tensors()
1343
1344  def testMultipleOutputNodeNames(self):
1345    """Tests converting a graph with an op that have multiple outputs."""
1346    with ops.Graph().as_default():
1347      input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
1348      out0, out1, out2, out3 = array_ops.split(
1349          input_tensor, [1, 1, 1, 1], axis=0)
1350      sess = session.Session()
1351
1352    # Convert model and ensure model is not None.
1353    converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
1354                                                  [out0, out1, out2, out3])
1355    tflite_model = converter.convert()
1356    self.assertIsNotNone(tflite_model)
1357
1358    # Check values from converted model.
1359    interpreter = Interpreter(model_content=tflite_model)
1360    interpreter.allocate_tensors()
1361
1362    input_details = interpreter.get_input_details()
1363    self.assertLen(input_details, 1)
1364    interpreter.set_tensor(input_details[0]['index'],
1365                           np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
1366    interpreter.invoke()
1367
1368    output_details = interpreter.get_output_details()
1369    self.assertLen(output_details, 4)
1370    self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
1371    self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
1372    self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
1373    self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))
1374
1375  @test_util.run_in_graph_and_eager_modes
1376  def testFunctions(self):
1377    """Tests tf.function in 1.X."""
1378
1379    @def_function.function
1380    def plus_placeholder(x, placeholder):
1381      return x + placeholder
1382
1383    with ops.Graph().as_default():
1384      placeholder = array_ops.placeholder(
1385          dtype=dtypes.float32, shape=[1], name='input')
1386      variable_node = variables.Variable(1.0, name='variable_node')
1387      defun_node = plus_placeholder(variable_node, placeholder)
1388      output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
1389
1390      # Initialize variables in the model.
1391      sess = session.Session()
1392      sess.run(variables.variables_initializer([variable_node]))
1393
1394    # Convert model and ensure model is not None.
1395    converter = lite.TFLiteConverter.from_session(sess, [placeholder],
1396                                                  [output_node])
1397    tflite_model = converter.convert()
1398    self.assertIsNotNone(tflite_model)
1399
1400    # Check values from converted model.
1401    interpreter = Interpreter(model_content=tflite_model)
1402    interpreter.allocate_tensors()
1403
1404    input_details = interpreter.get_input_details()
1405    self.assertLen(input_details, 1)
1406    self.assertEqual('input', input_details[0]['name'])
1407    self.assertEqual(np.float32, input_details[0]['dtype'])
1408    self.assertAllEqual([1], input_details[0]['shape'])
1409    self.assertEqual((0., 0.), input_details[0]['quantization'])
1410
1411    output_details = interpreter.get_output_details()
1412    self.assertLen(output_details, 1)
1413    self.assertEqual('output_node', output_details[0]['name'])
1414    self.assertEqual(np.float32, output_details[0]['dtype'])
1415    self.assertAllEqual([1], output_details[0]['shape'])
1416    self.assertEqual((0., 0.), output_details[0]['quantization'])
1417
1418  def testInferenceInputOutputTypeFloatDefault(self):
1419    with ops.Graph().as_default():
1420      in_tensor = array_ops.placeholder(
1421          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1422      out_tensor = in_tensor + in_tensor
1423      sess = session.Session()
1424
1425    # Convert model and ensure model is not None.
1426    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1427                                                  [out_tensor])
1428    tflite_model = converter.convert()
1429    self.assertIsNotNone(tflite_model)
1430
1431    # Check values from converted model.
1432    interpreter = Interpreter(model_content=tflite_model)
1433    interpreter.allocate_tensors()
1434
1435    input_details = interpreter.get_input_details()
1436    self.assertLen(input_details, 1)
1437    self.assertEqual('Placeholder', input_details[0]['name'])
1438    self.assertEqual(np.float32, input_details[0]['dtype'])
1439    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1440
1441    output_details = interpreter.get_output_details()
1442    self.assertLen(output_details, 1)
1443    self.assertEqual('add', output_details[0]['name'])
1444    self.assertEqual(np.float32, output_details[0]['dtype'])
1445    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1446
1447  def testInferenceInputOutputTypeQuantizedUint8Default(self):
1448    with ops.Graph().as_default():
1449      in_tensor = array_ops.placeholder(
1450          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1451      out_tensor = array_ops.fake_quant_with_min_max_args(
1452          in_tensor + in_tensor, min=0., max=1., name='output')
1453      sess = session.Session()
1454
1455    # Convert model and ensure model is not None.
1456    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1457                                                  [out_tensor])
1458    converter.inference_type = dtypes.uint8
1459    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
1460    tflite_model = converter.convert()
1461    self.assertIsNotNone(tflite_model)
1462
1463    # Check values from converted model.
1464    interpreter = Interpreter(model_content=tflite_model)
1465    interpreter.allocate_tensors()
1466
1467    input_details = interpreter.get_input_details()
1468    self.assertLen(input_details, 1)
1469    self.assertEqual('Placeholder', input_details[0]['name'])
1470    self.assertEqual(np.uint8, input_details[0]['dtype'])
1471    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1472
1473    output_details = interpreter.get_output_details()
1474    self.assertLen(output_details, 1)
1475    self.assertEqual('output', output_details[0]['name'])
1476    self.assertEqual(np.uint8, output_details[0]['dtype'])
1477    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1478
1479  def testReusingConverterWithDifferentPostTrainingQuantization(self):
1480    with ops.Graph().as_default():
1481      in_tensor = array_ops.placeholder(
1482          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1483      out_tensor = array_ops.fake_quant_with_min_max_args(
1484          in_tensor + in_tensor, min=0., max=1., name='output')
1485      sess = session.Session()
1486
1487    # Convert model and ensure model is not None.
1488    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1489                                                  [out_tensor])
1490
1491    converter.post_training_quantize = True
1492    tflite_model = converter.convert()
1493    self.assertIsNotNone(tflite_model)
1494
1495    converter.post_training_quantize = False
1496    tflite_model = converter.convert()
1497    self.assertIsNotNone(tflite_model)
1498
1499  def testResizeWithShape(self):
1500    with ops.Graph().as_default():
1501      # Construct a graph with a dynamically shapped input and an internal node
1502      # that relies on the output of that input's shape.
1503      in_tensor = array_ops.placeholder(
1504          shape=[None, None], dtype=dtypes.float32)
1505      in_tensor2 = [[1, 2], [3, 4]]
1506      out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor))
1507      sess = session.Session()
1508
1509    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1510                                                  [out_tensor])
1511    tflite_model = converter.convert()
1512
1513    # Check values from converted model.
1514    interpreter = Interpreter(model_content=tflite_model)
1515    input_details = interpreter.get_input_details()
1516    self.assertLen(input_details, 1)
1517    self.assertAllEqual([1, 1], input_details[0]['shape'])
1518    self.assertAllEqual([-1, -1], input_details[0]['shape_signature'])
1519
1520    # Resize tensor and invoke.
1521    interpreter.resize_tensor_input(0, [4])
1522    interpreter.allocate_tensors()
1523    interpreter.invoke()
1524
1525    # The output should be reshaped properly according to the resized input.
1526    output_details = interpreter.get_output_details()
1527    self.assertLen(output_details, 1)
1528    self.assertEqual(np.int32, output_details[0]['dtype'])
1529    self.assertAllEqual([4], output_details[0]['shape'])
1530    output_data = interpreter.get_tensor(output_details[0]['index'])
1531    self.assertAllEqual([1, 2, 3, 4], output_data)
1532
1533  def testResizingIntermediateDynamicTensor(self):
1534    # This is a regression test for the case where shape of dynamic output
1535    # tensors changes between invocations.
1536    # See also https://github.com/tensorflow/tensorflow/issues/26549
1537    with ops.Graph().as_default():
1538      input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32)
1539      input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32)
1540
1541      # The bug is triggered only when dynamic tensor is intermediate. Putting
1542      # some other ops around it.
1543      neg = math_ops.negative(input2_tensor)
1544      padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32)
1545      output_tensor = array_ops.pad(input_tensor, padding) + neg
1546
1547      sess = session.Session()
1548
1549    converter = lite.TFLiteConverter.from_session(
1550        sess, [input_tensor, padding, input2_tensor], [output_tensor])
1551    tflite_model = converter.convert()
1552
1553    interpreter = Interpreter(model_content=tflite_model)
1554    interpreter.allocate_tensors()
1555
1556    input_details = interpreter.get_input_details()
1557    interpreter.set_tensor(input_details[1]['index'],
1558                           np.array([[1, 1], [1, 1]], dtype=np.int32))
1559    interpreter.invoke()
1560
1561    # Without the fix, invocation will fail when changing the shape of
1562    # intermediate dynamic tensors.
1563    interpreter.set_tensor(input_details[1]['index'],
1564                           np.array([[2, 2], [2, 2]], dtype=np.int32))
1565    interpreter.invoke()
1566
1567  def testGraphDebugInfo(self):
1568    """Test a session has debug info captured."""
1569
1570    @def_function.function
1571    def plus_placeholder(x, placeholder):
1572      return x + placeholder
1573
1574    with ops.Graph().as_default():
1575      placeholder = array_ops.placeholder(
1576          dtype=dtypes.float32, shape=[1], name='input')
1577      variable_node = variables.Variable(1.0, name='variable_node')
1578      defun_node = plus_placeholder(variable_node, placeholder)
1579      output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
1580
1581      # Initialize variables in the model.
1582      sess = session.Session()
1583      sess.run(variables.variables_initializer([variable_node]))
1584
1585    converter = lite.TFLiteConverter.from_session(sess, [placeholder],
1586                                                  [output_node])
1587    converter.convert()
1588    self.assertValidDebugInfo(converter._debug_info)
1589
1590    # Check the add node in the inlined function is included.
1591    func = sess.graph.as_graph_def().library.function[0].signature.name
1592    self.assertIn(('add@' + func), converter._debug_info.traces)
1593
1594  def testOutputOnlyModel(self):
1595    with ops.Graph().as_default():
1596      out_tensor = random_ops.random_normal(shape=[3])
1597      sess = session.Session()
1598
1599    # Convert model and ensure model is not None.
1600    converter = lite.TFLiteConverter.from_session(sess, [], [out_tensor])
1601    converter.target_spec.supported_ops = [
1602        lite.OpsSet.TFLITE_BUILTINS,
1603        lite.OpsSet.SELECT_TF_OPS,
1604    ]
1605
1606    # Empty input array is a valid input.
1607    self.assertTrue(converter._has_valid_tensors())
1608
1609    tflite_model = converter.convert()
1610    self.assertIsNotNone(tflite_model)
1611
1612
1613class FromFrozenGraphFile(LiteTest):
1614
1615  def testFloat(self):
1616    with ops.Graph().as_default():
1617      in_tensor = array_ops.placeholder(
1618          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1619      _ = in_tensor + in_tensor
1620      sess = session.Session()
1621
1622    # Write graph to file.
1623    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1624    write_graph(sess.graph_def, '', graph_def_file, False)
1625    sess.close()
1626
1627    # Convert model and ensure model is not None.
1628    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
1629                                                       ['Placeholder'], ['add'])
1630    tflite_model = converter.convert()
1631    self.assertIsNotNone(tflite_model)
1632
1633    # Check values from converted model.
1634    interpreter = Interpreter(model_content=tflite_model)
1635    interpreter.allocate_tensors()
1636
1637    input_details = interpreter.get_input_details()
1638    self.assertLen(input_details, 1)
1639    self.assertEqual('Placeholder', input_details[0]['name'])
1640    self.assertEqual(np.float32, input_details[0]['dtype'])
1641    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1642    self.assertEqual((0., 0.), input_details[0]['quantization'])
1643
1644    output_details = interpreter.get_output_details()
1645    self.assertLen(output_details, 1)
1646    self.assertEqual('add', output_details[0]['name'])
1647    self.assertEqual(np.float32, output_details[0]['dtype'])
1648    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1649    self.assertEqual((0., 0.), output_details[0]['quantization'])
1650
1651  def testFloatWithShapesArray(self):
1652    """Test a shape overriding case."""
1653    with ops.Graph().as_default():
1654      in_tensor = array_ops.placeholder(
1655          shape=[None, 16, 16, 3], dtype=dtypes.float32)
1656      _ = in_tensor + in_tensor
1657      sess = session.Session()
1658
1659    # Write graph to file.
1660    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1661    write_graph(sess.graph_def, '', graph_def_file, False)
1662    sess.close()
1663
1664    # Convert model and ensure model is not None.
1665    converter = lite.TFLiteConverter.from_frozen_graph(
1666        graph_def_file, ['Placeholder'], ['add'],
1667        input_shapes={'Placeholder': [2, 16, 16, 3]})
1668    tflite_model = converter.convert()
1669    self.assertIsNotNone(tflite_model)
1670
1671    # Check values from converted model.
1672    interpreter = Interpreter(model_content=tflite_model)
1673    interpreter.allocate_tensors()
1674
1675    input_details = interpreter.get_input_details()
1676    self.assertLen(input_details, 1)
1677    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
1678
1679  def testInvalidShapesArray(self):
1680    """Test an invalid shape overriding case, which has a wrong input name."""
1681    with ops.Graph().as_default():
1682      in_tensor = array_ops.placeholder(
1683          shape=[None, 16, 16, 3], dtype=dtypes.float32)
1684      _ = in_tensor + in_tensor
1685      sess = session.Session()
1686
1687    # Write graph to file.
1688    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1689    write_graph(sess.graph_def, '', graph_def_file, False)
1690    sess.close()
1691
1692    # Convert model and ensure model is not None.
1693    with self.assertRaises(ValueError):
1694      lite.TFLiteConverter.from_frozen_graph(
1695          graph_def_file, ['Placeholder'], ['add'],
1696          input_shapes={'wrong_input': [2, 16, 16, 3]})
1697
1698  def testPartialShapesArray(self):
1699    """Test a shape overriding case, with the only one input among two."""
1700    with ops.Graph().as_default():
1701      a = array_ops.placeholder(
1702          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='a')
1703      b = array_ops.placeholder(
1704          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='b')
1705      _ = math_ops.add(a, b, name='add')
1706      sess = session.Session()
1707
1708    # Write graph to file.
1709    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1710    write_graph(sess.graph_def, '', graph_def_file, False)
1711    sess.close()
1712
1713    # Convert model and ensure model is not None.
1714    converter = lite.TFLiteConverter.from_frozen_graph(
1715        graph_def_file, ['a', 'b'], ['add'], input_shapes={'a': [2, 16, 16, 3]})
1716    tflite_model = converter.convert()
1717    self.assertIsNotNone(tflite_model)
1718
1719    # Check values from converted model.
1720    interpreter = Interpreter(model_content=tflite_model)
1721    interpreter.allocate_tensors()
1722
1723    input_details = interpreter.get_input_details()
1724    self.assertLen(input_details, 2)
1725    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
1726    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
1727
1728  def testFreezeGraph(self):
1729    with ops.Graph().as_default():
1730      in_tensor = array_ops.placeholder(
1731          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1732      var = variable_scope.get_variable(
1733          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
1734      _ = in_tensor + var
1735      sess = session.Session()
1736
1737    # Write graph to file.
1738    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1739    write_graph(sess.graph_def, '', graph_def_file, False)
1740    sess.close()
1741
1742    # Ensure the graph with variables cannot be converted.
1743    with self.assertRaises(ValueError) as error:
1744      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
1745                                             ['add'])
1746    self.assertEqual('Please freeze the graph using freeze_graph.py.',
1747                     str(error.exception))
1748
1749  def testPbtxt(self):
1750    with ops.Graph().as_default():
1751      in_tensor = array_ops.placeholder(
1752          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1753      _ = in_tensor + in_tensor
1754      sess = session.Session()
1755
1756    # Write graph to file.
1757    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
1758    write_graph(sess.graph_def, '', graph_def_file, True)
1759    sess.close()
1760
1761    # Convert model and ensure model is not None.
1762    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
1763                                                       ['Placeholder'], ['add'])
1764    tflite_model = converter.convert()
1765    self.assertIsNotNone(tflite_model)
1766
1767    # Check values from converted model.
1768    interpreter = Interpreter(model_content=tflite_model)
1769    interpreter.allocate_tensors()
1770
1771    input_details = interpreter.get_input_details()
1772    self.assertLen(input_details, 1)
1773    self.assertEqual('Placeholder', input_details[0]['name'])
1774    self.assertEqual(np.float32, input_details[0]['dtype'])
1775    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1776    self.assertEqual((0., 0.), input_details[0]['quantization'])
1777
1778    output_details = interpreter.get_output_details()
1779    self.assertLen(output_details, 1)
1780    self.assertEqual('add', output_details[0]['name'])
1781    self.assertEqual(np.float32, output_details[0]['dtype'])
1782    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1783    self.assertEqual((0., 0.), output_details[0]['quantization'])
1784
1785  def testInvalidFileNotFound(self):
1786    with self.assertRaises(IOError) as error:
1787      lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
1788                                             ['add'])
1789    self.assertEqual('File \'invalid_file\' does not exist.',
1790                     str(error.exception))
1791
1792  def testInvalidFileBadData(self):
1793    graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
1794    with gfile.Open(graph_def_file, 'wb') as temp_file:
1795      temp_file.write('bad data')
1796      temp_file.flush()
1797
1798    # Attempts to convert the invalid model.
1799    with self.assertRaises(IOError) as error:
1800      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
1801                                             ['add'])
1802    self.assertEqual(
1803        'Unable to parse input file \'{}\'.'.format(graph_def_file),
1804        str(error.exception))
1805
1806  def testFloatTocoConverter(self):
1807    with ops.Graph().as_default():
1808      in_tensor = array_ops.placeholder(
1809          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1810      _ = in_tensor + in_tensor
1811      sess = session.Session()
1812
1813    # Write graph to file.
1814    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1815    write_graph(sess.graph_def, '', graph_def_file, False)
1816    sess.close()
1817
1818    # Convert model and ensure model is not None.
1819    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
1820                                                     ['Placeholder'], ['add'])
1821    tflite_model = converter.convert()
1822    self.assertIsNotNone(tflite_model)
1823
1824    # Ensure the model is able to load.
1825    interpreter = Interpreter(model_content=tflite_model)
1826    interpreter.allocate_tensors()
1827
1828  def testGraphDebugInfo(self):
1829    """Test a frozen graph doesn't have debug info captured."""
1830    with ops.Graph().as_default():
1831      in_tensor = array_ops.placeholder(
1832          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1833      _ = in_tensor + in_tensor
1834      sess = session.Session()
1835
1836    # Write graph to file.
1837    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1838    write_graph(sess.graph_def, '', graph_def_file, False)
1839    sess.close()
1840
1841    # Convert model and ensure model is not None.
1842    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
1843                                                     ['Placeholder'], ['add'])
1844    converter.convert()
1845    # GraphDebugInfo should be none for frozen graph.
1846    self.assertFalse(converter._debug_info)
1847
1848  def testExcludeConversionMetadata(self):
1849    with ops.Graph().as_default():
1850      in_tensor = array_ops.placeholder(
1851          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1852      _ = in_tensor + in_tensor
1853      sess = session.Session()
1854
1855    # Write graph to file.
1856    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1857    write_graph(sess.graph_def, '', graph_def_file, False)
1858    sess.close()
1859
1860    # Convert model and ensure model is not None.
1861    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
1862                                                       ['Placeholder'], ['add'])
1863    converter.exclude_conversion_metadata = True
1864    tflite_model = converter.convert()
1865    self.assertIsNotNone(tflite_model)
1866    # Check the conversion metadata.
1867    metadata = get_conversion_metadata(tflite_model)
1868    self.assertIsNone(metadata)
1869
1870
1871class FromFrozenGraphObjectDetection(LiteTest):
1872
1873  def _initObjectDetectionArgs(self):
1874    # Initializes the arguments required for the object detection model.
1875    # Looks for the model file which is saved in a different location internally
1876    # and externally.
1877    filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
1878    if not os.path.exists(filename):
1879      filename = os.path.join(
1880          resource_loader.get_root_dir_with_all_resources(),
1881          '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
1882      if not os.path.exists(filename):
1883        raise IOError("File '{0}' does not exist.".format(filename))
1884
1885    self._graph_def_file = filename
1886    self._input_arrays = ['normalized_input_image_tensor']
1887    self._output_arrays = [
1888        'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
1889        'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
1890    ]
1891    self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
1892
1893  def testTFLiteGraphDef(self):
1894    # Tests the object detection model that cannot be loaded in TensorFlow.
1895    self._initObjectDetectionArgs()
1896
1897    converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
1898                                                       self._input_arrays,
1899                                                       self._output_arrays,
1900                                                       self._input_shapes)
1901    converter.allow_custom_ops = True
1902    tflite_model = converter.convert()
1903    self.assertIsNotNone(tflite_model)
1904
1905    # Check values from converted model.
1906    interpreter = Interpreter(model_content=tflite_model)
1907    interpreter.allocate_tensors()
1908
1909    input_details = interpreter.get_input_details()
1910    self.assertLen(input_details, 1)
1911    self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
1912    self.assertEqual(np.float32, input_details[0]['dtype'])
1913    self.assertAllEqual([1, 300, 300, 3], input_details[0]['shape'])
1914    self.assertEqual((0., 0.), input_details[0]['quantization'])
1915
1916    output_details = interpreter.get_output_details()
1917    self.assertLen(output_details, 4)
1918    self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
1919    self.assertEqual(np.float32, output_details[0]['dtype'])
1920    self.assertAllEqual([1, 10, 4], output_details[0]['shape'])
1921    self.assertEqual((0., 0.), output_details[0]['quantization'])
1922
1923    self.assertEqual('TFLite_Detection_PostProcess:1',
1924                     output_details[1]['name'])
1925    self.assertAllEqual([1, 10], output_details[1]['shape'])
1926    self.assertEqual('TFLite_Detection_PostProcess:2',
1927                     output_details[2]['name'])
1928    self.assertAllEqual([1, 10], output_details[2]['shape'])
1929    self.assertEqual('TFLite_Detection_PostProcess:3',
1930                     output_details[3]['name'])
1931    self.assertAllEqual([1], output_details[3]['shape'])
1932
1933  def testTFLiteGraphDefWithControlOutput(self):
1934    with ops.Graph().as_default():
1935      in_tensor = array_ops.placeholder(
1936          shape=[5, 5], dtype=dtypes.float32, name='input')
1937      out_tensor = in_tensor + in_tensor
1938      logging_ops.print_v2(out_tensor)
1939      sess = session.Session()
1940
1941    converter = lite.TFLiteConverter(
1942        sess.graph_def,
1943        input_tensors=None,
1944        output_tensors=None,
1945        input_arrays_with_shape=[('input', [5, 5])],
1946        output_arrays=None,
1947        experimental_debug_info_func=None)
1948    converter._control_output_arrays = ['PrintV2']
1949    converter.target_spec.supported_ops = [
1950        lite.OpsSet.TFLITE_BUILTINS,
1951        lite.OpsSet.SELECT_TF_OPS,
1952    ]
1953    tflite_model = converter.convert()
1954    self.assertIsNotNone(tflite_model)
1955
1956    model = util._convert_model_from_bytearray_to_object(tflite_model)
1957    self.assertEqual(model.operatorCodes[0].builtinCode,
1958                     schema_fb.BuiltinOperator.ADD)
1959    self.assertEqual(model.operatorCodes[1].builtinCode,
1960                     schema_fb.BuiltinOperator.CUSTOM)
1961    self.assertEqual(model.operatorCodes[1].customCode, b'FlexStringFormat')
1962    self.assertEqual(model.operatorCodes[2].builtinCode,
1963                     schema_fb.BuiltinOperator.CUSTOM)
1964    self.assertEqual(model.operatorCodes[2].customCode, b'FlexPrintV2')
1965
1966    # Check values from converted model.
1967    interpreter = Interpreter(model_content=tflite_model)
1968    interpreter.allocate_tensors()
1969
1970    input_details = interpreter.get_input_details()
1971    self.assertLen(input_details, 1)
1972    self.assertEqual('input', input_details[0]['name'])
1973    self.assertEqual(np.float32, input_details[0]['dtype'])
1974    self.assertAllEqual([5, 5], input_details[0]['shape'])
1975    self.assertEqual((0., 0.), input_details[0]['quantization'])
1976
1977    output_details = interpreter.get_output_details()
1978    self.assertLen(output_details, 0)
1979
1980  def testModifyIOToUint8(self):
1981    # Tests the object detection model that cannot be loaded in TensorFlow.
1982    self._initObjectDetectionArgs()
1983
1984    def representative_dataset_gen():
1985      for _ in range(2):
1986        yield [
1987            np.random.uniform(low=0, high=1,
1988                              size=(1, 300, 300, 3)).astype(np.float32)
1989        ]
1990
1991    converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
1992                                                       self._input_arrays,
1993                                                       self._output_arrays,
1994                                                       self._input_shapes)
1995    converter.representative_dataset = representative_dataset_gen
1996    converter.target_spec.supported_ops = {lite.OpsSet.TFLITE_BUILTINS_INT8}
1997    converter.inference_type = dtypes.int8
1998    converter.inference_input_type = dtypes.uint8
1999    converter.inference_output_type = dtypes.uint8
2000    converter.experimental_new_quantizer = True
2001    converter.quantized_input_stats = {
2002        'normalized_input_image_tensor': (0., 1.)
2003    }  # mean, std_dev
2004    converter.allow_custom_ops = True
2005    tflite_model = converter.convert()
2006
2007    self.assertIsNotNone(tflite_model)
2008
2009    model = util._convert_model_from_bytearray_to_object(tflite_model)
2010    quant_opcode_idxs = util.get_quantize_opcode_idx(model)
2011
2012    subgraph = model.subgraphs[0]
2013    tensors = subgraph.tensors
2014    operators = subgraph.operators
2015    for op in operators:
2016      if op.opcodeIndex in quant_opcode_idxs:
2017        input_type = util._convert_tflite_enum_type_to_tf_type(
2018            tensors[op.inputs[0]].type)
2019        if op.outputs[0] in subgraph.outputs:
2020          self.assertEqual(input_type, dtypes.float32)
2021
2022
2023class FromSavedModelTest(TestModels):
2024
2025  def _createSavedModel(self, shape):
2026    """Create a simple SavedModel."""
2027    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
2028    with ops.Graph().as_default():
2029      with session.Session() as sess:
2030        in_tensor_1 = array_ops.placeholder(
2031            shape=shape, dtype=dtypes.float32, name='inputB')
2032        in_tensor_2 = array_ops.placeholder(
2033            shape=shape, dtype=dtypes.float32, name='inputA')
2034        out_tensor = in_tensor_1 + in_tensor_2
2035        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
2036        outputs = {'z': out_tensor}
2037        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
2038    return saved_model_dir
2039
2040  def testSimpleModel(self):
2041    """Test a SavedModel."""
2042    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2043
2044    # Convert model and ensure model is not None.
2045    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2046    tflite_model = converter.convert()
2047    self.assertIsNotNone(tflite_model)
2048
2049    interpreter = Interpreter(model_content=tflite_model)
2050    interpreter.allocate_tensors()
2051
2052    input_details = interpreter.get_input_details()
2053    self.assertLen(input_details, 2)
2054    self.assertStartsWith(input_details[0]['name'], 'inputA')
2055    self.assertEqual(np.float32, input_details[0]['dtype'])
2056    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
2057    self.assertEqual((0., 0.), input_details[0]['quantization'])
2058
2059    self.assertStartsWith(input_details[1]['name'], 'inputB')
2060    self.assertEqual(np.float32, input_details[1]['dtype'])
2061    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
2062    self.assertEqual((0., 0.), input_details[1]['quantization'])
2063
2064    output_details = interpreter.get_output_details()
2065    self.assertLen(output_details, 1)
2066    self.assertStartsWith(output_details[0]['name'], 'add')
2067    self.assertEqual(np.float32, output_details[0]['dtype'])
2068    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
2069    self.assertEqual((0., 0.), output_details[0]['quantization'])
2070
2071  def testNoneBatchSize(self):
2072    """Test a SavedModel, with None in input tensor's shape."""
2073    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
2074
2075    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2076    tflite_model = converter.convert()
2077    self.assertIsNotNone(tflite_model)
2078
2079    # Check values from converted model.
2080    interpreter = Interpreter(model_content=tflite_model)
2081    interpreter.allocate_tensors()
2082
2083    input_details = interpreter.get_input_details()
2084    self.assertLen(input_details, 2)
2085    self.assertStartsWith(input_details[0]['name'], 'inputA')
2086    self.assertEqual(np.float32, input_details[0]['dtype'])
2087    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
2088    self.assertEqual((0., 0.), input_details[0]['quantization'])
2089
2090    self.assertStartsWith(input_details[1]['name'], 'inputB')
2091    self.assertEqual(np.float32, input_details[1]['dtype'])
2092    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
2093    self.assertEqual((0., 0.), input_details[1]['quantization'])
2094
2095    output_details = interpreter.get_output_details()
2096    self.assertLen(output_details, 1)
2097    self.assertStartsWith(output_details[0]['name'], 'add')
2098    self.assertEqual(np.float32, output_details[0]['dtype'])
2099    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
2100    self.assertEqual((0., 0.), output_details[0]['quantization'])
2101
2102  def testOrderInputArrays(self):
2103    """Test a SavedModel ordering of input arrays."""
2104    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2105
2106    converter = lite.TFLiteConverter.from_saved_model(
2107        saved_model_dir, input_arrays=['inputB', 'inputA'])
2108    tflite_model = converter.convert()
2109    self.assertIsNotNone(tflite_model)
2110
2111    # Check values from converted model.
2112    interpreter = Interpreter(model_content=tflite_model)
2113    interpreter.allocate_tensors()
2114
2115    input_details = interpreter.get_input_details()
2116    self.assertLen(input_details, 2)
2117    self.assertStartsWith(input_details[0]['name'], 'inputA')
2118    self.assertEqual(np.float32, input_details[0]['dtype'])
2119    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
2120    self.assertEqual((0., 0.), input_details[0]['quantization'])
2121
2122    self.assertStartsWith(input_details[1]['name'], 'inputB')
2123    self.assertEqual(np.float32, input_details[1]['dtype'])
2124    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
2125    self.assertEqual((0., 0.), input_details[1]['quantization'])
2126
2127    output_details = interpreter.get_output_details()
2128    self.assertLen(output_details, 1)
2129    self.assertStartsWith(output_details[0]['name'], 'add')
2130    self.assertEqual(np.float32, output_details[0]['dtype'])
2131    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
2132    self.assertEqual((0., 0.), output_details[0]['quantization'])
2133
2134  def testShapeOverriding(self):
2135    """Test a SavedModel with the input_shapes arugment."""
2136    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
2137
2138    # Convert model and ensure model is not None.
2139    converter = lite.TFLiteConverter.from_saved_model(
2140        saved_model_dir,
2141        input_shapes={
2142            'inputA': [2, 16, 16, 3],
2143            'inputB': [2, 16, 16, 3]
2144        })
2145    tflite_model = converter.convert()
2146    self.assertIsNotNone(tflite_model)
2147
2148    interpreter = Interpreter(model_content=tflite_model)
2149    interpreter.allocate_tensors()
2150
2151    input_details = interpreter.get_input_details()
2152    self.assertLen(input_details, 2)
2153    self.assertStartsWith(input_details[0]['name'], 'inputA')
2154    self.assertEqual(np.float32, input_details[0]['dtype'])
2155    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
2156    self.assertEqual((0., 0.), input_details[0]['quantization'])
2157
2158    self.assertStartsWith(input_details[1]['name'], 'inputB')
2159    self.assertEqual(np.float32, input_details[1]['dtype'])
2160    self.assertAllEqual([2, 16, 16, 3], input_details[1]['shape'])
2161    self.assertEqual((0., 0.), input_details[1]['quantization'])
2162
2163    output_details = interpreter.get_output_details()
2164    self.assertLen(output_details, 1)
2165    self.assertStartsWith(output_details[0]['name'], 'add')
2166    self.assertEqual(np.float32, output_details[0]['dtype'])
2167    self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape'])
2168    self.assertEqual((0., 0.), output_details[0]['quantization'])
2169
2170  def testWrongInputShapes(self):
2171    """Test a SavedModel with a wrong name in the input_shapes argument."""
2172    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2173
2174    # Check case where input shape is given.
2175    with self.assertRaises(ValueError):
2176      lite.TFLiteConverter.from_saved_model(
2177          saved_model_dir,
2178          input_arrays=['inputA'],
2179          input_shapes={'wrong_input': [1, 16, 16, 3]})
2180
2181  def testSubsetInputShaapes(self):
2182    """Test a SavedModel with a subset of the input array names of the model."""
2183    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2184
2185    # Check case where input shape is given.
2186    converter = lite.TFLiteConverter.from_saved_model(
2187        saved_model_dir,
2188        input_arrays=['inputA'],
2189        input_shapes={'inputA': [1, 16, 16, 3]})
2190
2191    # Since we only partially specify the input, this is not allowed.
2192    with self.assertRaises(ConverterError):
2193      _ = converter.convert()
2194
2195    # Check case where input shape is None.
2196    converter = lite.TFLiteConverter.from_saved_model(
2197        saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
2198
2199    # Since we only partially specify the input, this is not allowed.
2200    with self.assertRaises(ConverterError):
2201      _ = converter.convert()
2202
2203  def testSimpleModelTocoConverter(self):
2204    """Test a SavedModel with deprecated TocoConverter."""
2205    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2206
2207    # Convert model and ensure model is not None.
2208    converter = lite.TocoConverter.from_saved_model(saved_model_dir)
2209    tflite_model = converter.convert()
2210    self.assertIsNotNone(tflite_model)
2211
2212    # Ensure the model is able to load.
2213    interpreter = Interpreter(model_content=tflite_model)
2214    interpreter.allocate_tensors()
2215
2216  def testGraphDebugInfo(self):
2217    """Test a SavedModel has debug info captured."""
2218    self.skipTest(
2219        'b/221093690: The debug info is not from self._createSavedModel(), '
2220        'but from saved_model.loader_impl().')
2221    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2222    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2223    converter.convert()
2224    self.assertValidDebugInfo(converter._debug_info)
2225
2226
2227class MyAddLayer(keras.layers.Layer):
2228
2229  def __init__(self, increment, **kwargs):
2230    super(MyAddLayer, self).__init__(**kwargs)
2231    self._increment = increment
2232
2233  def call(self, inputs):
2234    return inputs + self._increment
2235
2236  def get_config(self):
2237    config = super(MyAddLayer, self).get_config()
2238    config['increment'] = self._increment
2239    return config
2240
2241
2242class FromKerasFile(TestModels, parameterized.TestCase):
2243
2244  def setUp(self):
2245    super(FromKerasFile, self).setUp()
2246    self._keras_file = None
2247    self._custom_objects = None
2248    if not context.executing_eagerly():
2249      keras.backend.clear_session()
2250
2251  def tearDown(self):
2252    if self._keras_file:
2253      os.remove(self._keras_file)
2254    super(FromKerasFile, self).tearDown()
2255
2256  def _getSequentialModel(self, include_custom_layer=False):
2257    model = keras.models.Sequential()
2258    model.add(keras.layers.Dense(2, input_shape=(3,)))
2259    if include_custom_layer:
2260      model.add(MyAddLayer(1.0))
2261    model.add(keras.layers.RepeatVector(3))
2262    model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
2263    model.compile(
2264        loss=keras.losses.MSE,
2265        optimizer='sgd',
2266        metrics=[keras.metrics.categorical_accuracy],
2267        sample_weight_mode='temporal')
2268    x = np.random.random((1, 3))
2269    y = np.random.random((1, 3, 3))
2270    model.train_on_batch(x, y)
2271    model.predict(x)
2272
2273    try:
2274      fd, self._keras_file = tempfile.mkstemp('.h5')
2275      keras.models.save_model(model, self._keras_file)
2276    finally:
2277      os.close(fd)
2278
2279    if include_custom_layer:
2280      self._custom_objects = {'MyAddLayer': MyAddLayer}
2281
2282  @parameterized.named_parameters(('_graph', context.graph_mode),
2283                                  ('_eager', context.eager_mode))
2284  def testSequentialModel(self, test_context):
2285    """Test a Sequential tf.keras model with default inputs."""
2286    with test_context():
2287      self._getSequentialModel()
2288
2289      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2290      tflite_model = converter.convert()
2291      self.assertIsNotNone(tflite_model)
2292
2293    # Check tensor details of converted model.
2294    interpreter = Interpreter(model_content=tflite_model)
2295    interpreter.allocate_tensors()
2296
2297    input_details = interpreter.get_input_details()
2298    self.assertLen(input_details, 1)
2299    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2300    self.assertEqual(np.float32, input_details[0]['dtype'])
2301    self.assertAllEqual([1, 3], input_details[0]['shape'])
2302    self.assertEqual((0., 0.), input_details[0]['quantization'])
2303
2304    output_details = interpreter.get_output_details()
2305    self.assertLen(output_details, 1)
2306    self.assertEqual(np.float32, output_details[0]['dtype'])
2307    self.assertAllEqual([1, 3, 3], output_details[0]['shape'])
2308    self.assertEqual((0., 0.), output_details[0]['quantization'])
2309
2310    # Check inference of converted model.
2311    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2312    interpreter.set_tensor(input_details[0]['index'], input_data)
2313    interpreter.invoke()
2314    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2315
2316    keras_model = keras.models.load_model(self._keras_file)
2317    keras_result = keras_model.predict(input_data)
2318
2319    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2320
2321  @parameterized.named_parameters(('_graph', context.graph_mode),
2322                                  ('_eager', context.eager_mode))
2323  def testCustomLayer(self, test_context):
2324    """Test a Sequential tf.keras model with default inputs."""
2325    with test_context():
2326      self._getSequentialModel(include_custom_layer=True)
2327
2328      converter = lite.TFLiteConverter.from_keras_model_file(
2329          self._keras_file, custom_objects=self._custom_objects)
2330      tflite_model = converter.convert()
2331      self.assertIsNotNone(tflite_model)
2332
2333    # Check tensor details of converted model.
2334    interpreter = Interpreter(model_content=tflite_model)
2335    interpreter.allocate_tensors()
2336
2337    input_details = interpreter.get_input_details()
2338    output_details = interpreter.get_output_details()
2339
2340    # Check inference of converted model.
2341    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2342    interpreter.set_tensor(input_details[0]['index'], input_data)
2343    interpreter.invoke()
2344    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2345
2346    keras_model = keras.models.load_model(
2347        self._keras_file, custom_objects=self._custom_objects)
2348    keras_result = keras_model.predict(input_data)
2349
2350    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2351
2352  def testSequentialModelInputArray(self):
2353    """Test a Sequential tf.keras model testing input arrays argument."""
2354    ops.disable_eager_execution()
2355    self._getSequentialModel()
2356
2357    # Invalid input array raises error.
2358    with self.assertRaises(ValueError) as error:
2359      lite.TFLiteConverter.from_keras_model_file(
2360          self._keras_file, input_arrays=['invalid-input'])
2361    self.assertEqual("Invalid tensors 'invalid-input' were found.",
2362                     str(error.exception))
2363
2364    # Valid input array.
2365    converter = lite.TFLiteConverter.from_keras_model_file(
2366        self._keras_file, input_arrays=['dense_input'])
2367    tflite_model = converter.convert()
2368    self.assertIsNotNone(tflite_model)
2369
2370  def testSequentialModelInputShape(self):
2371    """Test a Sequential tf.keras model testing input shapes argument."""
2372    self._getSequentialModel()
2373
2374    # Passing in shape of invalid input array raises error.
2375    with self.assertRaises(ValueError) as error:
2376      converter = lite.TFLiteConverter.from_keras_model_file(
2377          self._keras_file, input_shapes={'invalid-input': [2, 3]})
2378    self.assertEqual(
2379        "Invalid tensor 'invalid-input' found in tensor shapes map.",
2380        str(error.exception))
2381
2382    # Passing in shape of valid input array.
2383    converter = lite.TFLiteConverter.from_keras_model_file(
2384        self._keras_file, input_shapes={'dense_input': [2, 3]})
2385    tflite_model = converter.convert()
2386    self.assertIsNotNone(tflite_model)
2387
2388    # Check input shape from converted model.
2389    interpreter = Interpreter(model_content=tflite_model)
2390    interpreter.allocate_tensors()
2391
2392    input_details = interpreter.get_input_details()
2393    self.assertLen(input_details, 1)
2394    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2395    self.assertAllEqual([2, 3], input_details[0]['shape'])
2396
2397  def testSequentialModelOutputArray(self):
2398    """Test a Sequential tf.keras model testing output arrays argument."""
2399    ops.disable_eager_execution()
2400    self._getSequentialModel()
2401
2402    # Invalid output array raises error.
2403    with self.assertRaises(ValueError) as error:
2404      lite.TFLiteConverter.from_keras_model_file(
2405          self._keras_file, output_arrays=['invalid-output'])
2406    self.assertEqual("Invalid tensors 'invalid-output' were found.",
2407                     str(error.exception))
2408
2409    # Valid output array.
2410    converter = lite.TFLiteConverter.from_keras_model_file(
2411        self._keras_file, output_arrays=['time_distributed/Reshape_1'])
2412    tflite_model = converter.convert()
2413    self.assertIsNotNone(tflite_model)
2414
2415  @parameterized.named_parameters(('_graph', context.graph_mode),
2416                                  ('_eager', context.eager_mode))
2417  def testFunctionalModel(self, test_context):
2418    """Test a Functional tf.keras model with default inputs."""
2419    with test_context():
2420      inputs = keras.layers.Input(shape=(3,), name='input')
2421      x = keras.layers.Dense(2)(inputs)
2422      output = keras.layers.Dense(3)(x)
2423
2424      model = keras.models.Model(inputs, output)
2425      model.compile(
2426          loss=keras.losses.MSE,
2427          optimizer='sgd',
2428          metrics=[keras.metrics.categorical_accuracy])
2429      x = np.random.random((1, 3))
2430      y = np.random.random((1, 3))
2431      model.train_on_batch(x, y)
2432
2433      model.predict(x)
2434      fd, self._keras_file = tempfile.mkstemp('.h5')
2435      try:
2436        keras.models.save_model(model, self._keras_file)
2437      finally:
2438        os.close(fd)
2439
2440      # Convert to TFLite model.
2441      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2442      tflite_model = converter.convert()
2443      self.assertIsNotNone(tflite_model)
2444
2445    # Check tensor details of converted model.
2446    interpreter = Interpreter(model_content=tflite_model)
2447    interpreter.allocate_tensors()
2448
2449    input_details = interpreter.get_input_details()
2450    self.assertLen(input_details, 1)
2451    self.assertEqual('input', input_details[0]['name'])
2452    self.assertEqual(np.float32, input_details[0]['dtype'])
2453    self.assertAllEqual([1, 3], input_details[0]['shape'])
2454    self.assertEqual((0., 0.), input_details[0]['quantization'])
2455
2456    output_details = interpreter.get_output_details()
2457    self.assertLen(output_details, 1)
2458    self.assertEqual(np.float32, output_details[0]['dtype'])
2459    self.assertAllEqual([1, 3], output_details[0]['shape'])
2460    self.assertEqual((0., 0.), output_details[0]['quantization'])
2461
2462    # Check inference of converted model.
2463    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2464    interpreter.set_tensor(input_details[0]['index'], input_data)
2465    interpreter.invoke()
2466    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2467
2468    keras_model = keras.models.load_model(self._keras_file)
2469    keras_result = keras_model.predict(input_data)
2470
2471    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2472
2473  def _getFunctionalModelMultipleInputs(self):
2474    a = keras.layers.Input(shape=(3,), name='input_a')
2475    b = keras.layers.Input(shape=(3,), name='input_b')
2476    dense = keras.layers.Dense(4, name='dense')
2477    c = dense(a)
2478    d = dense(b)
2479    e = keras.layers.Dropout(0.5, name='dropout')(c)
2480
2481    model = keras.models.Model([a, b], [d, e])
2482    model.compile(
2483        loss=keras.losses.MSE,
2484        optimizer='sgd',
2485        metrics=[keras.metrics.mae],
2486        loss_weights=[1., 0.5])
2487
2488    input_a_np = np.random.random((10, 3))
2489    input_b_np = np.random.random((10, 3))
2490    output_d_np = np.random.random((10, 4))
2491    output_e_np = np.random.random((10, 4))
2492    model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
2493
2494    model.predict([input_a_np, input_b_np], batch_size=5)
2495    fd, self._keras_file = tempfile.mkstemp('.h5')
2496    try:
2497      keras.models.save_model(model, self._keras_file)
2498    finally:
2499      os.close(fd)
2500
2501  def testFunctionalModelMultipleInputs(self):
2502    """Test a Functional tf.keras model with multiple inputs and outputs."""
2503    self._getFunctionalModelMultipleInputs()
2504
2505    # Convert to TFLite model.
2506    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2507    tflite_model = converter.convert()
2508    self.assertIsNotNone(tflite_model)
2509
2510    # Check values from converted model.
2511    interpreter = Interpreter(model_content=tflite_model)
2512    interpreter.allocate_tensors()
2513
2514    input_details = interpreter.get_input_details()
2515    self.assertLen(input_details, 2)
2516    self.assertEndsWith(input_details[0]['name'], 'input_a')
2517    self.assertEqual(np.float32, input_details[0]['dtype'])
2518    self.assertAllEqual([1, 3], input_details[0]['shape'])
2519    self.assertEqual((0., 0.), input_details[0]['quantization'])
2520
2521    self.assertEndsWith(input_details[1]['name'], 'input_b')
2522    self.assertEqual(np.float32, input_details[1]['dtype'])
2523    self.assertAllEqual([1, 3], input_details[1]['shape'])
2524    self.assertEqual((0., 0.), input_details[1]['quantization'])
2525
2526    output_details = interpreter.get_output_details()
2527    self.assertLen(output_details, 2)
2528    self.assertEqual(np.float32, output_details[0]['dtype'])
2529    self.assertAllEqual([1, 4], output_details[0]['shape'])
2530    self.assertEqual((0., 0.), output_details[0]['quantization'])
2531
2532    self.assertEqual(np.float32, output_details[1]['dtype'])
2533    self.assertAllEqual([1, 4], output_details[1]['shape'])
2534    self.assertEqual((0., 0.), output_details[1]['quantization'])
2535
2536  def testShapeOverriding(self):
2537    """Test a Functional tf.keras model with input shape overriding."""
2538    self._getFunctionalModelMultipleInputs()
2539
2540    # Convert to TFLite model.
2541    converter = lite.TFLiteConverter.from_keras_model_file(
2542        self._keras_file, input_shapes={
2543            'input_a': {2, 3},
2544            'input_b': {2, 3}
2545        })
2546    tflite_model = converter.convert()
2547    self.assertIsNotNone(tflite_model)
2548
2549    # Check values from converted model.
2550    interpreter = Interpreter(model_content=tflite_model)
2551    interpreter.allocate_tensors()
2552
2553    input_details = interpreter.get_input_details()
2554    self.assertLen(input_details, 2)
2555    self.assertEndsWith(input_details[0]['name'], 'input_a')
2556    self.assertEqual(np.float32, input_details[0]['dtype'])
2557    self.assertAllEqual([2, 3], input_details[0]['shape'])
2558    self.assertEqual((0., 0.), input_details[0]['quantization'])
2559
2560    self.assertEndsWith(input_details[1]['name'], 'input_b')
2561    self.assertEqual(np.float32, input_details[1]['dtype'])
2562    self.assertAllEqual([2, 3], input_details[1]['shape'])
2563    self.assertEqual((0., 0.), input_details[1]['quantization'])
2564
2565    output_details = interpreter.get_output_details()
2566    self.assertLen(output_details, 2)
2567    self.assertEqual(np.float32, output_details[0]['dtype'])
2568    self.assertAllEqual([2, 4], output_details[0]['shape'])
2569    self.assertEqual((0., 0.), output_details[0]['quantization'])
2570
2571    self.assertEqual(np.float32, output_details[1]['dtype'])
2572    self.assertAllEqual([2, 4], output_details[1]['shape'])
2573    self.assertEqual((0., 0.), output_details[1]['quantization'])
2574
2575  def testPartialShapeOverriding(self):
2576    """Test a Functional tf.keras model with partial input shape overriding."""
2577    self._getFunctionalModelMultipleInputs()
2578
2579    # Convert to TFLite model.
2580    converter = lite.TFLiteConverter.from_keras_model_file(
2581        self._keras_file, input_shapes={'input_a': {2, 3}})
2582    tflite_model = converter.convert()
2583    self.assertIsNotNone(tflite_model)
2584
2585    # Check values from converted model.
2586    interpreter = Interpreter(model_content=tflite_model)
2587    interpreter.allocate_tensors()
2588
2589    input_details = interpreter.get_input_details()
2590    self.assertLen(input_details, 2)
2591    self.assertEndsWith(input_details[0]['name'], 'input_a')
2592    self.assertEqual(np.float32, input_details[0]['dtype'])
2593    self.assertAllEqual([2, 3], input_details[0]['shape'])
2594    self.assertEqual((0., 0.), input_details[0]['quantization'])
2595
2596    self.assertEndsWith(input_details[1]['name'], 'input_b')
2597    self.assertEqual(np.float32, input_details[1]['dtype'])
2598    self.assertAllEqual([1, 3], input_details[1]['shape'])
2599    self.assertEqual((0., 0.), input_details[1]['quantization'])
2600
2601    output_details = interpreter.get_output_details()
2602    self.assertLen(output_details, 2)
2603    self.assertEqual(np.float32, output_details[0]['dtype'])
2604    self.assertAllEqual([1, 4], output_details[0]['shape'])
2605    self.assertEqual((0., 0.), output_details[0]['quantization'])
2606
2607    self.assertEqual(np.float32, output_details[1]['dtype'])
2608    self.assertAllEqual([2, 4], output_details[1]['shape'])
2609    self.assertEqual((0., 0.), output_details[1]['quantization'])
2610
2611  def testWrongShapeOverriding(self):
2612    """Test a Functional tf.keras model with wrong input shape overriding."""
2613    self._getFunctionalModelMultipleInputs()
2614
2615    # Convert to TFLite model.
2616    with self.assertRaises(ValueError):
2617      lite.TFLiteConverter.from_keras_model_file(
2618          self._keras_file, input_shapes={'wrong_input': {2, 3}})
2619
2620  def testFunctionalSequentialModel(self):
2621    """Test a Functional tf.keras model containing a Sequential model."""
2622    model = keras.models.Sequential()
2623    model.add(keras.layers.Dense(2, input_shape=(3,)))
2624    model.add(keras.layers.RepeatVector(3))
2625    model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
2626    model = keras.models.Model(model.input, model.output)
2627
2628    model.compile(
2629        loss=keras.losses.MSE,
2630        optimizer='sgd',
2631        metrics=[keras.metrics.categorical_accuracy],
2632        sample_weight_mode='temporal')
2633    x = np.random.random((1, 3))
2634    y = np.random.random((1, 3, 3))
2635    model.train_on_batch(x, y)
2636    model.predict(x)
2637
2638    model.predict(x)
2639    fd, self._keras_file = tempfile.mkstemp('.h5')
2640    try:
2641      keras.models.save_model(model, self._keras_file)
2642    finally:
2643      os.close(fd)
2644
2645    # Convert to TFLite model.
2646    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2647    tflite_model = converter.convert()
2648    self.assertIsNotNone(tflite_model)
2649
2650    # Check tensor details of converted model.
2651    interpreter = Interpreter(model_content=tflite_model)
2652    interpreter.allocate_tensors()
2653
2654    input_details = interpreter.get_input_details()
2655    self.assertLen(input_details, 1)
2656    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2657    self.assertEqual(np.float32, input_details[0]['dtype'])
2658    self.assertAllEqual([1, 3], input_details[0]['shape'])
2659    self.assertEqual((0., 0.), input_details[0]['quantization'])
2660
2661    output_details = interpreter.get_output_details()
2662    self.assertLen(output_details, 1)
2663    self.assertEqual(np.float32, output_details[0]['dtype'])
2664    self.assertAllEqual([1, 3, 3], output_details[0]['shape'])
2665    self.assertEqual((0., 0.), output_details[0]['quantization'])
2666
2667    # Check inference of converted model.
2668    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2669    interpreter.set_tensor(input_details[0]['index'], input_data)
2670    interpreter.invoke()
2671    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2672
2673    keras_model = keras.models.load_model(self._keras_file)
2674    keras_result = keras_model.predict(input_data)
2675
2676    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2677
2678  def testSequentialModelTocoConverter(self):
2679    """Test a Sequential tf.keras model with deprecated TocoConverter."""
2680    self._getSequentialModel()
2681
2682    converter = lite.TocoConverter.from_keras_model_file(self._keras_file)
2683    tflite_model = converter.convert()
2684    self.assertIsNotNone(tflite_model)
2685
2686    # Ensure the model is able to load.
2687    interpreter = Interpreter(model_content=tflite_model)
2688    interpreter.allocate_tensors()
2689
2690  @parameterized.named_parameters(('_graph', context.graph_mode),
2691                                  ('_eager', context.eager_mode))
2692  def testGraphDebugInfo(self, test_context):
2693    """Test a Sequential tf.keras model has debug info captured."""
2694    with test_context():
2695      self._getSequentialModel()
2696      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2697      converter.convert()
2698      self.assertValidDebugInfo(converter._debug_info)
2699
2700
2701class SparsityTest(TestModels):
2702
2703  def _getSparsificableModel(self, matrix_b_values):
2704    with ops.Graph().as_default():
2705      in_tensor_1 = array_ops.placeholder(
2706          shape=[16, 4], dtype=dtypes.float32, name='input1')
2707      in_tensor_2 = constant_op.constant(
2708          matrix_b_values, shape=[4, 8], dtype=dtypes.float32)
2709      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2)
2710      sess = session.Session()
2711
2712    return (sess, [in_tensor_1], [out_tensor])
2713
2714  def testRandomSparsity(self):
2715    matrix_b_values = [
2716        0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0,
2717        0, 0, 0, 0, 0, 0, 0, 1
2718    ]
2719    sess, inputs, outputs = self._getSparsificableModel(matrix_b_values)
2720    float_converter = lite.TFLiteConverter.from_session(sess, inputs, outputs)
2721    float_converter.optimizations = [lite.Optimize.EXPERIMENTAL_SPARSITY]
2722    float_tflite_model = float_converter.convert()
2723    self.assertIsNotNone(float_tflite_model)
2724    # Check the conversion metadata.
2725    metadata = get_conversion_metadata(float_tflite_model)
2726    self.assertIsNotNone(metadata)
2727    self.assertAllEqual([metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY],
2728                        metadata.options.modelOptimizationModes)
2729
2730  def testSparsifyModel(self):
2731    matrix_b_values = [
2732        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2733        0, 0, 0, 0, 0, 0, 1, 0
2734    ]
2735    sess, inputs, outputs = self._getSparsificableModel(matrix_b_values)
2736    converter = lite.TFLiteConverter.from_session(sess, inputs, outputs)
2737    converter.optimizations = {lite.Optimize.EXPERIMENTAL_SPARSITY}
2738    tflite_model = converter.convert()
2739    self.assertTrue(tflite_model)
2740    # Check the conversion metadata.
2741    metadata = get_conversion_metadata(tflite_model)
2742    self.assertIsNotNone(metadata)
2743    self.assertAllEqual([
2744        metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY,
2745    ], metadata.options.modelOptimizationModes)
2746
2747  def testSparsifyQuantizedModel(self):
2748    matrix_b_values = [
2749        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2750        0, 0, 0, 0, 0, 0, 1, 0
2751    ]
2752    sess, inputs, outputs = self._getSparsificableModel(matrix_b_values)
2753    converter = lite.TFLiteConverter.from_session(sess, inputs, outputs)
2754    converter.optimizations = {
2755        lite.Optimize.DEFAULT, lite.Optimize.EXPERIMENTAL_SPARSITY
2756    }
2757    tflite_model = converter.convert()
2758    self.assertIsNotNone(tflite_model)
2759    # Check the conversion metadata.
2760    metadata = get_conversion_metadata(tflite_model)
2761    self.assertIsNotNone(metadata)
2762    self.assertAllEqual([
2763        metadata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE,
2764        metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY,
2765    ], metadata.options.modelOptimizationModes)
2766
2767
2768class GrapplerTest(TestModels, parameterized.TestCase):
2769
2770  def testConstantFolding(self):
2771    ops.disable_eager_execution()
2772    # Constant folding handles the tf.broadcast_to operation which was not
2773    # supported by the TFLite at the time this test was added.
2774    with ops.Graph().as_default():
2775      in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32)
2776      y_const = constant_op.constant([1., 2., 3.])
2777      y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
2778      out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output')
2779      sess = session.Session()
2780
2781    # Convert model.
2782    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2783                                                  [out_tensor])
2784    tflite_model = converter.convert()
2785
2786    # Check values from converted model.
2787    interpreter = Interpreter(model_content=tflite_model)
2788    interpreter.allocate_tensors()
2789
2790    input_details = interpreter.get_input_details()
2791    self.assertLen(input_details, 1)
2792    self.assertEqual('Placeholder', input_details[0]['name'])
2793    self.assertEqual(np.float32, input_details[0]['dtype'])
2794    self.assertAllEqual([3, 3], input_details[0]['shape'])
2795
2796    output_details = interpreter.get_output_details()
2797    self.assertLen(output_details, 1)
2798    self.assertEqual('output', output_details[0]['name'])
2799    self.assertEqual(np.float32, output_details[0]['dtype'])
2800    self.assertAllEqual([3, 3], output_details[0]['shape'])
2801
2802  def testInputNodeIsNotFolded(self):
2803    ops.disable_eager_execution()
2804    # Constant folding handles the tf.broadcast_to operation which was not
2805    # supported by the TFLite at the time this test was added.
2806    with ops.Graph().as_default():
2807      in_tensor = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
2808      y_const = constant_op.constant([1., 2., 3.])
2809      y_add = y_const + y_const
2810      out_tensor = in_tensor * y_add
2811      sess = session.Session()
2812
2813    # Convert model.
2814    converter = lite.TFLiteConverter.from_session(sess, [in_tensor, y_const],
2815                                                  [out_tensor])
2816    tflite_model = converter.convert()
2817
2818    # Check values from converted model.
2819    interpreter = Interpreter(model_content=tflite_model)
2820    interpreter.allocate_tensors()
2821
2822    input_details = interpreter.get_input_details()
2823    self.assertLen(input_details, 2)
2824    self.assertEqual('Placeholder', input_details[0]['name'])
2825    self.assertEqual('Const', input_details[1]['name'])
2826
2827  def testGrapplerConstFolding(self):
2828    # Constant folding converts the following add operation to tf.broadcast_to
2829    # operation which was not supported by the TFLite at the time this test was
2830    # added.
2831    @def_function.function
2832    def plus_placeholder(x, placeholder):
2833      return x + placeholder
2834
2835    with ops.Graph().as_default():
2836      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
2837      out_tensor = plus_placeholder(
2838          array_ops.zeros([2, 2, 2]),
2839          array_ops.reshape(in_tensor, shape=[2, 2]))
2840      sess = session.Session()
2841
2842    # Convert model.
2843    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2844                                                  [out_tensor])
2845    tflite_model = converter.convert()
2846
2847    # Check values from converted model.
2848    interpreter = Interpreter(model_content=tflite_model)
2849    interpreter.allocate_tensors()
2850
2851    input_details = interpreter.get_input_details()
2852    self.assertLen(input_details, 1)
2853    self.assertEqual('Placeholder', input_details[0]['name'])
2854
2855
2856class DefaultConverterAttrsTest(LiteTest):
2857
2858  def testAttrs(self):
2859    with ops.Graph().as_default():
2860      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
2861      out_tensor = in_tensor + in_tensor
2862      sess = session.Session()
2863
2864    # Convert model.
2865    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2866                                                  [out_tensor])
2867
2868    # Assert output format.
2869    self.assertEqual(converter.output_format, lite_constants.TFLITE)
2870
2871    # Assert the default inference type is float.
2872    self.assertEqual(converter.inference_type, dtypes.float32)
2873
2874    # Assert the default inference type overrides are None.
2875    self.assertIsNone(converter.inference_input_type)
2876    self.assertIsNone(converter.inference_output_type)
2877
2878    # Assert the default quantization options are not set.
2879    self.assertEqual(converter.quantized_input_stats, {})
2880    self.assertIsNone(converter.default_ranges_stats)
2881    self.assertFalse(converter.reorder_across_fake_quant)
2882    self.assertFalse(converter.change_concat_input_ranges)
2883
2884    # Assert dropping control dependency is enabled by default.
2885    self.assertIsNotNone(converter.drop_control_dependency)
2886
2887    # Assert dumping extra information is disabled by default.
2888    self.assertIsNone(converter.dump_graphviz_dir)
2889    self.assertFalse(converter.dump_graphviz_video)
2890    self.assertIsNone(converter.conversion_summary_dir)
2891
2892
2893class ControlFlowV1OpsTest(LiteTest):
2894
2895  def testConverterErrorOnControlFlowV1Ops(self):
2896    graph_def_file = resource_loader.get_path_to_datafile(
2897        'testdata/control_flow_v1.pbtxt')
2898    input_arrays = ['a', 'b', 'c', 'd']
2899    output_arrays = ['Merge']
2900
2901    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
2902                                                       input_arrays,
2903                                                       output_arrays)
2904    with self.assertRaises(ConverterError) as error:
2905      converter.convert()
2906    self.assertIn(
2907        'Failed to functionalize Control Flow V1 ops. Consider using Control '
2908        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
2909        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
2910
2911
2912class QuantizationModeTest(LiteTest, parameterized.TestCase):
2913
2914  @parameterized.named_parameters(
2915      ('size', lite.Optimize.OPTIMIZE_FOR_SIZE),
2916      ('latency', lite.Optimize.OPTIMIZE_FOR_LATENCY))
2917  def testDeprecatedOptionWarning(self, optimization):
2918    """Test if the warning message when using TOCO is logged."""
2919    log = io.StringIO()
2920    handler = logging.StreamHandler(log)
2921    logging.root.addHandler(handler)
2922    warning_message = 'please use optimizations=[Optimize.DEFAULT] instead.'
2923    lite.QuantizationMode([optimization], lite.TargetSpec(), None, None)
2924    self.assertIn(warning_message, log.getvalue())
2925    logging.root.removeHandler(handler)
2926
2927
2928if __name__ == '__main__':
2929  test.main()
2930