xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/multi_process_runner.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Multi-process runner for testing purpose."""
16
17import collections
18import contextlib
19import json
20import os
21import signal
22import sys
23import threading
24import time
25import unittest
26import weakref
27
28from absl import logging
29import six
30from six.moves import queue as Queue
31
32from tensorflow.python import tf2
33from tensorflow.python.compat import v2_compat
34from tensorflow.python.distribute import multi_worker_util
35from tensorflow.python.distribute import multi_process_lib
36from tensorflow.python.eager import context
37from tensorflow.python.framework import test_util
38from tensorflow.python.util.tf_export import tf_export
39
40multiprocessing = multi_process_lib.multiprocessing
41
42# pylint: disable=g-import-not-at-top
43try:
44  # `faulthandler` is not available in py2.
45  import faulthandler
46except ImportError:
47  faulthandler = None
48
49# TODO(b/150264776): Remove after resolving CI issue.
50try:
51  import dill
52except ImportError:
53  dill = None
54
55# TODO(b/150264776): Remove after resolving CI issue.
56try:
57  import tblib.pickling_support
58  # For pickling traceback objects.
59  tblib.pickling_support.install()
60except ImportError:
61  pass
62
63
64# _ProcessStatusInfo contains process status information. When is_successful
65# attribute is True, the subprocess has ended successfully, or if False, the
66# exception stack trace info is stored in exc_info to pass on to parent process
67# to be re-raised.
68_ProcessStatusInfo = collections.namedtuple(
69    '_ProcessStatusInfo',
70    ['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value'])
71
72# Information returned from a successful MultiProcessRunner run.
73MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
74                                                  ['return_value', 'stdout'])
75
76# visible_gpus: If not None, CUDA_VISIBLE_DEVICES is set to visible_gpus.
77TestEnvironment = collections.namedtuple('TestEnvironment', [
78    'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast',
79    'v2_enabled', 'executing_eagerly', 'visible_gpus'
80])
81
82# Resources for communication between worker processes and the main process.
83#
84# `process_status_queue` is used by `multi_process_runner` internally for
85#   communication from subprocesses to the parent process for whether it's been
86#   successful, and if not what the error stack trace is.
87# `parent_to_sub_queue` is used for communications from parent to subprocess.
88#   Currently this is only used to terminate subprocesses.
89# TODO(rchao): Remove this once subprocess is terminated by SIGKILL.
90# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent
91#   process.
92# `barrier` is a barrier for the party of all subprocesses.
93Resources = collections.namedtuple('Resources', [
94    'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier'
95])
96
97# Default time out sec is selected so that it's handled before the default
98# "medium" timeout of the test runs.
99_DEFAULT_TIMEOUT_SEC = 200
100
101# The timeout in seconds to wait to force kill a child process. When a child
102# process times out we first try to SIGTERM it so that it has a chance to dump
103# stacktraces. However dumping stacktrace can take a long time.
104_FORCE_KILL_WAIT_SEC = 30
105
106
107class MultiProcessRunner(object):
108  """A utility class to start multiple processes to simulate a cluster.
109
110  We need to use multiple processes to simulate a cluster in TF 2.0 tests
111  because TF 2.0 has some process-global data structures that have to be
112  separated by processes. We also need child processes to test out our fault
113  tolerance because shutting down a standard TensorFlow server within its
114  process is not supported.
115
116  Note: the main test program that uses this runner class must run main program
117  via `test_main` defined in this file. Using this runner in non-test binaries
118  is not supported yet.
119
120  This class is not thread-safe. Child processes will inherit TF2 behavior flag.
121  """
122
123  def __init__(self,
124               fn,
125               cluster_spec,
126               rpc_layer=None,
127               max_run_time=None,
128               grpc_fail_fast=None,
129               stream_output=True,
130               return_output=False,
131               use_dill_for_args=True,
132               daemon=False,
133               dependence_on_chief=True,
134               auto_restart=False,
135               share_gpu=True,
136               args=None,
137               kwargs=None):
138    """Instantiation of a `MultiProcessRunner`.
139
140    Args:
141      fn: Function to be run on child processes. This will be run on processes
142        for all task types.
143      cluster_spec: Dict for cluster spec. The utility function
144        `tf.__internal__.distribute.multi_process_runner.create_cluster_spec`
145        can be conveniently used to create such dict. The following is an
146        example of cluster with three workers and two ps's.
147        {"worker": ["worker0.example.com:2222",
148                    "worker1.example.com:2222",
149                    "worker2.example.com:2222"],
150         "ps": ["ps0.example.com:2222",
151                "ps1.example.com:2222"]}
152      rpc_layer: RPC layer to use. Default value is 'grpc'.
153      max_run_time: `None` or integer. If not `None`, child processes are forced
154        to exit at approximately this many seconds after this utility is called.
155        We achieve this through `signal.alarm()` api. Note that this is best
156        effort at Python level since Python signal handler does not get executed
157        when it runs lower level C/C++ code. So it can be delayed for
158        arbitrarily long time. If any of the child process is still running when
159        `max_run_time` is up, they will be force-terminated and an
160        `UnexpectedSubprocessExitError` may be raised. If `None`, child
161        processes are not forced to exit.
162      grpc_fail_fast: Whether GRPC connection between processes should fail
163        without retrying. Defaults to None, in which case the environment
164        variable is not explicitly set.
165      stream_output: True if the output/error from the subprocesses should be
166        streamed to be printed in parent process' log. Defaults to True.
167      return_output: If True, the output/error from the subprocesses should be
168        collected to be attached to the resulting namedtuple returned from
169        `join()`. The list of output can be retrieved via `stdout` attribute.
170        Defaults to False.
171      use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill
172        can pickle more objects, but doesn't work with types in
173        `multiprocessing` library like `Mutex`.
174      daemon: Whether to start processes as daemons.
175      dependence_on_chief: Whether to terminates the cluster if the chief exits.
176        If auto_restart is True, it only terminates the cluster if the chief
177        exits with a zero exit code.
178      auto_restart: Whether to automatically restart processes that exit with
179        non-zero exit code.
180      share_gpu: Whether to share GPUs among workers. If False, each worker is
181        assigned different GPUs in a roundrobin fashion. This should be True
182        whenever possible for better test execution coverage; some situations
183        that need it to be False are tests that runs NCCL.
184      args: Positional arguments to be sent to `fn` run on subprocesses.
185      kwargs: Keyword arguments to be sent to `fn` run on subprocesses.
186
187    Raises:
188      RuntimeError: if `multi_process_runner.test_main()` is not called.
189      ValueError: if there are more than one chief in the `cluster_spec`.
190      SkipTest: if thread sanitizer is enabled (which is incompatible with MPR).
191    """
192    if test_util.is_tsan_enabled():
193      raise unittest.SkipTest(
194          'ThreadSanitizer is not compatible with MultiProcessRunner.')
195
196    assert cluster_spec is not None
197    if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1:
198      raise ValueError('If chief exists in the cluster, there must be at most '
199                       'one chief. Current `cluster_spec` has {} chiefs.'
200                       .format(len(cluster_spec['chief'])))
201    _check_initialization()
202    if not callable(fn):
203      raise ValueError('fn is not a callable')
204
205    self._fn = fn
206    self._cluster_spec = cluster_spec
207    self._rpc_layer = rpc_layer or 'grpc'
208    self._max_run_time = max_run_time
209    self._grpc_fail_fast = grpc_fail_fast
210    self._stream_output = stream_output
211    # TODO(rchao): Revisit return_output argument to consider other solution.
212    self._return_output = return_output
213    self._dependence_on_chief = dependence_on_chief
214    self._use_dill_for_args = use_dill_for_args
215    self._daemon = daemon
216    self._auto_restart = auto_restart
217    self._args = args or ()
218    self._kwargs = kwargs or {}
219
220    self._share_gpu = share_gpu
221    self._total_gpu = len(context.context().list_physical_devices('GPU'))
222
223    # Child processes should have the same v2 and eager behavior.
224    self._v2_enabled = tf2.enabled()
225    self._executing_eagerly = context.executing_eagerly()
226
227    self._joined = False
228    self._process_lock = threading.Lock()
229    # Guarded by self._process_lock.
230    self._processes = {}
231    # Record which processes are terminated. Due to a bug in Python<3.7,
232    # terminated processes return 255 exit code, which should cause an exception
233    # in join().
234    # https://bugs.python.org/issue30589
235    # Guarded by self._process_lock.
236    self._terminated = set()
237    self._reading_threads = []
238
239    self._manager = manager()
240    self._process_status_queue = self._manager.Queue()
241    self._parent_to_sub_queue = self._manager.Queue()
242    parties = sum(len(addresses) for addresses in self._cluster_spec.values())
243    self._barrier = self._manager.Barrier(parties)
244
245    # We use a queue to collect outputs from worker processes since it's thread
246    # safe.
247    self._streaming_queue = self._manager.Queue()
248
249    self._watchdog_thread = None
250
251  def set_args(self, args=None, kwargs=None):
252    self._args = args or self._args
253    self._kwargs = kwargs or self._kwargs
254
255  def _continuously_readline_from_sub(self, pipe_r, task_type, task_id):
256    """Function to continuously read lines from subprocesses."""
257    with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader:
258      for line in reader:
259        task_string = '[{}-{}]:'.format(task_type, task_id)
260        formatted_line = '{} {}'.format(task_string.ljust(14), line)
261        if self._stream_output:
262          # TODO(rchao): Use a lock here to ensure the printed lines are not
263          # broken.
264          print(formatted_line, end='', flush=True)
265        if self._return_output:
266          self._streaming_queue.put(formatted_line)
267
268  def _start_subprocess_and_reading_thread(self,
269                                           task_type,
270                                           task_id,
271                                           cluster_spec=None,
272                                           fn=None,
273                                           args=None,
274                                           kwargs=None):
275    """Start a subprocess and a thread the reads lines from the subprocess."""
276
277    if dill is None:
278      raise unittest.SkipTest(
279          'TODO(b/150264776): Resolve dependency issue in CI')
280
281    cluster_spec = cluster_spec or self._cluster_spec
282    visible_gpus = None
283    if not self._share_gpu and self._total_gpu > 0:
284      # Assign GPUs in a roundrobin fashion.
285      id_in_cluster = multi_worker_util.id_in_cluster(cluster_spec, task_type,
286                                                      task_id)
287      worker_count = multi_worker_util.worker_count(cluster_spec, task_type)
288      visible_gpus = list(range(id_in_cluster, self._total_gpu, worker_count))
289
290    test_env = TestEnvironment(
291        task_type=task_type,
292        task_id=task_id,
293        cluster_spec=cluster_spec,
294        rpc_layer=self._rpc_layer,
295        grpc_fail_fast=self._grpc_fail_fast,
296        v2_enabled=self._v2_enabled,
297        executing_eagerly=self._executing_eagerly,
298        visible_gpus=visible_gpus,
299    )
300    pipe_r, pipe_w = multiprocessing.Pipe(duplex=False)
301    resources = Resources(
302        process_status_queue=self._process_status_queue,
303        parent_to_sub_queue=self._parent_to_sub_queue,
304        streaming_pipe_w=pipe_w,
305        barrier=self._barrier,
306    )
307    if fn is None:
308      fn, args, kwargs = self._fn, self._args, self._kwargs
309    # Always use dill to pickle fn so that we support more callable
310    # types, e.g. lambda.
311    fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL)
312    if self._use_dill_for_args:
313      args = dill.dumps(args, dill.HIGHEST_PROTOCOL)
314      kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL)
315
316    p = _Process(
317        test_env=test_env,
318        target=_ProcFunc(),
319        args=(resources, test_env, fn, args, kwargs, self._use_dill_for_args),
320        daemon=self._daemon)
321    p.start()
322    self._processes[(task_type, task_id)] = p
323    self._terminated.discard((task_type, task_id))
324
325    # For each subprocess, we dedicate a thread continuously reading lines
326    # from them.
327    thread = threading.Thread(  # pylint: disable=unexpected-keyword-arg
328        target=self._continuously_readline_from_sub,
329        args=(pipe_r, task_type, task_id))
330    thread.start()
331    self._reading_threads.append(thread)
332
333    if self._watchdog_thread is None or not self._watchdog_thread.is_alive():
334      self._watchdog_thread = threading.Thread(target=self._process_watchdog)
335      self._watchdog_thread.start()
336
337  def start(self):
338    """Starts processes, one for each task in `cluster_spec`.
339
340    Note that this is best effort by the applicable multiprocessing library,
341    and it may take up to seconds for a subprocess to be successfully started.
342    """
343    with self._process_lock:
344      if self._processes:
345        raise ValueError('MultiProcessRunner already started.')
346      if self._joined:
347        raise ValueError('cannot start new processes after'
348                         'MultiProcessRunner.join() is called')
349
350      for task_type, addresses in self._cluster_spec.items():
351        for task_id, _ in enumerate(addresses):
352          self._start_subprocess_and_reading_thread(task_type, task_id)
353
354    # TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
355    # without this the tests become very flaky.
356    if self._max_run_time is not None:
357
358      def handler(signum, frame):
359        del signum, frame
360        self.terminate_all()
361
362      signal.signal(signal.SIGALRM, handler)
363      signal.alarm(self._max_run_time)
364
365  def start_in_process_as(self, as_task_type, as_task_id):
366    """Start the processes, with the specified task run in main process.
367
368    This is similar to `start()` except that the task with task_type
369    `as_task_type` and task_id `as_task_id` is run in the main process.
370    This method is particularly useful when debugging tool such as `pdb` is
371    needed in some specific task. Note that since this method is blocking until
372    that specific task exits, additional actions would need a thread to be
373    called:
374
375    ```python
376    def fn():
377      # user code to be run
378      import pdb; pdb.set_trace()
379
380    def follow_ups():
381      time.sleep(5)
382      mpr.start_single_process(
383          task_type='evaluator',
384          task_id=0)
385
386    mpr = multi_process_runner.MultiProcessRunner(
387        fn,
388        multi_worker_test_base.create_cluster_spec(
389            has_chief=True, num_workers=1))
390    threading.Thread(target=follow_ups).start()
391    mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
392    mpr.join()
393    ```
394
395    Note that if `return_output=True`, the logs/stdout by task
396    run by the main process is not available in result.stdout.
397
398    Args:
399      as_task_type: The task type to be run in the main process.
400      as_task_id: The task id to be run in the main process.
401    """
402    if self._processes:
403      raise ValueError('MultiProcessRunner already started.')
404    with self._process_lock:
405      if self._joined:
406        raise ValueError('cannot start new processes after'
407                         'MultiProcessRunner.join() is called')
408      for task_type, addresses in self._cluster_spec.items():
409        for task_id, _ in enumerate(addresses):
410          if not (task_type == as_task_type and task_id == as_task_id):
411            self._start_subprocess_and_reading_thread(task_type, task_id)
412
413    _set_tf_config(as_task_type, as_task_id, self._cluster_spec,
414                   self._rpc_layer)
415    self._fn(*self._args, **self._kwargs)
416
417  def start_single_process(self,
418                           task_type,
419                           task_id,
420                           cluster_spec=None,
421                           fn=None,
422                           args=None,
423                           kwargs=None):
424    """Starts a single process.
425
426    This starts a process in the cluster with the task type, task id, and the
427    process function (`fn`). If process function is `None`, the function
428    provided at `__init__` will be used. If `cluster_spec` is `None`, the
429    cluster spec provided at `__init__` will be used.
430
431    TODO(rchao): It is meant that all subprocesses will be updated with the new
432    cluster spec, but this has yet to be implemented. At this time only the
433    newly started subprocess picks up this updated cluster spec.
434
435    Args:
436      task_type: The task type.
437      task_id: The task id.
438      cluster_spec: The cluster spec to be used on the newly started
439        process. If `None`, the cluster spec provided at `__init__` will be
440        used.
441      fn: The process function to be run on the newly started
442        process. If specified, specify `args` and `kwargs` as well. If `None`,
443        the function provided at `__init__` will be used.
444      args: Optional positional arguments to be supplied in `fn`.
445      kwargs: Optional keyword arguments to be supplied in `fn`.
446    """
447    with self._process_lock:
448      if self._joined:
449        raise ValueError('cannot start new processes after'
450                         'MultiProcessRunner.join() is called')
451      self._start_subprocess_and_reading_thread(
452          task_type,
453          task_id,
454          cluster_spec=cluster_spec,
455          fn=fn,
456          args=args or (),
457          kwargs=kwargs or {})
458
459  def _queue_to_list(self, queue_to_convert):
460    """Convert `queue.Queue` to `list`."""
461    list_to_return = []
462    # Calling `queue.empty()` is not reliable.
463    while True:
464      try:
465        list_to_return.append(queue_to_convert.get(block=False))
466      except Queue.Empty:
467        break
468    return list_to_return
469
470  def _get_process_statuses(self):
471    # One worker may have multiple statuses. We only keep the last one.
472    statuses = {}
473    for status in self._queue_to_list(self._process_status_queue):
474      statuses[(status.task_type, status.task_id)] = status
475    return statuses
476
477  def get_process_id(self, task_type, task_id):
478    """Returns the subprocess id given the task type and task id."""
479    with self._process_lock:
480      p = self._processes.get((task_type, task_id), None)
481    return p.pid if p else None
482
483  def get_process_exit_code(self, task_type, task_id):
484    """Returns the subprocess exit code given the task type and task id.
485
486    Args:
487      task_type: The task type.
488      task_id: The task id.
489
490    Returns:
491      The subprocess exit code; `None` if the subprocess has not exited yet.
492
493    Raises:
494      KeyError: If the corresponding subprocess is not found with `task_type`
495        and `task_id`.
496    """
497    with self._process_lock:
498      p = self._processes[(task_type, task_id)]
499    return p.exitcode if p else None
500
501  def process_exists(self, task_type, task_id):
502    """Returns whether the subprocess still exists given the task type and id.
503
504    Args:
505      task_type: The task type.
506      task_id: The task id.
507
508    Returns:
509      Boolean; whether the subprocess still exists. If the subprocess has
510      exited, this returns False.
511    """
512    return self.get_process_exit_code(task_type, task_id) is None
513
514  def _process_watchdog(self):
515    """Simulates a cluster management system.
516
517    - If auto_restart is True, it restarts processes that exit with a non-zero
518      exit code. Note that when join() times out it overrides auto_restart to
519      False.
520    - If dependence_on_chief is True, it terminates all processes once the chief
521      exits. If auto_restart is also True, it only terminates all processes if
522      the chief exit with a zero exit code, otherwise it restarts the chief.
523
524    This runs in self._watchdog_thread.
525    """
526    while True:
527      time.sleep(1)
528      with self._process_lock:
529        chief = self._processes.get(('chief', 0), None)
530        # Terminate the cluster when _dependence_on_chief is True if either:
531        # - chief has exited with zero exit code.
532        # - chief has exited with non-zero exit code and self._auto_restart is
533        #   False.
534        if chief and self._dependence_on_chief and chief.exitcode is not None:
535          if chief.exitcode == 0 or (not self._auto_restart):
536            for p in self._processes.values():
537              # Give other processes a chance to exit on their own.
538              p.join(timeout=3)
539            self._terminate_all()
540            for p in self._processes.values():
541              p.join()
542            return
543
544        # Auto restart failed processes if self._auto_restart is True.
545        if self._auto_restart:
546          has_failure = False
547          for (task_type, task_id), p in self._processes.items():
548            if p.exitcode is not None and p.exitcode != 0:
549              has_failure = True
550              logging.info('Restarting failed %s-%d', task_type, task_id)
551              self._start_subprocess_and_reading_thread(task_type, task_id)
552          if has_failure:
553            continue
554
555        # Exit the thread if all processes have exited at this point.
556        if all(p.exitcode is not None for p in self._processes.values()):
557          return
558
559  def _reraise_if_subprocess_error(self, process_statuses):
560    for process_status in process_statuses.values():
561      assert isinstance(process_status, _ProcessStatusInfo)
562      if not process_status.is_successful:
563        process_status.exc_info[1].mpr_result = self._get_mpr_result(
564            process_statuses)
565        six.reraise(*process_status.exc_info)
566
567  def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
568    """Joins all the processes with timeout.
569
570    If any of the subprocesses does not exit approximately after `timeout`
571    seconds has passed after `join` call, this raises a
572    `SubprocessTimeoutError`.
573
574    Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to
575    log the stack traces of the subprocesses when they exit. However, this
576    results in timeout when the test runs with tsan (thread sanitizer); if tsan
577    is being run on the test targets that rely on timeout to assert information,
578    `MultiProcessRunner.terminate_all()` must be called after `join()`, before
579    the test exits, so the subprocesses are terminated with SIGKILL, and data
580    race is removed.
581
582    Args:
583      timeout: optional integer or `None`. If provided as an integer, and not
584      all processes report status within roughly `timeout` seconds, a
585      `SubprocessTimeoutError` exception will be raised. If `None`, `join` never
586      times out.
587
588    Returns:
589      A `MultiProcessRunnerResult` object, which has two attributes,
590      `return_value` and `stdout`. `return_value` always contains a list of
591      return values from the subprocesses, although the order is not meaningful.
592      If `return_output` argument is True at `__init__`, `stdout` is available
593      that contains a list of all messages from subprocesses' stdout and stderr.
594
595    Raises:
596      SubprocessTimeoutError: if not all processes report status approximately
597        within `timeout` seconds. When this is raised, a
598        `MultiProcessRunnerResult` object can be retrieved by
599        `SubprocessTimeoutError`'s mpr_result attribute, which has the same
600        structure as above 'Returns' section describes.
601      UnexpectedSubprocessExitError: If any of the subprocesses did not exit
602        properly (for example, they exit on SIGTERM or SIGKILL signal). When
603        this is raised, a `MultiProcessRunnerResult` object can be retrieved by
604        `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the
605        same structure as above 'Returns' section describes. If `max_run_time`
606        is not `None`, it is expected that some subprocesses may be
607        force-killed when `max_run_time` is up, and this is raised in those
608        cases.
609      Exception: if there is an Exception propagated from any subprocess. When
610        this is raised, a `MultiProcessRunnerResult` object can be retrieved by
611        `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the
612        same structure as above 'Returns' section describes.
613    """
614    if timeout and not isinstance(timeout, int):
615      raise ValueError('`timeout` must be an integer or `None`.')
616    with self._process_lock:
617      if self._joined:
618        raise ValueError("MultiProcessRunner can't be joined twice.")
619      self._joined = True
620
621    self._watchdog_thread.join(timeout)
622    if self._watchdog_thread.is_alive():
623      # Timeout. Force termination to dump worker processes stack trace.
624      with self._process_lock:
625        self._auto_restart = False
626      logging.error('Timeout when joining for child processes. Terminating...')
627      self.terminate_all(sig=signal.SIGTERM)
628      # Wait for the processes to terminate by themselves first, so they have a
629      # chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them.
630      self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC)
631      if self._watchdog_thread.is_alive():
632        logging.error('Timeout when waiting for child processes to '
633                      'print stacktrace. Sending SIGKILL...')
634        self.terminate_all()
635        self._watchdog_thread.join()
636      process_statuses = self._get_process_statuses()
637      self._reraise_if_subprocess_error(process_statuses)
638      raise SubprocessTimeoutError(
639          'One or more subprocesses timed out, where timeout was set to {}s. '
640          'Please change the `timeout` argument for '
641          '`MultiProcessRunner.join()` or `multi_process_runner.run()` '
642          'if it should be adjusted.'.format(timeout),
643          self._get_mpr_result(process_statuses))
644
645    for (task_type, task_id), p in self._processes.items():
646      logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
647
648    process_statuses = self._get_process_statuses()
649    self._reraise_if_subprocess_error(process_statuses)
650
651    # Checking all the processes that are expected to exit properly.
652    for (task_type, task_id), p in self._processes.items():
653      # Successfully exiting process has exit code 0. We ignore processes that
654      # are terminated.
655      assert p.exitcode is not None
656      if (p.exitcode > 0 and (task_type, task_id) not in self._terminated):
657        raise UnexpectedSubprocessExitError(
658            'Subprocess %s-%d exited with exit code %s. See logs for details.'
659            % (task_type, task_id, p.exitcode),
660            self._get_mpr_result(process_statuses))
661
662    logging.info('Joining log reading threads.')
663    for thread in self._reading_threads:
664      thread.join()
665    logging.info('Joined log reading threads.')
666
667    # Clear the alarm.
668    signal.alarm(0)
669
670    return self._get_mpr_result(process_statuses)
671
672  def _get_mpr_result(self, process_statuses):
673    stdout = self._queue_to_list(self._streaming_queue)
674    return_values = []
675    for process_status in process_statuses.values():
676      if process_status.return_value is not None:
677        return_values.append(process_status.return_value)
678    return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
679
680  def terminate(self, task_type, task_id):
681    """Terminates the process with `task_type` and `task_id`.
682
683    If auto_retart=True, the terminated task will be restarted unless the chief
684    has already exited with zero exit code.
685
686    Args:
687      task_type: the task type.
688      task_id: the task id.
689
690    """
691    with self._process_lock:
692      p = self._processes.get((task_type, task_id), None)
693      if p is None:
694        raise ValueError('{}-{} does not exist'.format(task_type, task_id))
695      self._terminated.add((task_type, task_id))
696      # TODO(crccw): change to use Process.terminate() as well.
697      self._parent_to_sub_queue.put('terminate {} {}'.format(
698          task_type, task_id))
699      p.join()
700
701  def _terminate_all(self, sig=None):
702    """Terminates all subprocesses.
703
704    The caller is required to hold self._process_lock.
705
706    Args:
707      sig: the signal used to terminate the process. The default is SIGKILL.
708    """
709
710    # Use SIGKILL as default. In systems where that's unavailable such as
711    # windows, use SIGTERM.
712    sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
713    for (task_type, task_id), p in self._processes.items():
714      if p.exitcode is not None:
715        logging.info('%s-%d has already exited. Not terminating.', task_type,
716                     task_id)
717        continue
718      try:
719        os.kill(p.pid, sig)
720        self._terminated.add((task_type, task_id))
721        logging.info('%s-%d terminated with signal %r.', task_type, task_id,
722                     sig)
723      except ProcessLookupError:
724        logging.info('Attempting to kill %s-%d but it does not exist.',
725                     task_type, task_id)
726
727  def terminate_all(self, sig=None):
728    """Terminates all subprocesses."""
729    with self._process_lock:
730      self._terminate_all(sig)
731
732
733class _Process(multi_process_lib.Process):
734  """A modified `multiprocessing.Process` that can set up environment variables."""
735
736  # TODO(crccw): consider moving other logics in _ProcFunc to _Process.
737
738  def __init__(self, test_env, **kwargs):
739    super(_Process, self).__init__(**kwargs)
740    self._test_env = test_env
741    self._actual_run = getattr(self, 'run')
742    self.run = self._run_with_setenv
743
744  def _run_with_setenv(self):
745    # We need to set environment variables before doing anything because
746    # setenv() is not thread-safe.
747    test_env = self._test_env
748    if test_env.grpc_fail_fast is not None:
749      os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast)
750    if test_env.visible_gpus:
751      os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
752          [str(i) for i in test_env.visible_gpus])
753    _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec,
754                   test_env.rpc_layer)
755    return self._actual_run()
756
757
758class _ProcFunc(object):
759  """Represents a callable to run in a subprocess."""
760
761  @contextlib.contextmanager
762  def _runtime_mode(self, executing_eagerly):
763    if executing_eagerly:
764      with context.eager_mode():
765        yield
766    else:
767      with context.graph_mode():
768        yield
769
770  def _message_checking_func(self, task_type, task_id):
771    """A function that regularly checks messages from parent process."""
772    # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess.
773    while True:
774      try:
775        message = self._resources.parent_to_sub_queue.get(block=False)
776
777        # Currently the only possible message is termination.
778        if not message.startswith('terminate'):
779          raise ValueError('Unrecognized message: {}'.format(message))
780
781        if message == 'terminate {} {}'.format(task_type, task_id):
782          break
783        else:
784          # If the message is not targeting this process, put it back to the
785          # queue.
786          self._resources.parent_to_sub_queue.put(message)
787          time.sleep(1)
788      except Queue.Empty:
789        time.sleep(0.1)
790    self._resources.process_status_queue.put(
791        _ProcessStatusInfo(
792            task_type=task_type,
793            task_id=task_id,
794            is_successful=True,
795            exc_info=None,
796            return_value=None))
797    # `os._exit(1)` is used to more reliably terminate a subprocess.
798    os._exit(1)  # pylint: disable=protected-access
799
800  def _close_streaming(self):
801    """Close stdout, stderr and streaming pipe.
802
803    We need to explicitly close them since Tensorflow may take a while to exit,
804    so that the reading threads in the main process can exit more quickly.
805    """
806    sys.stdout.flush()
807    sys.stderr.flush()
808    sys.stdout.close()
809    sys.stderr.close()
810    self._resources.streaming_pipe_w.close()
811
812  def __call__(self, resources, test_env, fn, args, kwargs, use_dill_for_args):
813    """The wrapper function that actually gets run in child process(es)."""
814
815    global _barrier
816
817    self._resources = resources
818    _barrier = self._resources.barrier
819    fn = dill.loads(fn)
820    if use_dill_for_args:
821      args = dill.loads(args)
822      kwargs = dill.loads(kwargs)
823
824    if faulthandler is not None:
825      faulthandler.enable()
826      faulthandler.register(signal.SIGTERM, chain=True)
827
828    # All logging should go to stderr to be streamed to the main process.
829    logging.set_stderrthreshold(logging.DEBUG)
830
831    # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so
832    # print() and logging.*() write directly to `streaming_pipe_w`.
833    # Unfortunately since we cannot prepend task_type and task_id information to
834    # the streamed logs we will need a thread per subprocess to distinguish
835    # where the piece of message is from.
836    os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno())
837    os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno())
838
839    pid = os.getpid()
840    logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid,
841                 test_env.task_type, test_env.task_id)
842    logging.info('TF_CONFIG: %r', os.environ['TF_CONFIG'])
843
844    # The thread will be dedicated to checking messages from the parent process.
845    threading.Thread(  # pylint: disable=unexpected-keyword-arg
846        target=self._message_checking_func,
847        args=(test_env.task_type, test_env.task_id),
848        daemon=True).start()
849
850    if test_env.v2_enabled:
851      v2_compat.enable_v2_behavior()
852
853    with self._runtime_mode(test_env.executing_eagerly):
854      info = _run_contained(test_env.task_type, test_env.task_id, fn, args,
855                            kwargs)
856      self._resources.process_status_queue.put(info)
857
858      # Re-raise the exception in addition to reporting it to the parent
859      # process, so that even if `--test_timeout` flag is set and the
860      # error doesn't make it to be shown in parent process before bazel's
861      # timeout, the log would still show what happens in this subprocess,
862      # instead of silently suppressing the error due to early bazel
863      # timeout. Raising an error in the subprocess produces stack trace in
864      # the log, but the program continues running.
865      if not info.is_successful:
866        six.reraise(*info.exc_info)
867
868      self._close_streaming()
869
870    # Exit with code 0 as it's considered successful exit at this point.
871    sys.exit(0)
872
873
874# Active MultiProcessPoolRunner. We need to shut them down when the program
875# exits, and this is by setting the `tearDownModule` of the module containing
876# `__main__`. Note this it set in both the parent process and the subprocesses.
877_active_pool_runners = weakref.WeakSet()
878
879
880def _shutdown_all_pool_runners():
881  for pool in _active_pool_runners:
882    pool.shutdown()
883
884
885def is_oss():
886  """Returns whether the test is run under OSS."""
887  return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
888
889
890class MultiProcessPoolRunner(object):
891  """A utility class to start a process pool to simulate a cluster.
892
893  It's similar to MultiProcessRunner, but uses a pool of processes to avoid the
894  expensive initialization cost of Tensorflow.
895  """
896
897  def __init__(self, cluster_spec, initializer=None, share_gpu=True):
898    """Creates a multi-process pool runner.
899
900    Args:
901      cluster_spec: Dict for cluster spec. The following is an example of
902        cluster with three workers.
903        {"worker": ["worker0.example.com:2222",
904                    "worker1.example.com:2222",
905                    "worker2.example.com:2222"]}
906      initializer: a callable to called at the startup of worker processes.
907      share_gpu: Whether to share GPUs among workers. If False, each worker is
908        assigned different GPUs in a roundrobin fashion.
909
910    Raises:
911      RuntimeError: if `multi_process_runner.test_main()` is not called.
912      ValueError: if there are more than one chief in the `cluster_spec`.
913    """
914    _active_pool_runners.add(self)
915    self._cluster_spec = cluster_spec
916    self._initializer = initializer
917    self._share_gpu = share_gpu
918    self._conn = {}
919    self._runner = None
920
921  def __del__(self):
922    self.shutdown()
923
924  def shutdown(self):
925    """Shuts down the worker pool."""
926    for conn in self._conn.values():
927      conn.close()
928    self._conn = {}
929    if self._runner is not None:
930      try:
931        self._runner.join()
932      except Exception as e:  # pylint: disable=broad-except
933        logging.error(
934            'Ignoring exception when shutting down MultiProcessPoolRunner: %s',
935            e)
936      self._runner = None
937
938  def _start(self):
939    """Starts the worker pool."""
940    # We need different arguments for different processes so we're passing a
941    # no-op fn here and use start_single_process instead.
942
943    if dill is None:
944      raise unittest.SkipTest(
945          'TODO(b/150264776): Resolve dependency issue in CI')
946
947    self._runner = MultiProcessRunner(
948        fn=lambda: None,
949        cluster_spec=self._cluster_spec,
950        use_dill_for_args=False,
951        share_gpu=self._share_gpu)
952    if self._initializer:
953      initializer = dill.dumps(self._initializer, dill.HIGHEST_PROTOCOL)
954    else:
955      initializer = None
956    for task_type, addresses in self._cluster_spec.items():
957      for task_id, _ in enumerate(addresses):
958        conn1, conn2 = multiprocessing.Pipe(duplex=True)
959        self._conn[(task_type, task_id)] = conn1
960        self._runner.start_single_process(
961            task_type,
962            task_id,
963            fn=_pool_runner_worker,
964            args=(task_type, task_id, initializer, conn2))
965
966  def run(self, fn, args=None, kwargs=None):
967    """Runs `fn` with `args` and `kwargs` on all jobs.
968
969    Args:
970      fn: The function to be run.
971      args: Optional positional arguments to be supplied in `fn`.
972      kwargs: Optional keyword arguments to be supplied in `fn`.
973
974    Returns:
975      A list of return values.
976    """
977    _check_initialization()
978    # TODO(b/150264776): skip in OSS until it's implemented.
979    multi_process_lib.Process()
980    if self._runner is None:
981      self._start()
982
983    fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL)
984    for conn in self._conn.values():
985      conn.send((fn, args or [], kwargs or {}))
986
987    process_statuses = []
988    for (task_type, task_id), conn in self._conn.items():
989      logging.info('Waiting for the result from %s-%d', task_type, task_id)
990      try:
991        process_statuses.append(conn.recv())
992      except EOFError:
993        # This shouldn't happen due to exceptions in fn. This usually
994        # means bugs in the runner.
995        self.shutdown()
996        raise RuntimeError('Unexpected EOF. Worker process may have died. '
997                           'Please report a bug')
998
999    return_values = []
1000    for process_status in process_statuses:
1001      assert isinstance(process_status, _ProcessStatusInfo)
1002      if not process_status.is_successful:
1003        six.reraise(*process_status.exc_info)
1004      if process_status.return_value is not None:
1005        return_values.append(process_status.return_value)
1006
1007    return return_values
1008
1009
1010def _pool_runner_worker(task_type, task_id, initializer, conn):
1011  """Function that runs on the workers in a pool.
1012
1013  It listens for callables to run and returns the result until `conn` is closed.
1014  It captures the exceptions during executing the callable and return it through
1015  `conn`.
1016
1017  Args:
1018    task_type: the task type.
1019    task_id: the task index.
1020    initializer: a callable to execute during startup.
1021    conn: a multiprocessing.Connection object to listen for tasks and send
1022      results.
1023  """
1024  if initializer:
1025    initializer = dill.loads(initializer)
1026    initializer()
1027  while True:
1028    try:
1029      fn, args, kwargs = conn.recv()
1030    except EOFError:
1031      break
1032    fn = dill.loads(fn)
1033    info = _run_contained(task_type, task_id, fn, args, kwargs)
1034    sys.stdout.flush()
1035    sys.stderr.flush()
1036    conn.send(info)
1037
1038
1039def _run_contained(task_type, task_id, fn, args, kwargs):
1040  """Runs `fn` with `args` and `kwargs`.
1041
1042  The function returns _ProcessStatusInfo which captures the return value and
1043  the exception.
1044
1045  Args:
1046    task_type: the task type.
1047    task_id: the task index.
1048    fn: the function to be run.
1049    args: optional positional arguments to be supplied in `fn`.
1050    kwargs: optional keyword arguments to be supplied in `fn`.
1051
1052  Returns:
1053    a _ProcessStatusInfo.
1054
1055  """
1056  is_successful = False
1057  return_value = None
1058  exc_info = None
1059  try:
1060    return_value = fn(*args, **kwargs)
1061    is_successful = True
1062    return _ProcessStatusInfo(
1063        task_type=task_type,
1064        task_id=task_id,
1065        is_successful=is_successful,
1066        exc_info=exc_info,
1067        return_value=return_value)
1068
1069  # If `fn` ends up exiting with `sys.exit()`, the `SystemExit` is not
1070  # handled here.
1071  except Exception:  # pylint: disable=broad-except
1072    exc_info = sys.exc_info()
1073    return _ProcessStatusInfo(
1074        task_type=task_type,
1075        task_id=task_id,
1076        is_successful=is_successful,
1077        exc_info=exc_info,
1078        return_value=return_value)
1079
1080
1081@tf_export('__internal__.distribute.multi_process_runner'
1082           '.SubprocessTimeoutError',
1083           v1=[])
1084class SubprocessTimeoutError(RuntimeError):
1085  """An error that indicates there is at least one subprocess timing out.
1086
1087  When this is raised, a namedtuple object representing the multi-process run
1088  result can be retrieved by
1089  `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s
1090  `mpr_result` attribute. See
1091  `tf.__internal__.distribute.multi_process_runner.run` for more information.
1092  """
1093
1094  def __init__(self, msg, mpr_result):
1095    super(SubprocessTimeoutError, self).__init__(msg)
1096    self.mpr_result = mpr_result
1097
1098
1099@tf_export('__internal__.distribute.multi_process_runner'
1100           '.UnexpectedSubprocessExitError',
1101           v1=[])
1102class UnexpectedSubprocessExitError(RuntimeError):
1103  """An error indicating there is at least one subprocess with unexpected exit.
1104
1105  When this is raised, a namedtuple object representing the multi-process run
1106  result can be retrieved by
1107  `tf.__internal__.distribute.multi_process_runner
1108  .UnexpectedSubprocessExitError`'s
1109  `mpr_result` attribute. See
1110  `tf.__internal__.distribute.multi_process_runner.run` for more information.
1111  """
1112
1113  def __init__(self, msg, mpr_result):
1114    super(UnexpectedSubprocessExitError, self).__init__(msg)
1115    self.mpr_result = mpr_result
1116
1117
1118@tf_export(
1119    '__internal__.distribute.multi_process_runner.NotInitializedError', v1=[])
1120class NotInitializedError(RuntimeError):
1121  """An error indicating `multi_process_runner.run` is used without init.
1122
1123  When this is raised, user is supposed to call
1124  `tf.__internal__.distribute.multi_process_runner.test_main()` within
1125  `if __name__ == '__main__':` block to properly initialize
1126  `multi_process_runner.run`.
1127  """
1128  pass
1129
1130
1131def _check_initialization():
1132  if not multi_process_lib.initialized():
1133    raise NotInitializedError(
1134        '`multi_process_runner` is not initialized. '
1135        'Please call `tf.__internal__.distribute.multi_process_runner.'
1136        'test_main()` within `if __name__ == \'__main__\':` block '
1137        'in your python module to properly initialize '
1138        '`multi_process_runner`.')
1139
1140
1141def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None):
1142  """Set TF_CONFIG environment variable."""
1143  tf_config_dict = {
1144      'cluster': cluster_spec,
1145      'task': {
1146          'type': task_type,
1147          'index': task_id,
1148      },
1149  }
1150  if rpc_layer is not None:
1151    tf_config_dict['rpc_layer'] = rpc_layer
1152  os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
1153
1154
1155@tf_export('__internal__.distribute.multi_process_runner.run', v1=[])
1156def run(fn,
1157        cluster_spec,
1158        rpc_layer=None,
1159        max_run_time=None,
1160        return_output=False,
1161        timeout=_DEFAULT_TIMEOUT_SEC,
1162        args=None,
1163        kwargs=None):
1164  """Run `fn` in multiple processes according to `cluster_spec`.
1165
1166  Given a callable `fn`, `tf.__internal__.distribute.multi_process_runner.run`
1167  launches multiple processes, each of which runs `fn`. These processes are
1168  referred to as "subprocesses" or "child processes". Each of those subprocesses
1169  will have their `TF_CONFIG` environment variable set, according to
1170  `cluster_spec` and their task types. The stdout of the subprocesses are
1171  streamed to the main process' and thus available in logs (if `stream_output`
1172  is True), with [type-id] prefix.
1173
1174  `tf.__internal__.distribute.multi_process_runner.run` will block until all
1175  subprocesses have successfully exited, and return a namedtuple object that
1176  represents the run result. This object has a `return_value` attribute, which
1177  is a list that contains subprocesses `fn`'s return values, for those
1178  subprocesses that successfully returned from `fn`. The order of `return_value`
1179  list is not meaningful. If an optional arg `return_output` (default to False)
1180  is set to True, the namedtuple object will have an additional attribute
1181  `stdout`, which is a list containing the stdout of the subprocesses. If any
1182  subprocess' `fn` ends up raising an error, that error will be reraised from
1183  `tf.__internal__.distribute.multi_process_runner.run`, and the aforementioned
1184  namedtuple object will be available through the exception's
1185  `mpr_result` attribute.
1186
1187  This utility is used for simulating running TensorFlow programs across
1188  multiple task types, and each of the task type may contain more than one task
1189  (except for "chief" where more than one task is prohibited). Test coverage of
1190  multi-worker training is the main application of this utility, where code
1191  written for multi-worker training can be realistically covered in unit tests.
1192
1193  Any test module that uses
1194  `tf.__internal__.distribute.multi_process_runner.run()` must call
1195  `tf.__internal__.distribute.multi_process_runner.test_main()` instead of
1196  regular `test.main()` inside `if __name__ == '__main__':` block for proper
1197  initialization.
1198
1199  Args:
1200    fn: Function to be run on child processes. This will be run on processes for
1201      all task types.
1202    cluster_spec: Dict for cluster spec. The utility function
1203      `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` can
1204      be conveniently used to create such dict. The following is an example of
1205      cluster with three workers and two ps's.
1206      {"worker": ["worker0.example.com:2222",
1207                  "worker1.example.com:2222",
1208                  "worker2.example.com:2222"],
1209       "ps": ["ps0.example.com:2222",
1210              "ps1.example.com:2222"]}
1211    rpc_layer: RPC layer to use. Default value is 'grpc'.
1212    max_run_time: `None` or integer. If not `None`, child processes are forced
1213      to exit at approximately this many seconds after this utility is called.
1214      We achieve this through `signal.alarm()` api. Note that this is best
1215      effort at Python level since Python signal handler does not get executed
1216      when it runs lower level C/C++ code. So it can be delayed for arbitrarily
1217      long time. If any of the child process is still running when
1218      `max_run_time` is up, they will be force-terminated and an
1219      `tf.__internal__.distribute.multi_process_runner
1220      .UnexpectedSubprocessExitError`
1221      may be raised. If `None`, child processes are not forced to exit.
1222    return_output: If True, the output/error from the subprocesses should be
1223      collected to be attached to the resulting namedtuple returned from this
1224      utility. The list of output can be retrieved via `stdout` attribute.
1225      Defaults to False.
1226    timeout: optional integer or `None`. If provided as an integer, and not all
1227      processes report status within roughly `timeout` seconds, a
1228      `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`
1229      exception will be raised. If `None`,
1230      `tf.__internal__.distribute.multi_process_runner.run` never times out.
1231      Defaults to the constant `_DEFAULT_TIMEOUT_SEC` defined in
1232      `multi_process_runner` module.
1233    args: Positional arguments to be sent to `fn` run on subprocesses.
1234    kwargs: Keyword arguments to be sent to `fn` run on subprocesses.
1235
1236  Returns:
1237      A namedtuple object, which has two attributes,
1238      `return_value` and `stdout`. `return_value` always contains a list of
1239      returnvalues from the subprocesses, although the order is not meaningful.
1240      If `return_output` argument is True, `stdout` is available that contains a
1241      list of all messages from subprocesses' stdout and stderr, and the order
1242      is mostly chronological.
1243
1244  Raises:
1245    RuntimeError: if
1246    `tf.__internal__.distribute.multi_process_runner.test_main()` is
1247      not called in test's `if __name__ == '__main__':` block.
1248    ValueError: if there are more than one chief in the `cluster_spec`.
1249    tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError: if
1250      not all processes report status approximately
1251      within `timeout` seconds. When this is raised, a
1252      namedtuple object can be retrieved by
1253      `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s
1254      `mpr_result` attribute, which has the same
1255      structure as above 'Returns' section describes.
1256    tf.__internal__.distribute.multi_process_runner
1257    .UnexpectedSubprocessExitError:
1258      If any of the subprocesses did not exit
1259      properly (for example, they exit on SIGTERM or SIGKILL signal). When
1260      this is raised, a namedtuple object can be retrieved by
1261      `tf.__internal__.distribute.multi_process_runner
1262      .UnexpectedSubprocessExitError`'s
1263      `mpr_result` attribute, which has the
1264      same structure as above 'Returns' section describes. If `max_run_time`
1265      is not `None`, it is expected that some subprocesses may be
1266      force-killed when `max_run_time` is up, and this is raised in those
1267      cases.
1268    Exception: if there is an Exception propagated from any subprocess. When
1269      this is raised, a namedtuple object can be retrieved by
1270      `tf.__internal__.distribute.multi_process_runner
1271      .UnexpectedSubprocessExitError`
1272      `mpr_result` attribute, which has the
1273      same structure as above 'Returns' section describes.
1274
1275  Examples:
1276
1277  ```python
1278  class SimpleMultiProcessTest(tf.test.TestCase):
1279
1280    def test_simple_printing_and_return(self):
1281
1282      def fn():
1283        resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
1284
1285        # This will print "[chief-0]:     Task type: chief , task id: 0"
1286        # for chief, for example.
1287        logging.info('Task type: %s, task id: %d',
1288                     resolver.task_type, resolver.task_id)
1289
1290        return resolver.task_type
1291
1292      result = tf.__internal__.distribute.multi_process_runner.run(
1293          fn=fn,
1294          cluster_spec=(
1295              tf.__internal__
1296              .distribute.multi_process_runner.create_cluster_spec(
1297                  has_chief=True, num_workers=2)))
1298      assert sorted(result.return_value) == ['chief', 'worker', 'worker']
1299
1300    def test_error_from_fn(self):
1301
1302      def fn():
1303        resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
1304        raise ValueError('Task type {}, task id {} is errors out'.format(
1305            resolver.task_type, resolver.task_id))
1306
1307      with self.assertRaisesRegexp(ValueError,
1308                                   'Task type worker, task id 0 is errors out'):
1309        cluster_spec = (
1310            tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
1311                num_workers=1))
1312        tf.__internal__.distribute.multi_process_runner.run(
1313            fn=fn, cluster_spec=cluster_spec)
1314
1315
1316  if __name__ == '__main__':
1317    tf.__internal__.distribute.multi_process_runner.test_main()
1318  ```
1319  """
1320  runner = MultiProcessRunner(
1321      fn,
1322      cluster_spec,
1323      rpc_layer,
1324      max_run_time=max_run_time,
1325      return_output=return_output,
1326      args=args,
1327      kwargs=kwargs)
1328  runner.start()
1329  return runner.join(timeout)
1330
1331
1332# This is set by MultiProcessRunner in worker processes.
1333_barrier = None
1334
1335
1336@tf_export('__internal__.distribute.multi_process_runner.get_barrier', v1=[])
1337def get_barrier():
1338  """Returns a `multiprocessing.Barrier` for `multi_process_runner.run`.
1339
1340  `tf.__internal__.distribute.multi_process_runner.get_barrier()` returns
1341  a `multiprocessing.Barrier` object which can be used within `fn` of
1342  `tf.__internal__.distribute.multi_process_runner` to wait with
1343  `barrier.wait()` call until all other tasks have also reached the
1344  `barrier.wait()` call, before they can proceed individually.
1345
1346  Note that all tasks (subprocesses) have to reach `barrier.wait()` call to
1347  proceed. Currently it is not supported to block on only a subset of tasks
1348  in the cluster.
1349
1350  Example:
1351  ```python
1352
1353  def fn():
1354    some_work_to_be_done_by_all_tasks()
1355
1356    tf.__internal__.distribute.multi_process_runner.get_barrier().wait()
1357
1358    # The barrier guarantees that at this point, all tasks have finished
1359    # `some_work_to_be_done_by_all_tasks()`
1360    some_other_work_to_be_done_by_all_tasks()
1361
1362  result = tf.__internal__.distribute.multi_process_runner.run(
1363      fn=fn,
1364      cluster_spec=(
1365          tf.__internal__
1366          .distribute.multi_process_runner.create_cluster_spec(
1367              num_workers=2)))
1368  ```
1369
1370
1371  Returns:
1372    A `multiprocessing.Barrier` for `multi_process_runner.run`.
1373  """
1374  if _barrier is None:
1375    raise ValueError(
1376        'barrier is not defined. It is likely because you are calling '
1377        'get_barrier() in the main process. get_barrier() can only be called '
1378        'in the subprocesses.'
1379    )
1380  return _barrier
1381
1382
1383_manager = None
1384_manager_lock = threading.Lock()
1385
1386
1387def manager():
1388  """Returns the multiprocessing manager object for concurrency tools.
1389
1390  The manager object is useful as it controls a server process that holds
1391  the python objects that can be shared across processes. This can be used
1392  for parent-subprocess communication:
1393
1394  ```python
1395  manager = multi_process_runner.manager()
1396  some_event_happening_in_subprocess = manager.Event()
1397  mpr = multi_process_runner.MultiProcessRunner(fn, cluster_spec,
1398      args=(some_event_happening_in_subprocess,))
1399  mpr.start()
1400  some_event_happening_in_subprocess.wait()
1401  # Do something that only should after some event happens in subprocess.
1402  ```
1403
1404  Note that the user of multi_process_runner should not create additional
1405  `multiprocessing.Manager()` objects; doing so can result in segfault in
1406  some cases.
1407
1408  This method should only be called after multi_process_runner.test_main() is
1409  called.
1410  """
1411  _check_initialization()
1412  global _manager
1413  with _manager_lock:
1414    if _manager is None:
1415      _manager = multiprocessing.Manager()
1416    return _manager
1417
1418
1419@tf_export('__internal__.distribute.multi_process_runner.test_main', v1=[])
1420def test_main():
1421  """Main function to be called within `__main__` of a test file.
1422
1423  Any test module that uses
1424  `tf.__internal__.distribute.multi_process_runner.run()`
1425  must call this instead of regular `test.main()` inside
1426  `if __name__ == '__main__':` block, or an error will be raised when
1427  `tf.__internal__.distribute.multi_process_runner.run()` is used. This method
1428  takes
1429  care of needed initialization for launching multiple subprocesses.
1430
1431  Example:
1432  ```python
1433  class MyTestClass(tf.test.TestCase):
1434    def testSomething(self):
1435      # Testing code making use of
1436      # `tf.__internal__.distribute.multi_process_runner.run()`.
1437
1438  if __name__ == '__main__':
1439    tf.__internal__.distribute.multi_process_runner.test_main()
1440  ```
1441  """
1442  # Inject tearDownModule() to shut down all pool runners. Active pool runners
1443  # will block the program from exiting. This is necessary for global pool
1444  # runners. We tried atexit in the past, and it doesn't work in some
1445  # deployment.
1446  old_tear_down_module = getattr(sys.modules['__main__'], 'tearDownModule',
1447                                 None)
1448
1449  def tear_down_module():
1450    _shutdown_all_pool_runners()
1451    if old_tear_down_module is not None:
1452      old_tear_down_module()
1453
1454  setattr(sys.modules['__main__'], 'tearDownModule', tear_down_module)
1455  multi_process_lib.test_main()
1456