1# Lint as: python2, python3
2# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests to improve the the of tensorflow.
17
18Here we would like to include high level tests that stress tf.function and
19autograph in ways users have discovered.  Not everything here has to work,
20some things just need to have good error messages.  some things currently
21have bugs assigned to them but do not work and do not have sufficient error
22messages.
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28
29import collections
30import enum
31import tempfile
32import numpy as np
33import tensorflow as tf
34from tensorflow.python.util import tf_inspect
35
36
37# Example. to run test case against.
38#
39# arg: Argument tuple to function callable
40# out: Expected output.
41# failure:
42#   List of `RunMode` enums that are expected to fail, or
43#   Dict of {'<`RunMode` enum>': 'error mesaage', ...} where keys are `RunMode`
44#     enums that are excpected to fail and values are the corresponding
45#     'error message' of the failure.
46# bugs: List of bugs that are related to this test case.
47Example = collections.namedtuple('Example', ['arg', 'out', 'failure', 'bugs'])
48
49
50class RunMode(enum.Enum):
51  RAW = 0
52  FUNCTION = 1
53  SAVED = 2
54  XLA = 3
55
56
57dashboard_data = {}
58
59
60class ConsistencyTestBase(tf.test.TestCase):
61  """Tests that attempt to use py function's in the 4 use-examples.
62
63  The example kinds are:
64  raw, tf.function'ified, tf.function xlaified, and loaded from saved model.
65  """
66
67  def recordProperty(self, property_name, property_value):
68    """Wrapper to handle recording properties.
69
70    Args:
71      property_name: Name of property to record.
72      property_value: Value to record associated with `property_name`.
73
74    Open source does not have record property.
75    """
76    base = super(ConsistencyTestBase, self)
77    if hasattr(base, 'recordProperty'):
78      getattr(base, 'recordProperty')(property_name, property_value)
79
80  def _deep_equal(self, left, right):
81    if isinstance(left, tf.TensorArray):
82      return self._deep_equal(left.stack(), right)
83    if isinstance(right, tf.TensorArray):
84      return self._deep_equal(left, right.stack())
85    if isinstance(left, tf.Tensor):
86      return self._deep_equal(left.numpy(), right)
87    if isinstance(right, tf.Tensor):
88      return self._deep_equal(left, right.numpy())
89    if isinstance(left, tf.SparseTensor) and isinstance(right, tf.SparseTensor):
90      return (self._deep_equal(left.indices, right.indices)
91              and self._deep_equal(left.values, right.values)
92              and self._deep_equal(left.shape, right.shape))
93    if isinstance(left, np.ndarray) or isinstance(right, np.ndarray):
94      return np.array_equal(left, right)
95    if isinstance(left, (list, tuple)) and isinstance(right, (list, tuple)):
96      return all(self._deep_equal(l, r) for l, r in zip(left, right))
97
98    return left == right
99
100  def _run_and_check(self, f, mode, examples):
101    for arg, out, failure, bugs in examples:
102      del bugs
103      err_msg = '.*'
104      # `failure` can be a list of `RunMode` enums or a dict of `RunMode` enum
105      # and corresponding error message as key-value pairs:
106      # `{'<`RunMode` enum>': 'error message', ...}`. If `failure` is a dict,
107      # retrieve the error message corresponding to the `RunMode`.
108      if isinstance(failure, dict):
109        if mode in failure.keys():
110          err_msg = failure[mode]
111
112        # Get a list of `RunMode` enums from `failure` (dict) by getting the
113        # keys to make it consistent with when `failure` is a list.
114        failure = failure.keys()
115
116      if mode in failure:
117        with self.assertRaisesWithPredicateMatch(BaseException, err_msg):
118          self._deep_equal(f(*arg), out)
119      else:
120        # Make sure `_deep_equal` returns True. Otherwise, mismatching results
121        # (between `f(*arg)` and `out`) will not be caught.
122        self.assertTrue(self._deep_equal(f(*arg), out))
123
124  def _generic_test(self,
125                    f_raw,
126                    examples,
127                    input_signature=None,
128                    skip_modes=None):
129    """Test a function `f_raw` against all tests `examples`.
130
131    Args:
132      f_raw: a callable.
133      examples: A list of `Example` named tuples.
134      input_signature: Input signature to tf.function.
135      skip_modes: A list of `RunMode` enums to entirely skip testing in the
136        specified `RunMode`s. This is necessary when things fail in a certain
137        `RunMode` even before executing the function (e.g. during saving or
138        loading in `RunMode.SAVED` mode).
139    """
140    f_tf = None
141    if not skip_modes:
142      skip_modes = []
143
144    if tf_inspect.isfunction(f_raw):
145      self.recordProperty('f', tf_inspect.getsource(f_raw))
146    else:
147      self.recordProperty('f', tf_inspect.getdoc(f_raw))
148
149    for arg, out, failure, bugs in examples:
150      del out
151      self.recordProperty('Input "{}"'.format(arg), {
152          'not-working': failure,
153          'bugs': bugs
154      })
155
156    # Run the function without tf.function
157    if RunMode.RAW not in skip_modes:
158      self._run_and_check(f_raw, RunMode.RAW, examples)
159
160    # TF Function
161    if RunMode.FUNCTION not in skip_modes:
162      f_tf = tf.function(f_raw, input_signature=input_signature)
163      self._run_and_check(f_tf, RunMode.FUNCTION, examples)
164
165    # XLA Function
166    if RunMode.XLA not in skip_modes:
167      f_xla = tf.function(
168          f_raw, input_signature=input_signature, experimental_compile=True)
169      self._run_and_check(f_xla, RunMode.XLA, examples)
170
171    # Write a saved model and try to run it
172    if RunMode.SAVED not in skip_modes:
173      module = tf.Module()
174      if f_tf:
175        module.f = f_tf
176      else:
177        module.f = tf.function(f_raw, input_signature=input_signature)
178
179      saved_model_dir = tempfile.gettempdir()
180      tf.saved_model.save(module, saved_model_dir)
181      module_loaded = tf.saved_model.load(saved_model_dir)
182      self._run_and_check(module_loaded.f, RunMode.SAVED, examples)
183
184
185if __name__ == '__main__':
186  tf.test.main()
187