1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for tpu_test_wrapper.py.""" 16 17import importlib.util # Python 3 only. 18import os 19 20from absl.testing import flagsaver 21 22from tensorflow.python.platform import flags 23from tensorflow.python.platform import test 24from tensorflow.python.tpu import tpu_test_wrapper 25 26 27class TPUTestWrapperTest(test.TestCase): 28 29 @flagsaver.flagsaver() 30 def test_flags_undefined(self): 31 tpu_test_wrapper.maybe_define_flags() 32 33 self.assertIn('tpu', flags.FLAGS) 34 self.assertIn('zone', flags.FLAGS) 35 self.assertIn('project', flags.FLAGS) 36 self.assertIn('model_dir', flags.FLAGS) 37 38 @flagsaver.flagsaver() 39 def test_flags_already_defined_not_overridden(self): 40 flags.DEFINE_string('tpu', 'tpuname', 'helpstring') 41 tpu_test_wrapper.maybe_define_flags() 42 43 self.assertIn('tpu', flags.FLAGS) 44 self.assertIn('zone', flags.FLAGS) 45 self.assertIn('project', flags.FLAGS) 46 self.assertIn('model_dir', flags.FLAGS) 47 self.assertEqual(flags.FLAGS.tpu, 'tpuname') 48 49 @flagsaver.flagsaver(bazel_repo_root='tensorflow/python') 50 def test_parent_path(self): 51 filepath = '/filesystem/path/tensorflow/python/tpu/example_test.runfiles/tensorflow/python/tpu/example_test' # pylint: disable=line-too-long 52 self.assertEqual( 53 tpu_test_wrapper.calculate_parent_python_path(filepath), 54 'tensorflow.python.tpu') 55 56 @flagsaver.flagsaver(bazel_repo_root='tensorflow/python') 57 def test_parent_path_raises(self): 58 filepath = '/bad/path' 59 with self.assertRaisesWithLiteralMatch( 60 ValueError, 61 'Filepath "/bad/path" does not contain repo root "tensorflow/python"'): 62 tpu_test_wrapper.calculate_parent_python_path(filepath) 63 64 def test_is_test_class_positive(self): 65 66 class A(test.TestCase): 67 pass 68 69 self.assertTrue(tpu_test_wrapper._is_test_class(A)) 70 71 def test_is_test_class_negative(self): 72 73 class A(object): 74 pass 75 76 self.assertFalse(tpu_test_wrapper._is_test_class(A)) 77 78 @flagsaver.flagsaver(wrapped_tpu_test_module_relative='.tpu_test_wrapper_test' 79 ) 80 def test_move_test_classes_into_scope(self): 81 # Test the class importer by having the wrapper module import this test 82 # into itself. 83 with test.mock.patch.object( 84 tpu_test_wrapper, 'calculate_parent_python_path') as mock_parent_path: 85 mock_parent_path.return_value = ( 86 tpu_test_wrapper.__name__.rpartition('.')[0]) 87 88 module = tpu_test_wrapper.import_user_module() 89 tpu_test_wrapper.move_test_classes_into_scope(module) 90 91 self.assertEqual( 92 tpu_test_wrapper.tpu_test_imported_TPUTestWrapperTest.__name__, 93 self.__class__.__name__) 94 95 @flagsaver.flagsaver(test_dir_base='gs://example-bucket/tempfiles') 96 def test_set_random_test_dir(self): 97 tpu_test_wrapper.maybe_define_flags() 98 tpu_test_wrapper.set_random_test_dir() 99 100 self.assertStartsWith(flags.FLAGS.model_dir, 101 'gs://example-bucket/tempfiles') 102 self.assertGreater( 103 len(flags.FLAGS.model_dir), len('gs://example-bucket/tempfiles')) 104 105 @flagsaver.flagsaver(test_dir_base='gs://example-bucket/tempfiles') 106 def test_set_random_test_dir_repeatable(self): 107 tpu_test_wrapper.maybe_define_flags() 108 tpu_test_wrapper.set_random_test_dir() 109 first = flags.FLAGS.model_dir 110 tpu_test_wrapper.set_random_test_dir() 111 second = flags.FLAGS.model_dir 112 113 self.assertNotEqual(first, second) 114 115 def test_run_user_main(self): 116 test_module = _write_and_load_module(""" 117VARS = 1 118 119if 'unrelated_if' == 'should_be_ignored': 120 VARS = 2 121 122if __name__ == '__main__': 123 VARS = 3 124 125if 'extra_if_at_bottom' == 'should_be_ignored': 126 VARS = 4 127""") 128 129 self.assertEqual(test_module.VARS, 1) 130 tpu_test_wrapper.run_user_main(test_module) 131 self.assertEqual(test_module.VARS, 3) 132 133 def test_run_user_main_missing_if(self): 134 test_module = _write_and_load_module(""" 135VARS = 1 136""") 137 138 self.assertEqual(test_module.VARS, 1) 139 with self.assertRaises(NotImplementedError): 140 tpu_test_wrapper.run_user_main(test_module) 141 142 def test_run_user_main_double_quotes(self): 143 test_module = _write_and_load_module(""" 144VARS = 1 145 146if "unrelated_if" == "should_be_ignored": 147 VARS = 2 148 149if __name__ == "__main__": 150 VARS = 3 151 152if "extra_if_at_bottom" == "should_be_ignored": 153 VARS = 4 154""") 155 156 self.assertEqual(test_module.VARS, 1) 157 tpu_test_wrapper.run_user_main(test_module) 158 self.assertEqual(test_module.VARS, 3) 159 160 def test_run_user_main_test(self): 161 test_module = _write_and_load_module(""" 162from tensorflow.python.platform import test as unique_name 163 164class DummyTest(unique_name.TestCase): 165 def test_fail(self): 166 self.fail() 167 168if __name__ == '__main__': 169 unique_name.main() 170""") 171 172 # We're actually limited in what we can test here -- we can't call 173 # test.main() without deleting this current test from locals(), or we'll 174 # recurse infinitely. We settle for testing that the test imports and calls 175 # the right test module. 176 177 with test.mock.patch.object(test, 'main') as mock_main: 178 tpu_test_wrapper.run_user_main(test_module) 179 mock_main.assert_called_once() 180 181 182def _write_and_load_module(source): 183 fp = os.path.join(test.get_temp_dir(), 'testmod.py') 184 with open(fp, 'w') as f: 185 f.write(source) 186 spec = importlib.util.spec_from_file_location('testmodule', fp) 187 test_module = importlib.util.module_from_spec(spec) 188 spec.loader.exec_module(test_module) 189 return test_module 190 191 192if __name__ == '__main__': 193 test.main() 194