1# Lint as: python2, python3 2# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests to improve the the of tensorflow. 17 18Here we would like to include high level tests that stress tf.function and 19autograph in ways users have discovered. Not everything here has to work, 20some things just need to have good error messages. some things currently 21have bugs assigned to them but do not work and do not have sufficient error 22messages. 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28 29import collections 30import enum 31import tempfile 32import numpy as np 33import tensorflow as tf 34from tensorflow.python.util import tf_inspect 35 36 37# Example. to run test case against. 38# 39# arg: Argument tuple to function callable 40# out: Expected output. 41# failure: 42# List of `RunMode` enums that are expected to fail, or 43# Dict of {'<`RunMode` enum>': 'error mesaage', ...} where keys are `RunMode` 44# enums that are excpected to fail and values are the corresponding 45# 'error message' of the failure. 46# bugs: List of bugs that are related to this test case. 47Example = collections.namedtuple('Example', ['arg', 'out', 'failure', 'bugs']) 48 49 50class RunMode(enum.Enum): 51 RAW = 0 52 FUNCTION = 1 53 SAVED = 2 54 XLA = 3 55 56 57dashboard_data = {} 58 59 60class ConsistencyTestBase(tf.test.TestCase): 61 """Tests that attempt to use py function's in the 4 use-examples. 62 63 The example kinds are: 64 raw, tf.function'ified, tf.function xlaified, and loaded from saved model. 65 """ 66 67 def recordProperty(self, property_name, property_value): 68 """Wrapper to handle recording properties. 69 70 Args: 71 property_name: Name of property to record. 72 property_value: Value to record associated with `property_name`. 73 74 Open source does not have record property. 75 """ 76 base = super(ConsistencyTestBase, self) 77 if hasattr(base, 'recordProperty'): 78 getattr(base, 'recordProperty')(property_name, property_value) 79 80 def _deep_equal(self, left, right): 81 if isinstance(left, tf.TensorArray): 82 return self._deep_equal(left.stack(), right) 83 if isinstance(right, tf.TensorArray): 84 return self._deep_equal(left, right.stack()) 85 if isinstance(left, tf.Tensor): 86 return self._deep_equal(left.numpy(), right) 87 if isinstance(right, tf.Tensor): 88 return self._deep_equal(left, right.numpy()) 89 if isinstance(left, tf.SparseTensor) and isinstance(right, tf.SparseTensor): 90 return (self._deep_equal(left.indices, right.indices) 91 and self._deep_equal(left.values, right.values) 92 and self._deep_equal(left.shape, right.shape)) 93 if isinstance(left, np.ndarray) or isinstance(right, np.ndarray): 94 return np.array_equal(left, right) 95 if isinstance(left, (list, tuple)) and isinstance(right, (list, tuple)): 96 return all(self._deep_equal(l, r) for l, r in zip(left, right)) 97 98 return left == right 99 100 def _run_and_check(self, f, mode, examples): 101 for arg, out, failure, bugs in examples: 102 del bugs 103 err_msg = '.*' 104 # `failure` can be a list of `RunMode` enums or a dict of `RunMode` enum 105 # and corresponding error message as key-value pairs: 106 # `{'<`RunMode` enum>': 'error message', ...}`. If `failure` is a dict, 107 # retrieve the error message corresponding to the `RunMode`. 108 if isinstance(failure, dict): 109 if mode in failure.keys(): 110 err_msg = failure[mode] 111 112 # Get a list of `RunMode` enums from `failure` (dict) by getting the 113 # keys to make it consistent with when `failure` is a list. 114 failure = failure.keys() 115 116 if mode in failure: 117 with self.assertRaisesWithPredicateMatch(BaseException, err_msg): 118 self._deep_equal(f(*arg), out) 119 else: 120 # Make sure `_deep_equal` returns True. Otherwise, mismatching results 121 # (between `f(*arg)` and `out`) will not be caught. 122 self.assertTrue(self._deep_equal(f(*arg), out)) 123 124 def _generic_test(self, 125 f_raw, 126 examples, 127 input_signature=None, 128 skip_modes=None): 129 """Test a function `f_raw` against all tests `examples`. 130 131 Args: 132 f_raw: a callable. 133 examples: A list of `Example` named tuples. 134 input_signature: Input signature to tf.function. 135 skip_modes: A list of `RunMode` enums to entirely skip testing in the 136 specified `RunMode`s. This is necessary when things fail in a certain 137 `RunMode` even before executing the function (e.g. during saving or 138 loading in `RunMode.SAVED` mode). 139 """ 140 f_tf = None 141 if not skip_modes: 142 skip_modes = [] 143 144 if tf_inspect.isfunction(f_raw): 145 self.recordProperty('f', tf_inspect.getsource(f_raw)) 146 else: 147 self.recordProperty('f', tf_inspect.getdoc(f_raw)) 148 149 for arg, out, failure, bugs in examples: 150 del out 151 self.recordProperty('Input "{}"'.format(arg), { 152 'not-working': failure, 153 'bugs': bugs 154 }) 155 156 # Run the function without tf.function 157 if RunMode.RAW not in skip_modes: 158 self._run_and_check(f_raw, RunMode.RAW, examples) 159 160 # TF Function 161 if RunMode.FUNCTION not in skip_modes: 162 f_tf = tf.function(f_raw, input_signature=input_signature) 163 self._run_and_check(f_tf, RunMode.FUNCTION, examples) 164 165 # XLA Function 166 if RunMode.XLA not in skip_modes: 167 f_xla = tf.function( 168 f_raw, input_signature=input_signature, experimental_compile=True) 169 self._run_and_check(f_xla, RunMode.XLA, examples) 170 171 # Write a saved model and try to run it 172 if RunMode.SAVED not in skip_modes: 173 module = tf.Module() 174 if f_tf: 175 module.f = f_tf 176 else: 177 module.f = tf.function(f_raw, input_signature=input_signature) 178 179 saved_model_dir = tempfile.gettempdir() 180 tf.saved_model.save(module, saved_model_dir) 181 module_loaded = tf.saved_model.load(saved_model_dir) 182 self._run_and_check(module_loaded.f, RunMode.SAVED, examples) 183 184 185if __name__ == '__main__': 186 tf.test.main() 187