xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_test_wrapper_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for tpu_test_wrapper.py."""
16
17import importlib.util  # Python 3 only.
18import os
19
20from absl.testing import flagsaver
21
22from tensorflow.python.platform import flags
23from tensorflow.python.platform import test
24from tensorflow.python.tpu import tpu_test_wrapper
25
26
27class TPUTestWrapperTest(test.TestCase):
28
29  @flagsaver.flagsaver()
30  def test_flags_undefined(self):
31    tpu_test_wrapper.maybe_define_flags()
32
33    self.assertIn('tpu', flags.FLAGS)
34    self.assertIn('zone', flags.FLAGS)
35    self.assertIn('project', flags.FLAGS)
36    self.assertIn('model_dir', flags.FLAGS)
37
38  @flagsaver.flagsaver()
39  def test_flags_already_defined_not_overridden(self):
40    flags.DEFINE_string('tpu', 'tpuname', 'helpstring')
41    tpu_test_wrapper.maybe_define_flags()
42
43    self.assertIn('tpu', flags.FLAGS)
44    self.assertIn('zone', flags.FLAGS)
45    self.assertIn('project', flags.FLAGS)
46    self.assertIn('model_dir', flags.FLAGS)
47    self.assertEqual(flags.FLAGS.tpu, 'tpuname')
48
49  @flagsaver.flagsaver(bazel_repo_root='tensorflow/python')
50  def test_parent_path(self):
51    filepath = '/filesystem/path/tensorflow/python/tpu/example_test.runfiles/tensorflow/python/tpu/example_test'  # pylint: disable=line-too-long
52    self.assertEqual(
53        tpu_test_wrapper.calculate_parent_python_path(filepath),
54        'tensorflow.python.tpu')
55
56  @flagsaver.flagsaver(bazel_repo_root='tensorflow/python')
57  def test_parent_path_raises(self):
58    filepath = '/bad/path'
59    with self.assertRaisesWithLiteralMatch(
60        ValueError,
61        'Filepath "/bad/path" does not contain repo root "tensorflow/python"'):
62      tpu_test_wrapper.calculate_parent_python_path(filepath)
63
64  def test_is_test_class_positive(self):
65
66    class A(test.TestCase):
67      pass
68
69    self.assertTrue(tpu_test_wrapper._is_test_class(A))
70
71  def test_is_test_class_negative(self):
72
73    class A(object):
74      pass
75
76    self.assertFalse(tpu_test_wrapper._is_test_class(A))
77
78  @flagsaver.flagsaver(wrapped_tpu_test_module_relative='.tpu_test_wrapper_test'
79                      )
80  def test_move_test_classes_into_scope(self):
81    # Test the class importer by having the wrapper module import this test
82    # into itself.
83    with test.mock.patch.object(
84        tpu_test_wrapper, 'calculate_parent_python_path') as mock_parent_path:
85      mock_parent_path.return_value = (
86          tpu_test_wrapper.__name__.rpartition('.')[0])
87
88      module = tpu_test_wrapper.import_user_module()
89      tpu_test_wrapper.move_test_classes_into_scope(module)
90
91    self.assertEqual(
92        tpu_test_wrapper.tpu_test_imported_TPUTestWrapperTest.__name__,
93        self.__class__.__name__)
94
95  @flagsaver.flagsaver(test_dir_base='gs://example-bucket/tempfiles')
96  def test_set_random_test_dir(self):
97    tpu_test_wrapper.maybe_define_flags()
98    tpu_test_wrapper.set_random_test_dir()
99
100    self.assertStartsWith(flags.FLAGS.model_dir,
101                          'gs://example-bucket/tempfiles')
102    self.assertGreater(
103        len(flags.FLAGS.model_dir), len('gs://example-bucket/tempfiles'))
104
105  @flagsaver.flagsaver(test_dir_base='gs://example-bucket/tempfiles')
106  def test_set_random_test_dir_repeatable(self):
107    tpu_test_wrapper.maybe_define_flags()
108    tpu_test_wrapper.set_random_test_dir()
109    first = flags.FLAGS.model_dir
110    tpu_test_wrapper.set_random_test_dir()
111    second = flags.FLAGS.model_dir
112
113    self.assertNotEqual(first, second)
114
115  def test_run_user_main(self):
116    test_module = _write_and_load_module("""
117VARS = 1
118
119if 'unrelated_if' == 'should_be_ignored':
120  VARS = 2
121
122if __name__ == '__main__':
123  VARS = 3
124
125if 'extra_if_at_bottom' == 'should_be_ignored':
126  VARS = 4
127""")
128
129    self.assertEqual(test_module.VARS, 1)
130    tpu_test_wrapper.run_user_main(test_module)
131    self.assertEqual(test_module.VARS, 3)
132
133  def test_run_user_main_missing_if(self):
134    test_module = _write_and_load_module("""
135VARS = 1
136""")
137
138    self.assertEqual(test_module.VARS, 1)
139    with self.assertRaises(NotImplementedError):
140      tpu_test_wrapper.run_user_main(test_module)
141
142  def test_run_user_main_double_quotes(self):
143    test_module = _write_and_load_module("""
144VARS = 1
145
146if "unrelated_if" == "should_be_ignored":
147  VARS = 2
148
149if __name__ == "__main__":
150  VARS = 3
151
152if "extra_if_at_bottom" == "should_be_ignored":
153  VARS = 4
154""")
155
156    self.assertEqual(test_module.VARS, 1)
157    tpu_test_wrapper.run_user_main(test_module)
158    self.assertEqual(test_module.VARS, 3)
159
160  def test_run_user_main_test(self):
161    test_module = _write_and_load_module("""
162from tensorflow.python.platform import test as unique_name
163
164class DummyTest(unique_name.TestCase):
165  def test_fail(self):
166    self.fail()
167
168if __name__ == '__main__':
169  unique_name.main()
170""")
171
172    # We're actually limited in what we can test here -- we can't call
173    # test.main() without deleting this current test from locals(), or we'll
174    # recurse infinitely. We settle for testing that the test imports and calls
175    # the right test module.
176
177    with test.mock.patch.object(test, 'main') as mock_main:
178      tpu_test_wrapper.run_user_main(test_module)
179    mock_main.assert_called_once()
180
181
182def _write_and_load_module(source):
183  fp = os.path.join(test.get_temp_dir(), 'testmod.py')
184  with open(fp, 'w') as f:
185    f.write(source)
186  spec = importlib.util.spec_from_file_location('testmodule', fp)
187  test_module = importlib.util.module_from_spec(spec)
188  spec.loader.exec_module(test_module)
189  return test_module
190
191
192if __name__ == '__main__':
193  test.main()
194