1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Defines types required for representative datasets for quantization."""
16
17from typing import Iterable, Mapping, Union
18
19from tensorflow.python.client import session
20from tensorflow.python.types import core
21
22# A representative sample is a map of: input_key -> input_value.
23# Ex.: {'dense_input': tf.constant([1, 2, 3])}
24# Ex.: {'x1': np.ndarray([4, 5, 6]}
25RepresentativeSample = Mapping[str, core.TensorLike]
26
27# A representative dataset is an iterable of representative samples.
28RepresentativeDataset = Iterable[RepresentativeSample]
29
30# A type representing a map from: signature key -> representative dataset.
31# Ex.: {'serving_default': [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])],
32#       'other_signature_key': [tf.constant([[2, 2], [9, 9]])]}
33RepresentativeDatasetMapping = Mapping[str, RepresentativeDataset]
34
35# A type alias expressing that it can be either a RepresentativeDataset or
36# a mapping of signature key to RepresentativeDataset.
37RepresentativeDatasetOrMapping = Union[RepresentativeDataset,
38                                       RepresentativeDatasetMapping]
39
40
41def replace_tensors_by_numpy_ndarrays(
42    repr_ds: RepresentativeDataset,
43    sess: session.Session) -> RepresentativeDataset:
44  """Replaces tf.Tensors in samples by their evaluated numpy arrays.
45
46  Note: This should be run in graph mode (default in TF1) only.
47
48  Args:
49    repr_ds: Representative dataset to replace the tf.Tensors with their
50      evaluated values. `repr_ds` is iterated through, so it may not be reusable
51      (e.g. if it is a generator object).
52    sess: Session instance used to evaluate tf.Tensors.
53
54  Returns:
55    The new representative dataset where each tf.Tensor is replaced by its
56    evaluated numpy ndarrays.
57  """
58  new_repr_ds = []
59  for sample in repr_ds:
60    new_sample = {}
61    for input_key, input_data in sample.items():
62      # Evaluate the Tensor to get the actual value.
63      if isinstance(input_data, core.Tensor):
64        input_data = input_data.eval(session=sess)
65
66      new_sample[input_key] = input_data
67
68    new_repr_ds.append(new_sample)
69  return new_repr_ds
70