xref: /aosp_15_r20/external/tensorflow/tensorflow/python/platform/googletest.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Imports absltest as a replacement for testing.pybase.googletest."""
17import atexit
18import os
19import sys
20import tempfile
21
22# go/tf-wildcard-import
23# pylint: disable=wildcard-import,redefined-builtin
24from absl import app
25from absl.testing.absltest import *
26# pylint: enable=wildcard-import,redefined-builtin
27
28from tensorflow.python.framework import errors
29from tensorflow.python.lib.io import file_io
30from tensorflow.python.platform import benchmark
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.util import tf_decorator
33from tensorflow.python.util import tf_inspect
34from tensorflow.python.util.tf_export import tf_export
35
36
37Benchmark = benchmark.TensorFlowBenchmark  # pylint: disable=invalid-name
38
39absltest_main = main
40
41# We keep a global variable in this module to make sure we create the temporary
42# directory only once per test binary invocation.
43_googletest_temp_dir = ''
44
45
46# pylint: disable=invalid-name
47# pylint: disable=undefined-variable
48def g_main(argv):
49  """Delegate to absltest.main."""
50
51  absltest_main(argv=argv)
52
53
54# Redefine main to allow running benchmarks
55def main(argv=None):  # pylint: disable=function-redefined
56  def main_wrapper():
57    args = argv
58    if args is None:
59      args = sys.argv
60    return app.run(main=g_main, argv=args)
61
62  benchmark.benchmarks_main(true_main=main_wrapper, argv=argv)
63
64
65def GetTempDir():
66  """Return a temporary directory for tests to use."""
67  global _googletest_temp_dir
68  if not _googletest_temp_dir:
69    if os.environ.get('TEST_TMPDIR'):
70      temp_dir = tempfile.mkdtemp(prefix=os.environ['TEST_TMPDIR'])
71    else:
72      first_frame = tf_inspect.stack()[-1][0]
73      temp_dir = os.path.join(tempfile.gettempdir(),
74                              os.path.basename(tf_inspect.getfile(first_frame)))
75      temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py'))
76
77    # Make sure we have the correct path separators.
78    temp_dir = temp_dir.replace('/', os.sep)
79
80    def delete_temp_dir(dirname=temp_dir):
81      try:
82        file_io.delete_recursively(dirname)
83      except errors.OpError as e:
84        logging.error('Error removing %s: %s', dirname, e)
85
86    atexit.register(delete_temp_dir)
87
88    _googletest_temp_dir = temp_dir
89
90  return _googletest_temp_dir
91
92
93def test_src_dir_path(relative_path):
94  """Creates an absolute test srcdir path given a relative path.
95
96  Args:
97    relative_path: a path relative to tensorflow root.
98      e.g. "contrib/session_bundle/example".
99
100  Returns:
101    An absolute path to the linked in runfiles.
102  """
103  return os.path.join(os.environ['TEST_SRCDIR'],
104                      'org_tensorflow/tensorflow', relative_path)
105
106
107def StatefulSessionAvailable():
108  return False
109
110
111@tf_export(v1=['test.StubOutForTesting'])
112class StubOutForTesting(object):
113  """Support class for stubbing methods out for unit testing.
114
115  Sample Usage:
116
117  You want os.path.exists() to always return true during testing.
118
119     stubs = StubOutForTesting()
120     stubs.Set(os.path, 'exists', lambda x: 1)
121       ...
122     stubs.CleanUp()
123
124  The above changes os.path.exists into a lambda that returns 1.  Once
125  the ... part of the code finishes, the CleanUp() looks up the old
126  value of os.path.exists and restores it.
127  """
128
129  def __init__(self):
130    self.cache = []
131    self.stubs = []
132
133  def __del__(self):
134    """Do not rely on the destructor to undo your stubs.
135
136    You cannot guarantee exactly when the destructor will get called without
137    relying on implementation details of a Python VM that may change.
138    """
139    self.CleanUp()
140
141  # __enter__ and __exit__ allow use as a context manager.
142  def __enter__(self):
143    return self
144
145  def __exit__(self, unused_exc_type, unused_exc_value, unused_tb):
146    self.CleanUp()
147
148  def CleanUp(self):
149    """Undoes all SmartSet() & Set() calls, restoring original definitions."""
150    self.SmartUnsetAll()
151    self.UnsetAll()
152
153  def SmartSet(self, obj, attr_name, new_attr):
154    """Replace obj.attr_name with new_attr.
155
156    This method is smart and works at the module, class, and instance level
157    while preserving proper inheritance. It will not stub out C types however
158    unless that has been explicitly allowed by the type.
159
160    This method supports the case where attr_name is a staticmethod or a
161    classmethod of obj.
162
163    Notes:
164      - If obj is an instance, then it is its class that will actually be
165        stubbed. Note that the method Set() does not do that: if obj is
166        an instance, it (and not its class) will be stubbed.
167      - The stubbing is using the builtin getattr and setattr. So, the __get__
168        and __set__ will be called when stubbing (TODO: A better idea would
169        probably be to manipulate obj.__dict__ instead of getattr() and
170        setattr()).
171
172    Args:
173      obj: The object whose attributes we want to modify.
174      attr_name: The name of the attribute to modify.
175      new_attr: The new value for the attribute.
176
177    Raises:
178      AttributeError: If the attribute cannot be found.
179    """
180    _, obj = tf_decorator.unwrap(obj)
181    if (tf_inspect.ismodule(obj) or
182        (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)):
183      orig_obj = obj
184      orig_attr = getattr(obj, attr_name)
185    else:
186      if not tf_inspect.isclass(obj):
187        mro = list(tf_inspect.getmro(obj.__class__))
188      else:
189        mro = list(tf_inspect.getmro(obj))
190
191      mro.reverse()
192
193      orig_attr = None
194      found_attr = False
195
196      for cls in mro:
197        try:
198          orig_obj = cls
199          orig_attr = getattr(obj, attr_name)
200          found_attr = True
201        except AttributeError:
202          continue
203
204      if not found_attr:
205        raise AttributeError('Attribute not found.')
206
207    # Calling getattr() on a staticmethod transforms it to a 'normal' function.
208    # We need to ensure that we put it back as a staticmethod.
209    old_attribute = obj.__dict__.get(attr_name)
210    if old_attribute is not None and isinstance(old_attribute, staticmethod):
211      orig_attr = staticmethod(orig_attr)
212
213    self.stubs.append((orig_obj, attr_name, orig_attr))
214    setattr(orig_obj, attr_name, new_attr)
215
216  def SmartUnsetAll(self):
217    """Reverses SmartSet() calls, restoring things to original definitions.
218
219    This method is automatically called when the StubOutForTesting()
220    object is deleted; there is no need to call it explicitly.
221
222    It is okay to call SmartUnsetAll() repeatedly, as later calls have
223    no effect if no SmartSet() calls have been made.
224    """
225    for args in reversed(self.stubs):
226      setattr(*args)
227
228    self.stubs = []
229
230  def Set(self, parent, child_name, new_child):
231    """In parent, replace child_name's old definition with new_child.
232
233    The parent could be a module when the child is a function at
234    module scope.  Or the parent could be a class when a class' method
235    is being replaced.  The named child is set to new_child, while the
236    prior definition is saved away for later, when UnsetAll() is
237    called.
238
239    This method supports the case where child_name is a staticmethod or a
240    classmethod of parent.
241
242    Args:
243      parent: The context in which the attribute child_name is to be changed.
244      child_name: The name of the attribute to change.
245      new_child: The new value of the attribute.
246    """
247    old_child = getattr(parent, child_name)
248
249    old_attribute = parent.__dict__.get(child_name)
250    if old_attribute is not None and isinstance(old_attribute, staticmethod):
251      old_child = staticmethod(old_child)
252
253    self.cache.append((parent, old_child, child_name))
254    setattr(parent, child_name, new_child)
255
256  def UnsetAll(self):
257    """Reverses Set() calls, restoring things to their original definitions.
258
259    This method is automatically called when the StubOutForTesting()
260    object is deleted; there is no need to call it explicitly.
261
262    It is okay to call UnsetAll() repeatedly, as later calls have no
263    effect if no Set() calls have been made.
264    """
265    # Undo calls to Set() in reverse order, in case Set() was called on the
266    # same arguments repeatedly (want the original call to be last one undone)
267    for (parent, old_child, child_name) in reversed(self.cache):
268      setattr(parent, child_name, old_child)
269    self.cache = []
270