xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_test_wrapper.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Wrapper for Python TPU tests.
16
17The py_tpu_test macro will actually use this file as its main, building and
18executing the user-provided test file as a py_binary instead. This lets us do
19important work behind the scenes, without complicating the tests themselves.
20
21The main responsibilities of this file are:
22  - Define standard set of model flags if test did not. This allows us to
23    safely set flags at the Bazel invocation level using --test_arg.
24  - Pick a random directory on GCS to use for each test case, and set it as the
25    default value of --model_dir. This is similar to how Bazel provides each
26    test with a fresh local directory in $TEST_TMPDIR.
27"""
28
29import ast
30import importlib
31import os
32import sys
33import uuid
34
35from tensorflow.python.platform import flags
36from tensorflow.python.util import tf_inspect
37
38FLAGS = flags.FLAGS
39flags.DEFINE_string(
40    'wrapped_tpu_test_module_relative', None,
41    'The Python-style relative path to the user-given test. If test is in same '
42    'directory as BUILD file as is common, then "test.py" would be ".test".')
43flags.DEFINE_string('test_dir_base',
44                    os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
45                    'GCS path to root directory for temporary test files.')
46flags.DEFINE_string(
47    'bazel_repo_root', 'tensorflow/python',
48    'Substring of a bazel filepath beginning the python absolute import path.')
49
50# List of flags which all TPU tests should accept.
51REQUIRED_FLAGS = ['tpu', 'zone', 'project', 'model_dir']
52
53
54def maybe_define_flags():
55  """Defines any required flags that are missing."""
56  for f in REQUIRED_FLAGS:
57    try:
58      flags.DEFINE_string(f, None, 'flag defined by test lib')
59    except flags.DuplicateFlagError:
60      pass
61
62
63def set_random_test_dir():
64  """Pick a random GCS directory under --test_dir_base, set as --model_dir."""
65  path = os.path.join(FLAGS.test_dir_base, uuid.uuid4().hex)
66  FLAGS.set_default('model_dir', path)
67
68
69def calculate_parent_python_path(test_filepath):
70  """Returns the absolute import path for the containing directory.
71
72  Args:
73    test_filepath: The filepath which Bazel invoked
74      (ex: /filesystem/path/tensorflow/tensorflow/python/tpu/tpu_test)
75
76  Returns:
77    Absolute import path of parent (ex: tensorflow.python.tpu).
78
79  Raises:
80    ValueError: if bazel_repo_root does not appear within test_filepath.
81  """
82  # We find the last occurrence of bazel_repo_root, and drop everything before.
83  split_path = test_filepath.rsplit(FLAGS.bazel_repo_root, 1)
84  if len(split_path) < 2:
85    raise ValueError(
86        f'Filepath "{test_filepath}" does not contain repo root "{FLAGS.bazel_repo_root}"'
87    )
88
89  path = FLAGS.bazel_repo_root + split_path[1]
90
91  # We drop the last portion of the path, which is the name of the test wrapper.
92  path = path.rsplit('/', 1)[0]
93
94  # We convert the directory separators into dots.
95  return path.replace('/', '.')
96
97
98def import_user_module():
99  """Imports the flag-specified user test code.
100
101  This runs all top-level statements in the user module, specifically flag
102  definitions.
103
104  Returns:
105    The user test module.
106  """
107  return importlib.import_module(FLAGS.wrapped_tpu_test_module_relative,
108                                 calculate_parent_python_path(sys.argv[0]))
109
110
111def _is_test_class(obj):
112  """Check if arbitrary object is a test class (not a test object!).
113
114  Args:
115    obj: An arbitrary object from within a module.
116
117  Returns:
118    True iff obj is a test class inheriting at some point from a module
119    named "TestCase". This is because we write tests using different underlying
120    test libraries.
121  """
122  return (tf_inspect.isclass(obj)
123          and 'TestCase' in (p.__name__ for p in tf_inspect.getmro(obj)))
124
125
126module_variables = vars()
127
128
129def move_test_classes_into_scope(wrapped_test_module):
130  """Add all test classes defined in wrapped module to our module.
131
132  The test runner works by inspecting the main module for TestCase classes, so
133  by adding a module-level reference to the TestCase we cause it to execute the
134  wrapped TestCase.
135
136  Args:
137    wrapped_test_module: The user-provided test code to run.
138  """
139  for name, obj in wrapped_test_module.__dict__.items():
140    if _is_test_class(obj):
141      module_variables['tpu_test_imported_%s' % name] = obj
142
143
144def run_user_main(wrapped_test_module):
145  """Runs the "if __name__ == '__main__'" at the bottom of a module.
146
147  TensorFlow practice is to have a main if at the bottom of the module which
148  might call an API compat function before calling test.main().
149
150  Since this is a statement, not a function, we can't cleanly reference it, but
151  we can inspect it from the user module and run it in the context of that
152  module so all imports and variables are available to it.
153
154  Args:
155    wrapped_test_module: The user-provided test code to run.
156
157  Raises:
158    NotImplementedError: If main block was not found in module. This should not
159      be caught, as it is likely an error on the user's part -- absltest is all
160      too happy to report a successful status (and zero tests executed) if a
161      user forgets to end a class with "test.main()".
162  """
163  tree = ast.parse(tf_inspect.getsource(wrapped_test_module))
164
165  # Get string representation of just the condition `__name == "__main__"`.
166  target = ast.dump(ast.parse('if __name__ == "__main__": pass').body[0].test)
167
168  # `tree.body` is a list of top-level statements in the module, like imports
169  # and class definitions. We search for our main block, starting from the end.
170  for expr in reversed(tree.body):
171    if isinstance(expr, ast.If) and ast.dump(expr.test) == target:
172      break
173  else:
174    raise NotImplementedError(
175        f'Could not find `if __name__ == "main":` block in {wrapped_test_module.__name__}.'
176        )
177
178  # expr is defined because we would have raised an error otherwise.
179  new_ast = ast.Module(body=expr.body, type_ignores=[])  # pylint:disable=undefined-loop-variable
180  exec(  # pylint:disable=exec-used
181      compile(new_ast, '<ast>', 'exec'),
182      globals(),
183      wrapped_test_module.__dict__,
184  )
185
186
187if __name__ == '__main__':
188  # Partially parse flags, since module to import is specified by flag.
189  unparsed = FLAGS(sys.argv, known_only=True)
190  user_module = import_user_module()
191  maybe_define_flags()
192  # Parse remaining flags.
193  FLAGS(unparsed)
194  set_random_test_dir()
195
196  move_test_classes_into_scope(user_module)
197  run_user_main(user_module)
198