1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for multi-process testing.""" 16 17import multiprocessing 18import os 19import platform 20import sys 21import unittest 22from absl import app 23from absl import logging 24 25from tensorflow.python.eager import test 26 27 28def is_oss(): 29 """Returns whether the test is run under OSS.""" 30 return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] 31 32 33def _is_enabled(): 34 # Note that flags may not be parsed at this point and simply importing the 35 # flags module causes a variety of unusual errors. 36 tpu_args = [arg for arg in sys.argv if arg.startswith('--tpu')] 37 if is_oss() and tpu_args: 38 return False 39 if sys.version_info == (3, 8) and platform.system() == 'Linux': 40 return False # TODO(b/171242147) 41 return sys.platform != 'win32' 42 43 44class _AbslProcess: 45 """A process that runs using absl.app.run.""" 46 47 def __init__(self, *args, **kwargs): 48 super(_AbslProcess, self).__init__(*args, **kwargs) 49 # Monkey-patch that is carried over into the spawned process by pickle. 50 self._run_impl = getattr(self, 'run') 51 self.run = self._run_with_absl 52 53 def _run_with_absl(self): 54 app.run(lambda _: self._run_impl()) 55 56 57if _is_enabled(): 58 59 class AbslForkServerProcess(_AbslProcess, 60 multiprocessing.context.ForkServerProcess): 61 """An absl-compatible Forkserver process. 62 63 Note: Forkserver is not available in windows. 64 """ 65 66 class AbslForkServerContext(multiprocessing.context.ForkServerContext): 67 _name = 'absl_forkserver' 68 Process = AbslForkServerProcess # pylint: disable=invalid-name 69 70 multiprocessing = AbslForkServerContext() 71 Process = multiprocessing.Process 72 73else: 74 75 class Process(object): 76 """A process that skips test (until windows is supported).""" 77 78 def __init__(self, *args, **kwargs): 79 del args, kwargs 80 raise unittest.SkipTest( 81 'TODO(b/150264776): Windows is not supported in MultiProcessRunner.') 82 83 84_test_main_called = False 85 86 87def _set_spawn_exe_path(): 88 """Set the path to the executable for spawned processes. 89 90 This utility searches for the binary the parent process is using, and sets 91 the executable of multiprocessing's context accordingly. 92 93 Raises: 94 RuntimeError: If the binary path cannot be determined. 95 """ 96 # TODO(b/150264776): This does not work with Windows. Find a solution. 97 if sys.argv[0].endswith('.py'): 98 def guess_path(package_root): 99 # If all we have is a python module path, we'll need to make a guess for 100 # the actual executable path. 101 if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]: 102 # Guess the binary path under bazel. For target 103 # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the 104 # argv[0] is in the form of 105 # /.../tensorflow/python/distribute/input_lib_test.py 106 # and the binary is 107 # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu 108 package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)] 109 binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1) 110 possible_path = os.path.join(package_root_base, package_root, 111 binary) 112 logging.info('Guessed test binary path: %s', possible_path) 113 if os.access(possible_path, os.X_OK): 114 return possible_path 115 return None 116 path = guess_path('org_tensorflow') 117 if not path: 118 path = guess_path('org_keras') 119 if path is None: 120 logging.error( 121 'Cannot determine binary path. sys.argv[0]=%s os.environ=%s', 122 sys.argv[0], os.environ) 123 raise RuntimeError('Cannot determine binary path') 124 sys.argv[0] = path 125 # Note that this sets the executable for *all* contexts. 126 multiprocessing.get_context().set_executable(sys.argv[0]) 127 128 129def _if_spawn_run_and_exit(): 130 """If spawned process, run requested spawn task and exit. Else a no-op.""" 131 132 # `multiprocessing` module passes a script "from multiprocessing.x import y" 133 # to subprocess, followed by a main function call. We use this to tell if 134 # the process is spawned. Examples of x are "forkserver" or 135 # "semaphore_tracker". 136 is_spawned = ('-c' in sys.argv[1:] and 137 sys.argv[sys.argv.index('-c') + 138 1].startswith('from multiprocessing.')) 139 140 if not is_spawned: 141 return 142 cmd = sys.argv[sys.argv.index('-c') + 1] 143 # As a subprocess, we disregarding all other interpreter command line 144 # arguments. 145 sys.argv = sys.argv[0:1] 146 147 # Run the specified command - this is expected to be one of: 148 # 1. Spawn the process for semaphore tracker. 149 # 2. Spawn the initial process for forkserver. 150 # 3. Spawn any process as requested by the "spawn" method. 151 exec(cmd) # pylint: disable=exec-used 152 sys.exit(0) # Semaphore tracker doesn't explicitly sys.exit. 153 154 155def test_main(): 156 """Main function to be called within `__main__` of a test file.""" 157 global _test_main_called 158 _test_main_called = True 159 160 os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 161 162 if _is_enabled(): 163 _set_spawn_exe_path() 164 _if_spawn_run_and_exit() 165 166 # Only runs test.main() if not spawned process. 167 test.main() 168 169 170def initialized(): 171 """Returns whether the module is initialized.""" 172 return _test_main_called 173