1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for representative_dataset.py.""" 16import random 17 18import numpy as np 19 20from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset 21from tensorflow.python.client import session 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import test_util 24from tensorflow.python.platform import test 25from tensorflow.python.types import core 26 27 28def _contains_tensor(sample: repr_dataset.RepresentativeSample) -> bool: 29 """Determines whether `sample` contains any tf.Tensors. 30 31 Args: 32 sample: A `RepresentativeSample`. 33 34 Returns: 35 True iff `sample` contains at least tf.Tensors. 36 """ 37 return any(map(lambda value: isinstance(value, core.Tensor), sample.values())) 38 39 40class RepresentativeDatasetTest(test.TestCase): 41 """Tests functions for representative datasets.""" 42 43 def _assert_tensorlike_all_close(self, sess: session.Session, 44 tensorlike_value_1: core.TensorLike, 45 tensorlike_value_2: core.TensorLike) -> None: 46 """Asserts that two different TensorLike values are "all close". 47 48 Args: 49 sess: Session instance used to evaluate any tf.Tensors. 50 tensorlike_value_1: A TensorLike value. 51 tensorlike_value_2: A TensorLike value. 52 """ 53 if isinstance(tensorlike_value_1, core.Tensor): 54 tensorlike_value_1 = tensorlike_value_1.eval(session=sess) 55 56 if isinstance(tensorlike_value_2, core.Tensor): 57 tensorlike_value_2 = tensorlike_value_2.eval(session=sess) 58 59 self.assertAllClose(tensorlike_value_1, tensorlike_value_2) 60 61 def _assert_sample_values_all_close( 62 self, sess: session.Session, 63 repr_ds_1: repr_dataset.RepresentativeDataset, 64 repr_ds_2: repr_dataset.RepresentativeDataset) -> None: 65 """Asserts that the sample values are "all close" between the two datasets. 66 67 This assumes that the order of corresponding samples is preserved and the 68 size of the two datasets are equal. 69 70 Args: 71 sess: Session instance used to evaluate any tf.Tensors. 72 repr_ds_1: A RepresentativeDataset. 73 repr_ds_2: A RepresentativeDataset. 74 """ 75 for sample_1, sample_2 in zip(repr_ds_1, repr_ds_2): 76 self.assertCountEqual(sample_1.keys(), sample_2.keys()) 77 78 for input_key in sample_1: 79 self._assert_tensorlike_all_close(sess, sample_1[input_key], 80 sample_2[input_key]) 81 82 @test_util.deprecated_graph_mode_only 83 def test_replace_tensors_by_numpy_ndarrays_with_tensor_list(self): 84 num_samples = 8 85 samples = [ 86 np.random.uniform(low=-1., high=1., size=(3, 3)).astype('f4') 87 for _ in range(num_samples) 88 ] 89 90 repr_ds: repr_dataset.RepresentativeDataset = [{ 91 'input_tensor': ops.convert_to_tensor(sample), 92 } for sample in samples] 93 94 with self.session() as sess: 95 new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays( 96 repr_ds, sess) 97 98 # The resulting dataset should not contain any tf.Tensors. 99 self.assertFalse(any(map(_contains_tensor, new_repr_ds))) 100 self._assert_sample_values_all_close(sess, repr_ds, new_repr_ds) 101 102 @test_util.deprecated_graph_mode_only 103 def test_replace_tensors_by_numpy_ndarrays_with_tensor_generator(self): 104 num_samples = 8 105 samples = [ 106 np.random.uniform(low=-1., high=1., size=(1, 4)).astype('f4') 107 for _ in range(num_samples) 108 ] 109 110 def data_gen() -> repr_dataset.RepresentativeDataset: 111 for sample in samples: 112 yield {'input_tensor': ops.convert_to_tensor(sample)} 113 114 with self.session() as sess: 115 new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays( 116 data_gen(), sess) 117 118 # The resulting dataset should not contain any tf.Tensors. 119 self.assertFalse(any(map(_contains_tensor, new_repr_ds))) 120 self._assert_sample_values_all_close(sess, data_gen(), new_repr_ds) 121 122 @test_util.deprecated_graph_mode_only 123 def test_replace_tensors_by_numpy_ndarrays_is_noop_when_no_tensor(self): 124 # Fill the representative dataset with np.ndarrays only. 125 repr_ds: repr_dataset.RepresentativeDataset = [{ 126 'input_tensor': np.random.uniform(low=-1., high=1., size=(4, 3)), 127 } for _ in range(8)] 128 129 with self.session() as sess: 130 new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays( 131 repr_ds, sess) 132 133 # The resulting dataset should not contain any tf.Tensors. 134 self.assertFalse(any(map(_contains_tensor, new_repr_ds))) 135 self._assert_sample_values_all_close(sess, repr_ds, new_repr_ds) 136 137 @test_util.deprecated_graph_mode_only 138 def test_replace_tensors_by_numpy_ndarrays_mixed_tensor_and_ndarray(self): 139 num_tensors = 4 140 samples = [ 141 np.random.uniform(low=-1., high=1., size=(3, 3)).astype('f4') 142 for _ in range(num_tensors) 143 ] 144 145 repr_ds: repr_dataset.RepresentativeDataset = [{ 146 'tensor_key': ops.convert_to_tensor(sample), 147 } for sample in samples] 148 149 # Extend the representative dataset with np.ndarrays. 150 repr_ds.extend([{ 151 'tensor_key': np.random.uniform(low=-1., high=1., size=(3, 3)) 152 } for _ in range(4)]) 153 154 random.shuffle(repr_ds) 155 156 with self.session() as sess: 157 new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays( 158 repr_ds, sess) 159 160 # The resulting dataset should not contain any tf.Tensors. 161 self.assertFalse(any(map(_contains_tensor, new_repr_ds))) 162 self._assert_sample_values_all_close(sess, repr_ds, new_repr_ds) 163 164 165if __name__ == '__main__': 166 test.main() 167