xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/optimize/debugging/python/debugger_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for QuantizationDebugger."""
16
17import csv
18import io
19import re
20from unittest import mock
21
22from absl.testing import parameterized
23import numpy as np
24import tensorflow as tf
25
26from tensorflow.lite.python import convert
27from tensorflow.lite.python import lite
28from tensorflow.lite.python.metrics import metrics
29from tensorflow.lite.tools.optimize.debugging.python import debugger
30from tensorflow.python.framework import test_util
31from tensorflow.python.platform import test
32from tensorflow.python.trackable import autotrackable
33
34
35def _get_model():
36  """Returns somple model with Conv2D and representative dataset gen."""
37  root = autotrackable.AutoTrackable()
38  kernel_in = np.array([-2, -1, 1, 2], dtype=np.float32).reshape((2, 2, 1, 1))
39
40  @tf.function(
41      input_signature=[tf.TensorSpec(shape=[1, 3, 3, 1], dtype=tf.float32)])
42  def func(inp):
43    kernel = tf.constant(kernel_in, dtype=tf.float32)
44    conv = tf.nn.conv2d(inp, kernel, strides=1, padding='SAME')
45    output = tf.nn.relu(conv, name='output')
46    return output
47
48  root.f = func
49  to_save = root.f.get_concrete_function()
50  return (root, to_save)
51
52
53def _calibration_gen():
54  for i in range(5):
55    yield [np.arange(9).reshape((1, 3, 3, 1)).astype(np.float32) * i]
56
57
58def _convert_model(model, func):
59  """Converts TF model to TFLite float model."""
60  converter = lite.TFLiteConverterV2.from_concrete_functions([func], model)
61  # TODO(b/191205988): Explicitly disable saved model lowering in conversion.
62  converter.experimental_lower_to_saved_model = False
63  return converter.convert()
64
65
66def _quantize_converter(model, func, calibration_gen, debug=True):
67  """Returns a converter appropriate for the function and debug configs."""
68  converter = lite.TFLiteConverterV2.from_concrete_functions([func], model)
69  converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
70  converter.representative_dataset = calibration_gen
71
72  # TODO(b/191205988): Explicitly disable saved model lowering in conversion.
73  converter.experimental_lower_to_saved_model = False
74
75  # Create a TFLite model with new quantizer and numeric verify ops.
76  converter.optimizations = [lite.Optimize.DEFAULT]
77  converter.experimental_new_quantizer = True
78  if debug:
79    converter._experimental_calibrate_only = True
80  return converter
81
82
83def _quantize_model(model,
84                    func,
85                    calibration_gen,
86                    quantized_io=False,
87                    debug=True):
88  """Quantizes model, in debug or normal mode."""
89  converter = _quantize_converter(model, func, calibration_gen, debug)
90  if debug:
91    calibrated = converter.convert()
92    return convert.mlir_quantize(
93        calibrated, enable_numeric_verify=True, fully_quantize=quantized_io)
94  else:
95    return converter.convert()
96
97
98def _dummy_fn(*unused_args):
99  return 0.0
100
101
102class QuantizationDebugOptionsTest(test_util.TensorFlowTestCase,
103                                   parameterized.TestCase):
104
105  @test_util.run_v2_only
106  def test_init_duplicate_keys_raises_ValueError(self):
107    with self.assertRaises(ValueError):
108      debugger.QuantizationDebugOptions(
109          layer_debug_metrics={
110              'a': _dummy_fn,
111              'b': _dummy_fn
112          },
113          model_debug_metrics={
114              'c': _dummy_fn,
115              'd': _dummy_fn
116          },
117          layer_direct_compare_metrics={
118              'a': _dummy_fn,
119              'e': _dummy_fn
120          })
121
122    with self.assertRaises(ValueError):
123      debugger.QuantizationDebugOptions(
124          layer_debug_metrics={
125              'a': _dummy_fn,
126              'b': _dummy_fn
127          },
128          layer_direct_compare_metrics={
129              'a': _dummy_fn,
130              'e': _dummy_fn
131          })
132
133
134class QuantizationDebuggerTest(test_util.TensorFlowTestCase,
135                               parameterized.TestCase):
136
137  @classmethod
138  def setUpClass(cls):
139    super().setUpClass()
140    cls.tf_model_root, cls.tf_model = _get_model()
141    cls.float_model = _convert_model(cls.tf_model_root, cls.tf_model)
142    cls.debug_model_float = _quantize_model(
143        cls.tf_model_root, cls.tf_model, _calibration_gen, quantized_io=False)
144    cls.debug_model_int8 = _quantize_model(
145        cls.tf_model_root, cls.tf_model, _calibration_gen, quantized_io=True)
146
147  @parameterized.named_parameters(
148      ('float_io', False, False),
149      ('quantized_io', True, False),
150      ('float_io_from_converter', False, True),
151      ('quantized_io_from_converter', True, True),
152  )
153  @test_util.run_v2_only
154  def test_layer_metrics(self, quantized_io, from_converter):
155    options = debugger.QuantizationDebugOptions(
156        layer_debug_metrics={'l1_norm': lambda diffs: np.mean(np.abs(diffs))})
157    if not from_converter:
158      if quantized_io:
159        debug_model = QuantizationDebuggerTest.debug_model_int8
160      else:
161        debug_model = QuantizationDebuggerTest.debug_model_float
162      quant_debugger = debugger.QuantizationDebugger(
163          quant_debug_model_content=debug_model,
164          debug_dataset=_calibration_gen,
165          debug_options=options)
166    else:
167      options.fully_quantize = quantized_io
168      quant_debugger = debugger.QuantizationDebugger(
169          converter=_quantize_converter(self.tf_model_root, self.tf_model,
170                                        _calibration_gen),
171          debug_dataset=_calibration_gen,
172          debug_options=options)
173
174    quant_debugger.run()
175
176    expected_quant_io_metrics = {
177        'num_elements': 9,
178        'stddev': 0.03850026,
179        'mean_error': 0.01673192,
180        'max_abs_error': 0.10039272,
181        'mean_squared_error': 0.0027558778,
182        'l1_norm': 0.023704167,
183    }
184    expected_float_io_metrics = {
185        'num_elements': 9,
186        'stddev': 0.050998904,
187        'mean_error': 0.007843441,
188        'max_abs_error': 0.105881885,
189        'mean_squared_error': 0.004357292,
190        'l1_norm': 0.035729896,
191    }
192    expected_metrics = (
193        expected_quant_io_metrics
194        if quantized_io else expected_float_io_metrics)
195    self.assertLen(quant_debugger.layer_statistics, 1)
196    actual_metrics = next(iter(quant_debugger.layer_statistics.values()))
197
198    self.assertCountEqual(expected_metrics.keys(), actual_metrics.keys())
199    for key, value in expected_metrics.items():
200      self.assertAlmostEqual(value, actual_metrics[key], places=5)
201
202    buffer = io.StringIO()
203    quant_debugger.layer_statistics_dump(buffer)
204    reader = csv.DictReader(buffer.getvalue().split())
205    actual_values = next(iter(reader))
206
207    expected_values = expected_metrics.copy()
208    expected_values.update({
209        'op_name': 'CONV_2D',
210        'tensor_idx': 7,
211        'scale': 0.15686275,
212        'zero_point': -128,
213        'tensor_name': r'Identity[1-9]?$'
214    })
215    for key, value in expected_values.items():
216      if isinstance(value, str):
217        self.assertIsNotNone(
218            re.match(value, actual_values[key]),
219            'String is different from expected string. Please fix test code if'
220            " it's being affected by graph manipulation changes.")
221      elif isinstance(value, list):
222        self.assertAlmostEqual(
223            value[0], float(actual_values[key][1:-1]), places=5)
224      else:
225        self.assertAlmostEqual(value, float(actual_values[key]), places=5)
226
227  @parameterized.named_parameters(
228      ('float_io', False),
229      ('quantized_io', True),
230  )
231  @test_util.run_v2_only
232  def test_model_metrics(self, quantized_io):
233    if quantized_io:
234      debug_model = QuantizationDebuggerTest.debug_model_int8
235    else:
236      debug_model = QuantizationDebuggerTest.debug_model_float
237    options = debugger.QuantizationDebugOptions(
238        model_debug_metrics={'stdev': lambda x, y: np.std(x[0] - y[0])})
239    quant_debugger = debugger.QuantizationDebugger(
240        quant_debug_model_content=debug_model,
241        float_model_content=QuantizationDebuggerTest.float_model,
242        debug_dataset=_calibration_gen,
243        debug_options=options)
244    quant_debugger.run()
245
246    expected_metrics = {'stdev': 0.050998904}
247    actual_metrics = quant_debugger.model_statistics
248
249    self.assertCountEqual(expected_metrics.keys(), actual_metrics.keys())
250    for key, value in expected_metrics.items():
251      self.assertAlmostEqual(value, actual_metrics[key], places=5)
252
253  @parameterized.named_parameters(
254      ('float_io', False),
255      ('quantized_io', True),
256  )
257  @test_util.run_v2_only
258  def test_layer_direct_compare_metrics(self, quantized_io):
259
260    def _corr(float_values, quant_values, scale, zero_point):
261      dequant_values = (quant_values.astype(np.int32) - zero_point) * scale
262      return np.corrcoef(float_values.flatten(), dequant_values.flatten())[0, 1]
263
264    if quantized_io:
265      debug_model = QuantizationDebuggerTest.debug_model_int8
266    else:
267      debug_model = QuantizationDebuggerTest.debug_model_float
268
269    options = debugger.QuantizationDebugOptions(
270        layer_direct_compare_metrics={'corr': _corr})
271    quant_debugger = debugger.QuantizationDebugger(
272        quant_debug_model_content=debug_model,
273        debug_dataset=_calibration_gen,
274        debug_options=options)
275    quant_debugger.run()
276
277    expected_metrics = {
278        'corr': 0.99999,
279    }
280    self.assertLen(quant_debugger.layer_statistics, 1)
281    actual_metrics = next(iter(quant_debugger.layer_statistics.values()))
282
283    for key, value in expected_metrics.items():
284      self.assertAlmostEqual(value, actual_metrics[key], places=4)
285
286  @test_util.run_v2_only
287  def test_wrong_input_raises_ValueError(self):
288
289    def wrong_calibration_gen():
290      for _ in range(5):
291        yield [
292            np.ones((1, 3, 3, 1), dtype=np.float32),
293            np.ones((1, 3, 3, 1), dtype=np.float32)
294        ]
295
296    quant_debugger = debugger.QuantizationDebugger(
297        quant_debug_model_content=QuantizationDebuggerTest.debug_model_float,
298        debug_dataset=wrong_calibration_gen)
299    with self.assertRaisesRegex(
300        ValueError, r'inputs provided \(2\).+inputs to the model \(1\)'):
301      quant_debugger.run()
302
303  @test_util.run_v2_only
304  def test_non_debug_model_raises_ValueError(self):
305    normal_quant_model = _quantize_model(
306        QuantizationDebuggerTest.tf_model_root,
307        QuantizationDebuggerTest.tf_model,
308        _calibration_gen,
309        debug=False)
310
311    with self.assertRaisesRegex(
312        ValueError, 'Please check if the quantized model is in debug mode'):
313      debugger.QuantizationDebugger(
314          quant_debug_model_content=normal_quant_model,
315          debug_dataset=_calibration_gen)
316
317  @parameterized.named_parameters(
318      ('empty quantization parameter', {
319          'quantization_parameters': {}
320      }, None),
321      ('empty scales/zero points', {
322          'quantization_parameters': {
323              'scales': [],
324              'zero_points': []
325          }
326      }, None),
327      ('invalid scales/zero points', {
328          'quantization_parameters': {
329              'scales': [1.0],
330              'zero_points': []
331          }
332      }, None),
333      ('correct case', {
334          'quantization_parameters': {
335              'scales': [0.5, 1.0],
336              'zero_points': [42, 7]
337          }
338      }, (0.5, 42)),
339  )
340  def test_get_quant_params(self, tensor_detail, expected_value):
341    self.assertEqual(debugger._get_quant_params(tensor_detail), expected_value)
342
343  @parameterized.named_parameters(
344      ('float_io', False),
345      ('quantized_io', True),
346  )
347  @test_util.run_v2_only
348  def test_denylisted_ops_from_option_setter(self, quantized_io):
349    options = debugger.QuantizationDebugOptions(
350        layer_debug_metrics={'l1_norm': lambda diffs: np.mean(np.abs(diffs))},
351        fully_quantize=quantized_io)
352    quant_debugger = debugger.QuantizationDebugger(
353        converter=_quantize_converter(self.tf_model_root, self.tf_model,
354                                      _calibration_gen),
355        debug_dataset=_calibration_gen,
356        debug_options=options)
357
358    options.denylisted_ops = ['CONV_2D']
359    # TODO(b/195084873): The exception is expected to check whether selective
360    # quantization was done properly, since after the selective quantization
361    # the model will have no quantized layers thus have no NumericVerify ops,
362    # resulted in this exception. Marked with a bug to fix this in more
363    # straightforward way.
364    with self.assertRaisesRegex(
365        ValueError, 'Please check if the quantized model is in debug mode'):
366      quant_debugger.options = options
367
368  @parameterized.named_parameters(
369      ('float_io', False),
370      ('quantized_io', True),
371  )
372  @test_util.run_v2_only
373  def test_denylisted_ops_from_option_constructor(self, quantized_io):
374    options = debugger.QuantizationDebugOptions(
375        layer_debug_metrics={'l1_norm': lambda diffs: np.mean(np.abs(diffs))},
376        fully_quantize=quantized_io,
377        denylisted_ops=['CONV_2D'])
378    # TODO(b/195084873): Count the number of NumericVerify op.
379    with self.assertRaisesRegex(
380        ValueError, 'Please check if the quantized model is in debug mode'):
381      _ = debugger.QuantizationDebugger(
382          converter=_quantize_converter(self.tf_model_root, self.tf_model,
383                                        _calibration_gen),
384          debug_dataset=_calibration_gen,
385          debug_options=options)
386
387  @parameterized.named_parameters(
388      ('float_io', False),
389      ('quantized_io', True),
390  )
391  @test_util.run_v2_only
392  def test_denylisted_nodes_from_option_setter(self, quantized_io):
393    options = debugger.QuantizationDebugOptions(
394        layer_debug_metrics={'l1_norm': lambda diffs: np.mean(np.abs(diffs))},
395        fully_quantize=quantized_io)
396    quant_debugger = debugger.QuantizationDebugger(
397        converter=_quantize_converter(self.tf_model_root, self.tf_model,
398                                      _calibration_gen),
399        debug_dataset=_calibration_gen,
400        debug_options=options)
401
402    options.denylisted_nodes = ['Identity']
403    # TODO(b/195084873): Count the number of NumericVerify op.
404    with self.assertRaisesRegex(
405        ValueError, 'Please check if the quantized model is in debug mode'):
406      quant_debugger.options = options
407
408  @parameterized.named_parameters(
409      ('float_io', False),
410      ('quantized_io', True),
411  )
412  @test_util.run_v2_only
413  def test_denylisted_nodes_from_option_constructor(self, quantized_io):
414    options = debugger.QuantizationDebugOptions(
415        layer_debug_metrics={'l1_norm': lambda diffs: np.mean(np.abs(diffs))},
416        fully_quantize=quantized_io,
417        denylisted_nodes=['Identity'])
418    # TODO(b/195084873): Count the number of NumericVerify op.
419    with self.assertRaisesRegex(
420        ValueError, 'Please check if the quantized model is in debug mode'):
421      _ = debugger.QuantizationDebugger(
422          converter=_quantize_converter(self.tf_model_root, self.tf_model,
423                                        _calibration_gen),
424          debug_dataset=_calibration_gen,
425          debug_options=options)
426
427  @mock.patch.object(metrics.TFLiteMetrics,
428                     'increase_counter_debugger_creation')
429  def test_creation_counter(self, increase_call):
430    debug_model = QuantizationDebuggerTest.debug_model_float
431    debugger.QuantizationDebugger(
432        quant_debug_model_content=debug_model, debug_dataset=_calibration_gen)
433    increase_call.assert_called_once()
434
435
436if __name__ == '__main__':
437  test.main()
438