1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 16# 17# Licensed under the Apache License, Version 2.0 (the "License"); 18# you may not use this file except in compliance with the License. 19# You may obtain a copy of the License at 20# 21# http://www.apache.org/licenses/LICENSE-2.0 22# 23# Unless required by applicable law or agreed to in writing, software 24# distributed under the License is distributed on an "AS IS" BASIS, 25# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26# See the License for the specific language governing permissions and 27# limitations under the License. 28# ============================================================================== 29"""Reference tests check that a function is compiled correctly.""" 30 31import io 32import numbers 33import os 34import sys 35import traceback 36 37import numpy as np 38import tensorflow as tf 39 40 41class TestCase(tf.test.TestCase): 42 """Base class for the reference tests.""" 43 44 def setUp(self): 45 super(TestCase, self).setUp() 46 os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' 47 self.autograph_opts = None 48 self.all_inputs_tensors = False 49 self.allow_exceptions = False 50 51 # TODO(mdan): Consider rewriting as a context manager. 52 def _run_with_output_capture(self, func): 53 """Executes `func`, capturing stdout.""" 54 out_capturer = io.StringIO() 55 results = None 56 captured_out = None 57 captured_err = None 58 try: 59 sys.stdout = out_capturer 60 results = func() 61 captured_out = out_capturer.getvalue() 62 except Exception as e: # pylint:disable=broad-except 63 sys.stdout = sys.__stdout__ 64 captured_err = e 65 print('*** Capturing exception:\n{}\n'.format(traceback.format_exc())) 66 finally: 67 sys.stdout = sys.__stdout__ 68 out_capturer.close() 69 return results, captured_out, captured_err 70 71 def _as_tensors(self, args): 72 """Converts args to tensors.""" 73 tensor_args = [] 74 for a in args: 75 if isinstance(a, (numbers.Number, list, np.ndarray)): 76 tensor_arg = tf.constant(a) 77 elif isinstance(a, dict): 78 keys = tuple(a.keys()) 79 tensor_arg = dict(zip(keys, self._as_tensors([a[k] for k in keys]))) 80 else: 81 tensor_arg = a 82 tensor_args.append(tensor_arg) 83 return tensor_args 84 85 def run_native(self, f, *args): 86 return self._run_with_output_capture(lambda: f(*args)) 87 88 def _deep_equal(self, left, right): 89 """Compares two possibly-nested structures.""" 90 if isinstance(left, tf.Tensor): 91 return self._deep_equal(left.numpy(), right) 92 if isinstance(right, tf.Tensor): 93 return self._deep_equal(left, right.numpy()) 94 if isinstance(left, tf.SparseTensor) and isinstance(right, tf.SparseTensor): 95 return (self._deep_equal(left.indices, right.indices) 96 and self._deep_equal(left.values, right.values) 97 and self._deep_equal(left.shape, right.shape)) 98 if isinstance(left, np.ndarray) or isinstance(right, np.ndarray): 99 return np.array_equal(left, right) 100 if isinstance(left, (list, tuple)) and isinstance(right, (list, tuple)): 101 return all(self._deep_equal(l, r) for l, r in zip(left, right)) 102 return left == right 103 104 def assertResultsMatch(self, 105 f, 106 args, 107 native_data, 108 compiled_data): 109 """Asserts that native_data matches compiled_data.""" 110 native_results, native_out, native_err = native_data 111 compiled_results, compiled_out, compiled_err = compiled_data 112 str_args = '(%s)' % ', '.join(str(a) for a in args) 113 # Using a manual verification to avoid a second compilation on success. 114 # For exceptions, we don't enforce that they are the same, only that 115 # both paths raised. 116 # TODO(mdan): Add an API that returns both object and source code instead. 117 outputs_equal = ( 118 self._deep_equal(native_results, compiled_results) and 119 native_out == compiled_out) 120 errors_equivalent = type(native_err) == type(compiled_err) # pylint:disable=unidiomatic-typecheck 121 if (not outputs_equal or not errors_equivalent): 122 self.fail('Native and compiled functions are not equivalent.\n\n' 123 'Native results: %s\n' 124 'Compiled results: %s\n' 125 'Native out: %s\n' 126 'Compiled out: %s\n' 127 'Native error: %s: %s\n' 128 'Compiled error: %s: %s\n' 129 'Native call: %s%s\n' 130 'Check the logs for the generated code.' 131 '' % ( 132 native_results, 133 compiled_results, 134 native_out, 135 compiled_out, 136 type(native_err).__name__, 137 native_err, 138 type(compiled_err).__name__, 139 compiled_err, 140 f.__name__, 141 str_args, 142 )) 143 144 def function(self, f, xla=False): 145 return tf.function( 146 f, 147 experimental_autograph_options=self.autograph_opts, 148 experimental_compile=xla) 149 150 def convert(self, f): 151 return tf.autograph.to_graph( 152 f, experimental_optional_features=self.autograph_opts) 153 154 def assertFunctionMatchesEagerStatefulInput(self, f, args): 155 """Like assertFunctionMatchesEager but creates new inputs each time.""" 156 compiled_data = self.run_native(self.function(f), *args()) 157 native_data = self.run_native(f, *args()) 158 self.assertResultsMatch(f, args(), native_data, compiled_data) 159 160 def assertFunctionMatchesEager(self, f, *args, xla=False): 161 if self.all_inputs_tensors: 162 args = self._as_tensors(args) 163 compiled_data = self.run_native(self.function(f, xla=xla), *args) 164 if not self.allow_exceptions: 165 _, _, compiled_err = compiled_data 166 if compiled_err is not None: 167 self.fail(str(compiled_err)) 168 native_data = self.run_native(f, *args) 169 self.assertResultsMatch(f, args, native_data, compiled_data) 170 171 def assertConvertedMatchesNative(self, f, *args): 172 compiled_data = self.run_native(self.convert(f), *args) 173 native_data = self.run_native(f, *args) 174 self.assertResultsMatch(f, args, native_data, compiled_data) 175 176 177if __name__ == '__main__': 178 tf.test.main() 179