xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/tests/reference_test_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
16#
17# Licensed under the Apache License, Version 2.0 (the "License");
18# you may not use this file except in compliance with the License.
19# You may obtain a copy of the License at
20#
21#     http://www.apache.org/licenses/LICENSE-2.0
22#
23# Unless required by applicable law or agreed to in writing, software
24# distributed under the License is distributed on an "AS IS" BASIS,
25# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26# See the License for the specific language governing permissions and
27# limitations under the License.
28# ==============================================================================
29"""Reference tests check that a function is compiled correctly."""
30
31import io
32import numbers
33import os
34import sys
35import traceback
36
37import numpy as np
38import tensorflow as tf
39
40
41class TestCase(tf.test.TestCase):
42  """Base class for the reference tests."""
43
44  def setUp(self):
45    super(TestCase, self).setUp()
46    os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
47    self.autograph_opts = None
48    self.all_inputs_tensors = False
49    self.allow_exceptions = False
50
51  # TODO(mdan): Consider rewriting as a context manager.
52  def _run_with_output_capture(self, func):
53    """Executes `func`, capturing stdout."""
54    out_capturer = io.StringIO()
55    results = None
56    captured_out = None
57    captured_err = None
58    try:
59      sys.stdout = out_capturer
60      results = func()
61      captured_out = out_capturer.getvalue()
62    except Exception as e:  # pylint:disable=broad-except
63      sys.stdout = sys.__stdout__
64      captured_err = e
65      print('*** Capturing exception:\n{}\n'.format(traceback.format_exc()))
66    finally:
67      sys.stdout = sys.__stdout__
68      out_capturer.close()
69    return results, captured_out, captured_err
70
71  def _as_tensors(self, args):
72    """Converts args to tensors."""
73    tensor_args = []
74    for a in args:
75      if isinstance(a, (numbers.Number, list, np.ndarray)):
76        tensor_arg = tf.constant(a)
77      elif isinstance(a, dict):
78        keys = tuple(a.keys())
79        tensor_arg = dict(zip(keys, self._as_tensors([a[k] for k in keys])))
80      else:
81        tensor_arg = a
82      tensor_args.append(tensor_arg)
83    return tensor_args
84
85  def run_native(self, f, *args):
86    return self._run_with_output_capture(lambda: f(*args))
87
88  def _deep_equal(self, left, right):
89    """Compares two possibly-nested structures."""
90    if isinstance(left, tf.Tensor):
91      return self._deep_equal(left.numpy(), right)
92    if isinstance(right, tf.Tensor):
93      return self._deep_equal(left, right.numpy())
94    if isinstance(left, tf.SparseTensor) and isinstance(right, tf.SparseTensor):
95      return (self._deep_equal(left.indices, right.indices)
96              and self._deep_equal(left.values, right.values)
97              and self._deep_equal(left.shape, right.shape))
98    if isinstance(left, np.ndarray) or isinstance(right, np.ndarray):
99      return np.array_equal(left, right)
100    if isinstance(left, (list, tuple)) and isinstance(right, (list, tuple)):
101      return all(self._deep_equal(l, r) for l, r in zip(left, right))
102    return left == right
103
104  def assertResultsMatch(self,
105                         f,
106                         args,
107                         native_data,
108                         compiled_data):
109    """Asserts that native_data matches compiled_data."""
110    native_results, native_out, native_err = native_data
111    compiled_results, compiled_out, compiled_err = compiled_data
112    str_args = '(%s)' % ', '.join(str(a) for a in args)
113    # Using a manual verification to avoid a second compilation on success.
114    # For exceptions, we don't enforce that they are the same, only that
115    # both paths raised.
116    # TODO(mdan): Add an API that returns both object and source code instead.
117    outputs_equal = (
118        self._deep_equal(native_results, compiled_results) and
119        native_out == compiled_out)
120    errors_equivalent = type(native_err) == type(compiled_err)  # pylint:disable=unidiomatic-typecheck
121    if (not outputs_equal or not errors_equivalent):
122      self.fail('Native and compiled functions are not equivalent.\n\n'
123                'Native results: %s\n'
124                'Compiled results: %s\n'
125                'Native out: %s\n'
126                'Compiled out: %s\n'
127                'Native error: %s: %s\n'
128                'Compiled error: %s: %s\n'
129                'Native call: %s%s\n'
130                'Check the logs for the generated code.'
131                '' % (
132                    native_results,
133                    compiled_results,
134                    native_out,
135                    compiled_out,
136                    type(native_err).__name__,
137                    native_err,
138                    type(compiled_err).__name__,
139                    compiled_err,
140                    f.__name__,
141                    str_args,
142                ))
143
144  def function(self, f, xla=False):
145    return tf.function(
146        f,
147        experimental_autograph_options=self.autograph_opts,
148        experimental_compile=xla)
149
150  def convert(self, f):
151    return tf.autograph.to_graph(
152        f, experimental_optional_features=self.autograph_opts)
153
154  def assertFunctionMatchesEagerStatefulInput(self, f, args):
155    """Like assertFunctionMatchesEager but creates new inputs each time."""
156    compiled_data = self.run_native(self.function(f), *args())
157    native_data = self.run_native(f, *args())
158    self.assertResultsMatch(f, args(), native_data, compiled_data)
159
160  def assertFunctionMatchesEager(self, f, *args, xla=False):
161    if self.all_inputs_tensors:
162      args = self._as_tensors(args)
163    compiled_data = self.run_native(self.function(f, xla=xla), *args)
164    if not self.allow_exceptions:
165      _, _, compiled_err = compiled_data
166      if compiled_err is not None:
167        self.fail(str(compiled_err))
168    native_data = self.run_native(f, *args)
169    self.assertResultsMatch(f, args, native_data, compiled_data)
170
171  def assertConvertedMatchesNative(self, f, *args):
172    compiled_data = self.run_native(self.convert(f), *args)
173    native_data = self.run_native(f, *args)
174    self.assertResultsMatch(f, args, native_data, compiled_data)
175
176
177if __name__ == '__main__':
178  tf.test.main()
179