1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for representative_dataset.py."""
16import random
17
18import numpy as np
19
20from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset
21from tensorflow.python.client import session
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.platform import test
25from tensorflow.python.types import core
26
27
28def _contains_tensor(sample: repr_dataset.RepresentativeSample) -> bool:
29  """Determines whether `sample` contains any tf.Tensors.
30
31  Args:
32    sample: A `RepresentativeSample`.
33
34  Returns:
35    True iff `sample` contains at least tf.Tensors.
36  """
37  return any(map(lambda value: isinstance(value, core.Tensor), sample.values()))
38
39
40class RepresentativeDatasetTest(test.TestCase):
41  """Tests functions for representative datasets."""
42
43  def _assert_tensorlike_all_close(self, sess: session.Session,
44                                   tensorlike_value_1: core.TensorLike,
45                                   tensorlike_value_2: core.TensorLike) -> None:
46    """Asserts that two different TensorLike values are "all close".
47
48    Args:
49      sess: Session instance used to evaluate any tf.Tensors.
50      tensorlike_value_1: A TensorLike value.
51      tensorlike_value_2: A TensorLike value.
52    """
53    if isinstance(tensorlike_value_1, core.Tensor):
54      tensorlike_value_1 = tensorlike_value_1.eval(session=sess)
55
56    if isinstance(tensorlike_value_2, core.Tensor):
57      tensorlike_value_2 = tensorlike_value_2.eval(session=sess)
58
59    self.assertAllClose(tensorlike_value_1, tensorlike_value_2)
60
61  def _assert_sample_values_all_close(
62      self, sess: session.Session,
63      repr_ds_1: repr_dataset.RepresentativeDataset,
64      repr_ds_2: repr_dataset.RepresentativeDataset) -> None:
65    """Asserts that the sample values are "all close" between the two datasets.
66
67    This assumes that the order of corresponding samples is preserved and the
68    size of the two datasets are equal.
69
70    Args:
71      sess: Session instance used to evaluate any tf.Tensors.
72      repr_ds_1: A RepresentativeDataset.
73      repr_ds_2: A RepresentativeDataset.
74    """
75    for sample_1, sample_2 in zip(repr_ds_1, repr_ds_2):
76      self.assertCountEqual(sample_1.keys(), sample_2.keys())
77
78      for input_key in sample_1:
79        self._assert_tensorlike_all_close(sess, sample_1[input_key],
80                                          sample_2[input_key])
81
82  @test_util.deprecated_graph_mode_only
83  def test_replace_tensors_by_numpy_ndarrays_with_tensor_list(self):
84    num_samples = 8
85    samples = [
86        np.random.uniform(low=-1., high=1., size=(3, 3)).astype('f4')
87        for _ in range(num_samples)
88    ]
89
90    repr_ds: repr_dataset.RepresentativeDataset = [{
91        'input_tensor': ops.convert_to_tensor(sample),
92    } for sample in samples]
93
94    with self.session() as sess:
95      new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays(
96          repr_ds, sess)
97
98      # The resulting dataset should not contain any tf.Tensors.
99      self.assertFalse(any(map(_contains_tensor, new_repr_ds)))
100      self._assert_sample_values_all_close(sess, repr_ds, new_repr_ds)
101
102  @test_util.deprecated_graph_mode_only
103  def test_replace_tensors_by_numpy_ndarrays_with_tensor_generator(self):
104    num_samples = 8
105    samples = [
106        np.random.uniform(low=-1., high=1., size=(1, 4)).astype('f4')
107        for _ in range(num_samples)
108    ]
109
110    def data_gen() -> repr_dataset.RepresentativeDataset:
111      for sample in samples:
112        yield {'input_tensor': ops.convert_to_tensor(sample)}
113
114    with self.session() as sess:
115      new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays(
116          data_gen(), sess)
117
118      # The resulting dataset should not contain any tf.Tensors.
119      self.assertFalse(any(map(_contains_tensor, new_repr_ds)))
120      self._assert_sample_values_all_close(sess, data_gen(), new_repr_ds)
121
122  @test_util.deprecated_graph_mode_only
123  def test_replace_tensors_by_numpy_ndarrays_is_noop_when_no_tensor(self):
124    # Fill the representative dataset with np.ndarrays only.
125    repr_ds: repr_dataset.RepresentativeDataset = [{
126        'input_tensor': np.random.uniform(low=-1., high=1., size=(4, 3)),
127    } for _ in range(8)]
128
129    with self.session() as sess:
130      new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays(
131          repr_ds, sess)
132
133      # The resulting dataset should not contain any tf.Tensors.
134      self.assertFalse(any(map(_contains_tensor, new_repr_ds)))
135      self._assert_sample_values_all_close(sess, repr_ds, new_repr_ds)
136
137  @test_util.deprecated_graph_mode_only
138  def test_replace_tensors_by_numpy_ndarrays_mixed_tensor_and_ndarray(self):
139    num_tensors = 4
140    samples = [
141        np.random.uniform(low=-1., high=1., size=(3, 3)).astype('f4')
142        for _ in range(num_tensors)
143    ]
144
145    repr_ds: repr_dataset.RepresentativeDataset = [{
146        'tensor_key': ops.convert_to_tensor(sample),
147    } for sample in samples]
148
149    # Extend the representative dataset with np.ndarrays.
150    repr_ds.extend([{
151        'tensor_key': np.random.uniform(low=-1., high=1., size=(3, 3))
152    } for _ in range(4)])
153
154    random.shuffle(repr_ds)
155
156    with self.session() as sess:
157      new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays(
158          repr_ds, sess)
159
160      # The resulting dataset should not contain any tf.Tensors.
161      self.assertFalse(any(map(_contains_tensor, new_repr_ds)))
162      self._assert_sample_values_all_close(sess, repr_ds, new_repr_ds)
163
164
165if __name__ == '__main__':
166  test.main()
167