xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/multi_process_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for multi-process testing."""
16
17import multiprocessing
18import os
19import platform
20import sys
21import unittest
22from absl import app
23from absl import logging
24
25from tensorflow.python.eager import test
26
27
28def is_oss():
29  """Returns whether the test is run under OSS."""
30  return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
31
32
33def _is_enabled():
34  # Note that flags may not be parsed at this point and simply importing the
35  # flags module causes a variety of unusual errors.
36  tpu_args = [arg for arg in sys.argv if arg.startswith('--tpu')]
37  if is_oss() and tpu_args:
38    return False
39  if sys.version_info == (3, 8) and platform.system() == 'Linux':
40    return False  # TODO(b/171242147)
41  return sys.platform != 'win32'
42
43
44class _AbslProcess:
45  """A process that runs using absl.app.run."""
46
47  def __init__(self, *args, **kwargs):
48    super(_AbslProcess, self).__init__(*args, **kwargs)
49    # Monkey-patch that is carried over into the spawned process by pickle.
50    self._run_impl = getattr(self, 'run')
51    self.run = self._run_with_absl
52
53  def _run_with_absl(self):
54    app.run(lambda _: self._run_impl())
55
56
57if _is_enabled():
58
59  class AbslForkServerProcess(_AbslProcess,
60                              multiprocessing.context.ForkServerProcess):
61    """An absl-compatible Forkserver process.
62
63    Note: Forkserver is not available in windows.
64    """
65
66  class AbslForkServerContext(multiprocessing.context.ForkServerContext):
67    _name = 'absl_forkserver'
68    Process = AbslForkServerProcess  # pylint: disable=invalid-name
69
70  multiprocessing = AbslForkServerContext()
71  Process = multiprocessing.Process
72
73else:
74
75  class Process(object):
76    """A process that skips test (until windows is supported)."""
77
78    def __init__(self, *args, **kwargs):
79      del args, kwargs
80      raise unittest.SkipTest(
81          'TODO(b/150264776): Windows is not supported in MultiProcessRunner.')
82
83
84_test_main_called = False
85
86
87def _set_spawn_exe_path():
88  """Set the path to the executable for spawned processes.
89
90  This utility searches for the binary the parent process is using, and sets
91  the executable of multiprocessing's context accordingly.
92
93  Raises:
94    RuntimeError: If the binary path cannot be determined.
95  """
96  # TODO(b/150264776): This does not work with Windows. Find a solution.
97  if sys.argv[0].endswith('.py'):
98    def guess_path(package_root):
99      # If all we have is a python module path, we'll need to make a guess for
100      # the actual executable path.
101      if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
102        # Guess the binary path under bazel. For target
103        # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
104        # argv[0] is in the form of
105        # /.../tensorflow/python/distribute/input_lib_test.py
106        # and the binary is
107        # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
108        package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
109        binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
110        possible_path = os.path.join(package_root_base, package_root,
111                                     binary)
112        logging.info('Guessed test binary path: %s', possible_path)
113        if os.access(possible_path, os.X_OK):
114          return possible_path
115        return None
116    path = guess_path('org_tensorflow')
117    if not path:
118      path = guess_path('org_keras')
119    if path is None:
120      logging.error(
121          'Cannot determine binary path. sys.argv[0]=%s os.environ=%s',
122          sys.argv[0], os.environ)
123      raise RuntimeError('Cannot determine binary path')
124    sys.argv[0] = path
125  # Note that this sets the executable for *all* contexts.
126  multiprocessing.get_context().set_executable(sys.argv[0])
127
128
129def _if_spawn_run_and_exit():
130  """If spawned process, run requested spawn task and exit. Else a no-op."""
131
132  # `multiprocessing` module passes a script "from multiprocessing.x import y"
133  # to subprocess, followed by a main function call. We use this to tell if
134  # the process is spawned. Examples of x are "forkserver" or
135  # "semaphore_tracker".
136  is_spawned = ('-c' in sys.argv[1:] and
137                sys.argv[sys.argv.index('-c') +
138                         1].startswith('from multiprocessing.'))
139
140  if not is_spawned:
141    return
142  cmd = sys.argv[sys.argv.index('-c') + 1]
143  # As a subprocess, we disregarding all other interpreter command line
144  # arguments.
145  sys.argv = sys.argv[0:1]
146
147  # Run the specified command - this is expected to be one of:
148  # 1. Spawn the process for semaphore tracker.
149  # 2. Spawn the initial process for forkserver.
150  # 3. Spawn any process as requested by the "spawn" method.
151  exec(cmd)  # pylint: disable=exec-used
152  sys.exit(0)  # Semaphore tracker doesn't explicitly sys.exit.
153
154
155def test_main():
156  """Main function to be called within `__main__` of a test file."""
157  global _test_main_called
158  _test_main_called = True
159
160  os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
161
162  if _is_enabled():
163    _set_spawn_exe_path()
164    _if_spawn_run_and_exit()
165
166  # Only runs test.main() if not spawned process.
167  test.main()
168
169
170def initialized():
171  """Returns whether the module is initialized."""
172  return _test_main_called
173