1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Defines types required for representative datasets for quantization.""" 16 17from typing import Iterable, Mapping, Union 18 19from tensorflow.python.client import session 20from tensorflow.python.types import core 21 22# A representative sample is a map of: input_key -> input_value. 23# Ex.: {'dense_input': tf.constant([1, 2, 3])} 24# Ex.: {'x1': np.ndarray([4, 5, 6]} 25RepresentativeSample = Mapping[str, core.TensorLike] 26 27# A representative dataset is an iterable of representative samples. 28RepresentativeDataset = Iterable[RepresentativeSample] 29 30# A type representing a map from: signature key -> representative dataset. 31# Ex.: {'serving_default': [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])], 32# 'other_signature_key': [tf.constant([[2, 2], [9, 9]])]} 33RepresentativeDatasetMapping = Mapping[str, RepresentativeDataset] 34 35# A type alias expressing that it can be either a RepresentativeDataset or 36# a mapping of signature key to RepresentativeDataset. 37RepresentativeDatasetOrMapping = Union[RepresentativeDataset, 38 RepresentativeDatasetMapping] 39 40 41def replace_tensors_by_numpy_ndarrays( 42 repr_ds: RepresentativeDataset, 43 sess: session.Session) -> RepresentativeDataset: 44 """Replaces tf.Tensors in samples by their evaluated numpy arrays. 45 46 Note: This should be run in graph mode (default in TF1) only. 47 48 Args: 49 repr_ds: Representative dataset to replace the tf.Tensors with their 50 evaluated values. `repr_ds` is iterated through, so it may not be reusable 51 (e.g. if it is a generator object). 52 sess: Session instance used to evaluate tf.Tensors. 53 54 Returns: 55 The new representative dataset where each tf.Tensor is replaced by its 56 evaluated numpy ndarrays. 57 """ 58 new_repr_ds = [] 59 for sample in repr_ds: 60 new_sample = {} 61 for input_key, input_data in sample.items(): 62 # Evaluate the Tensor to get the actual value. 63 if isinstance(input_data, core.Tensor): 64 input_data = input_data.eval(session=sess) 65 66 new_sample[input_key] = input_data 67 68 new_repr_ds.append(new_sample) 69 return new_repr_ds 70