1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Wrapper for Python TPU tests. 16 17The py_tpu_test macro will actually use this file as its main, building and 18executing the user-provided test file as a py_binary instead. This lets us do 19important work behind the scenes, without complicating the tests themselves. 20 21The main responsibilities of this file are: 22 - Define standard set of model flags if test did not. This allows us to 23 safely set flags at the Bazel invocation level using --test_arg. 24 - Pick a random directory on GCS to use for each test case, and set it as the 25 default value of --model_dir. This is similar to how Bazel provides each 26 test with a fresh local directory in $TEST_TMPDIR. 27""" 28 29import ast 30import importlib 31import os 32import sys 33import uuid 34 35from tensorflow.python.platform import flags 36from tensorflow.python.util import tf_inspect 37 38FLAGS = flags.FLAGS 39flags.DEFINE_string( 40 'wrapped_tpu_test_module_relative', None, 41 'The Python-style relative path to the user-given test. If test is in same ' 42 'directory as BUILD file as is common, then "test.py" would be ".test".') 43flags.DEFINE_string('test_dir_base', 44 os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'), 45 'GCS path to root directory for temporary test files.') 46flags.DEFINE_string( 47 'bazel_repo_root', 'tensorflow/python', 48 'Substring of a bazel filepath beginning the python absolute import path.') 49 50# List of flags which all TPU tests should accept. 51REQUIRED_FLAGS = ['tpu', 'zone', 'project', 'model_dir'] 52 53 54def maybe_define_flags(): 55 """Defines any required flags that are missing.""" 56 for f in REQUIRED_FLAGS: 57 try: 58 flags.DEFINE_string(f, None, 'flag defined by test lib') 59 except flags.DuplicateFlagError: 60 pass 61 62 63def set_random_test_dir(): 64 """Pick a random GCS directory under --test_dir_base, set as --model_dir.""" 65 path = os.path.join(FLAGS.test_dir_base, uuid.uuid4().hex) 66 FLAGS.set_default('model_dir', path) 67 68 69def calculate_parent_python_path(test_filepath): 70 """Returns the absolute import path for the containing directory. 71 72 Args: 73 test_filepath: The filepath which Bazel invoked 74 (ex: /filesystem/path/tensorflow/tensorflow/python/tpu/tpu_test) 75 76 Returns: 77 Absolute import path of parent (ex: tensorflow.python.tpu). 78 79 Raises: 80 ValueError: if bazel_repo_root does not appear within test_filepath. 81 """ 82 # We find the last occurrence of bazel_repo_root, and drop everything before. 83 split_path = test_filepath.rsplit(FLAGS.bazel_repo_root, 1) 84 if len(split_path) < 2: 85 raise ValueError( 86 f'Filepath "{test_filepath}" does not contain repo root "{FLAGS.bazel_repo_root}"' 87 ) 88 89 path = FLAGS.bazel_repo_root + split_path[1] 90 91 # We drop the last portion of the path, which is the name of the test wrapper. 92 path = path.rsplit('/', 1)[0] 93 94 # We convert the directory separators into dots. 95 return path.replace('/', '.') 96 97 98def import_user_module(): 99 """Imports the flag-specified user test code. 100 101 This runs all top-level statements in the user module, specifically flag 102 definitions. 103 104 Returns: 105 The user test module. 106 """ 107 return importlib.import_module(FLAGS.wrapped_tpu_test_module_relative, 108 calculate_parent_python_path(sys.argv[0])) 109 110 111def _is_test_class(obj): 112 """Check if arbitrary object is a test class (not a test object!). 113 114 Args: 115 obj: An arbitrary object from within a module. 116 117 Returns: 118 True iff obj is a test class inheriting at some point from a module 119 named "TestCase". This is because we write tests using different underlying 120 test libraries. 121 """ 122 return (tf_inspect.isclass(obj) 123 and 'TestCase' in (p.__name__ for p in tf_inspect.getmro(obj))) 124 125 126module_variables = vars() 127 128 129def move_test_classes_into_scope(wrapped_test_module): 130 """Add all test classes defined in wrapped module to our module. 131 132 The test runner works by inspecting the main module for TestCase classes, so 133 by adding a module-level reference to the TestCase we cause it to execute the 134 wrapped TestCase. 135 136 Args: 137 wrapped_test_module: The user-provided test code to run. 138 """ 139 for name, obj in wrapped_test_module.__dict__.items(): 140 if _is_test_class(obj): 141 module_variables['tpu_test_imported_%s' % name] = obj 142 143 144def run_user_main(wrapped_test_module): 145 """Runs the "if __name__ == '__main__'" at the bottom of a module. 146 147 TensorFlow practice is to have a main if at the bottom of the module which 148 might call an API compat function before calling test.main(). 149 150 Since this is a statement, not a function, we can't cleanly reference it, but 151 we can inspect it from the user module and run it in the context of that 152 module so all imports and variables are available to it. 153 154 Args: 155 wrapped_test_module: The user-provided test code to run. 156 157 Raises: 158 NotImplementedError: If main block was not found in module. This should not 159 be caught, as it is likely an error on the user's part -- absltest is all 160 too happy to report a successful status (and zero tests executed) if a 161 user forgets to end a class with "test.main()". 162 """ 163 tree = ast.parse(tf_inspect.getsource(wrapped_test_module)) 164 165 # Get string representation of just the condition `__name == "__main__"`. 166 target = ast.dump(ast.parse('if __name__ == "__main__": pass').body[0].test) 167 168 # `tree.body` is a list of top-level statements in the module, like imports 169 # and class definitions. We search for our main block, starting from the end. 170 for expr in reversed(tree.body): 171 if isinstance(expr, ast.If) and ast.dump(expr.test) == target: 172 break 173 else: 174 raise NotImplementedError( 175 f'Could not find `if __name__ == "main":` block in {wrapped_test_module.__name__}.' 176 ) 177 178 # expr is defined because we would have raised an error otherwise. 179 new_ast = ast.Module(body=expr.body, type_ignores=[]) # pylint:disable=undefined-loop-variable 180 exec( # pylint:disable=exec-used 181 compile(new_ast, '<ast>', 'exec'), 182 globals(), 183 wrapped_test_module.__dict__, 184 ) 185 186 187if __name__ == '__main__': 188 # Partially parse flags, since module to import is specified by flag. 189 unparsed = FLAGS(sys.argv, known_only=True) 190 user_module = import_user_module() 191 maybe_define_flags() 192 # Parse remaining flags. 193 FLAGS(unparsed) 194 set_random_test_dir() 195 196 move_test_classes_into_scope(user_module) 197 run_user_main(user_module) 198